diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/default.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/default.py index 7feb5af8..3151a1e8 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/default.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/default.py @@ -4,6 +4,7 @@ from unstructured.chunking.basic import chunk_elements from unstructured.documents.elements import Element as UnstructuredElement +from unstructured.documents.elements import ElementType from unstructured.partition.auto import partition from unstructured.staging.base import elements_from_dicts from unstructured_client import UnstructuredClient @@ -11,7 +12,7 @@ from ragbits.document_search.documents.document import DocumentMeta, DocumentType from ragbits.document_search.documents.element import Element from ragbits.document_search.ingestion.providers.base import BaseProvider -from ragbits.document_search.ingestion.providers.unstructured.utils import set_or_raise, to_text_element +from ragbits.document_search.ingestion.providers.unstructured.utils import check_required_argument, to_text_element DEFAULT_PARTITION_KWARGS: dict = { "strategy": "hi_res", @@ -71,7 +72,7 @@ def __init__( variable will be used. api_server: The API server URL to use for the Unstructured API. If not specified, the UNSTRUCTURED_SERVER_URL environment variable will be used. - use_api: whether to use unstructured API + use_api: whether to use Unstructured API, otherwise use local version of Unstructured library ignore_images: if True images will be skipped """ self.partition_kwargs = partition_kwargs or DEFAULT_PARTITION_KWARGS @@ -95,8 +96,10 @@ def client(self) -> UnstructuredClient: """ if self._client is not None: return self._client - api_key = set_or_raise(name="api_key", value=self.api_key, env_var=UNSTRUCTURED_API_KEY_ENV) - api_server = set_or_raise(name="api_server", value=self.api_server, env_var=UNSTRUCTURED_SERVER_URL_ENV) + api_key = check_required_argument(arg_name="api_key", value=self.api_key, fallback_env=UNSTRUCTURED_API_KEY_ENV) + api_server = check_required_argument( + arg_name="api_server", value=self.api_server, fallback_env=UNSTRUCTURED_SERVER_URL_ENV + ) self._client = UnstructuredClient(api_key_auth=api_key, server_url=api_server) return self._client @@ -145,8 +148,8 @@ async def _chunk_and_convert( if self.__class__ == UnstructuredDefaultProvider: chunked_elements = chunk_elements(elements, **self.chunking_kwargs) return [to_text_element(element, document_meta) for element in chunked_elements] - image_elements = [e for e in elements if e.category == "Image"] - other_elements = [e for e in elements if e.category != "Image"] + image_elements = [e for e in elements if e.category == ElementType.IMAGE] + other_elements = [e for e in elements if e.category != ElementType.IMAGE] chunked_other_elements = chunk_elements(other_elements, **self.chunking_kwargs) text_elements: list[Element] = [to_text_element(element, document_meta) for element in chunked_other_elements] diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py index ae7b084b..eaff2b81 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py @@ -59,9 +59,9 @@ def __init__( async def _to_image_element( self, element: UnstructuredElement, document_meta: DocumentMeta, document_path: Path ) -> ImageElement: - image_coordinates = extract_image_coordinates(element) + top_x, top_y, bottom_x, bottom_y = extract_image_coordinates(element) image = Image.open(document_path).convert("RGB") - img_bytes = crop_and_convert_to_bytes(image, *image_coordinates) + img_bytes = crop_and_convert_to_bytes(image, top_x, top_y, bottom_x, bottom_y) image_description = await self.image_summarizer.get_image_description(img_bytes) return ImageElement( description=image_description, diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/pdf.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/pdf.py index fc1b3c6a..bf2a9018 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/pdf.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/pdf.py @@ -59,18 +59,14 @@ def __init__( async def _to_image_element( self, element: UnstructuredElement, document_meta: DocumentMeta, document_path: Path ) -> ImageElement: - image_coordinates = extract_image_coordinates(element) + top_x, top_y, bottom_x, bottom_y = extract_image_coordinates(element) page_number = element.metadata.page_number image = convert_from_path(document_path, first_page=page_number, last_page=page_number)[0] new_system = CoordinateSystem(image.width, image.height) new_system.orientation = Orientation.SCREEN - x0, y0 = element.metadata.coordinates.system.convert_coordinates_to_new_system( - new_system, image_coordinates[0], image_coordinates[1] - ) - x1, y1 = element.metadata.coordinates.system.convert_coordinates_to_new_system( - new_system, image_coordinates[2], image_coordinates[3] - ) + x0, y0 = element.metadata.coordinates.system.convert_coordinates_to_new_system(new_system, top_x, top_y) + x1, y1 = element.metadata.coordinates.system.convert_coordinates_to_new_system(new_system, bottom_x, bottom_y) img_bytes = crop_and_convert_to_bytes(image, x0, y0, x1, y1) image_description = await self.image_summarizer.get_image_description(img_bytes) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/utils.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/utils.py index feeeaf31..105c2706 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/utils.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/utils.py @@ -28,14 +28,14 @@ def to_text_element(element: UnstructuredElement, document_meta: DocumentMeta) - ) -def set_or_raise(name: str, value: Optional[str], env_var: str) -> str: +def check_required_argument(value: Optional[str], arg_name: str, fallback_env: str) -> str: """ Checks if given environment variable is set and returns it or raises an error Args: - name: name of the variable + arg_name: name of the variable value: optional default value - env_var: name of the environment variable to get + fallback_env: name of the environment variable to get Raises: ValueError: if environment variable is not set @@ -45,8 +45,8 @@ def set_or_raise(name: str, value: Optional[str], env_var: str) -> str: """ if value is not None: return value - if (env_value := os.getenv(env_var)) is None: - raise ValueError(f"Either pass {name} argument or set the {env_var} environment variable") + if (env_value := os.getenv(fallback_env)) is None: + raise ValueError(f"Either pass {arg_name} argument or set the {fallback_env} environment variable") return env_value @@ -87,10 +87,10 @@ class ImageDescriber: DEFAULT_PROMPT = "Describe the content of the image." - def __init__(self, model_name: str, model_options: LiteLLMOptions): - self.model_name = model_name - self.model: Optional[LiteLLM] = None - self.model_options = model_options + def __init__(self, llm_name: str, llm_options: LiteLLMOptions): + self.llm_name = llm_name + self.llm: Optional[LiteLLM] = None + self.llm_options = llm_options async def get_image_description(self, image_bytes: bytes, prompt: Optional[str] = DEFAULT_PROMPT) -> str: """ @@ -103,8 +103,8 @@ async def get_image_description(self, image_bytes: bytes, prompt: Optional[str] Returns: summary of the image """ - if not self.model: - self.model = LiteLLM(self.model_name) + if not self.llm: + self.llm = LiteLLM(self.llm_name) img_base64 = base64.b64encode(image_bytes).decode("utf-8") @@ -121,4 +121,4 @@ async def get_image_description(self, image_bytes: bytes, prompt: Optional[str] ], } ] - return await self.model.client.call(messages, self.model_options) + return await self.llm.client.call(messages, self.llm_options)