diff --git a/examples/apps/documents_chat.py b/examples/apps/documents_chat.py index 595577990..6236c30aa 100644 --- a/examples/apps/documents_chat.py +++ b/examples/apps/documents_chat.py @@ -19,7 +19,6 @@ from ragbits.core.vector_store.chromadb_store import ChromaDBStore from ragbits.document_search import DocumentSearch from ragbits.document_search.documents.document import DocumentMeta -from ragbits.document_search.documents.element import TextElement class QueryWithContext(BaseModel): @@ -124,9 +123,7 @@ async def _handle_message( if not self._documents_ingested: yield self.NO_DOCUMENTS_INGESTED_MESSAGE results = await self.document_search.search(message[-1]) - prompt = RAGPrompt( - QueryWithContext(query=message, context=[i.content for i in results if isinstance(i, TextElement)]) - ) + prompt = RAGPrompt(QueryWithContext(query=message, context=[i.get_key() for i in results])) response = await self._llm.generate(prompt) yield response.answer diff --git a/packages/ragbits-core/src/ragbits/core/prompt/base.py b/packages/ragbits-core/src/ragbits/core/prompt/base.py index 47bf427f1..d2faedd73 100644 --- a/packages/ragbits-core/src/ragbits/core/prompt/base.py +++ b/packages/ragbits-core/src/ragbits/core/prompt/base.py @@ -1,10 +1,10 @@ from abc import ABCMeta, abstractmethod -from typing import Dict, Generic, List, Optional, Type +from typing import Dict, Generic, Optional, Type from pydantic import BaseModel from typing_extensions import TypeVar -ChatFormat = List[Dict[str, str]] +ChatFormat = list[dict[str, str]] OutputT = TypeVar("OutputT", default=str) diff --git a/packages/ragbits-document-search/pyproject.toml b/packages/ragbits-document-search/pyproject.toml index ecd0063a6..615ccff13 100644 --- a/packages/ragbits-document-search/pyproject.toml +++ b/packages/ragbits-document-search/pyproject.toml @@ -34,7 +34,8 @@ dependencies = [ "numpy~=1.24.0", "unstructured>=0.15.13", "unstructured-client>=0.26.0", - "ragbits-core==0.1.0", + "pdf2image>=1.17.0", + "ragbits-core==0.1.0" ] [project.optional-dependencies] diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py index 581f33006..a8683b209 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py @@ -28,6 +28,8 @@ class DocumentType(str, Enum): RTF = "rtf" TSV = "tsv" XML = "xml" + JPG = "jpg" + PNG = "png" UNKNOWN = "unknown" diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/element.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/element.py index 744aed729..67a152f62 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/element.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/element.py @@ -85,3 +85,23 @@ def get_key(self) -> str: The key. """ return self.content + + +class ImageElement(Element): + """ + An object representing an image element in a document. + """ + + element_type: str = "image" + description: str + ocr_extracted_text: str + image_bytes: bytes + + def get_key(self) -> str: + """ + Get the key of the element which will be used to generate the vector. + + Returns: + The key. + """ + return self.description + " " + self.ocr_extracted_text diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py index 0ce6b7cef..f029673c6 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py @@ -4,30 +4,34 @@ from ragbits.document_search.documents.document import DocumentMeta, DocumentType from ragbits.document_search.ingestion.providers import get_provider from ragbits.document_search.ingestion.providers.base import BaseProvider -from ragbits.document_search.ingestion.providers.unstructured import UnstructuredProvider +from ragbits.document_search.ingestion.providers.unstructured.default import UnstructuredDefaultProvider +from ragbits.document_search.ingestion.providers.unstructured.images import UnstructuredImageProvider +from ragbits.document_search.ingestion.providers.unstructured.pdf import UnstructuredPdfProvider ProvidersConfig = dict[DocumentType, BaseProvider] DEFAULT_PROVIDERS_CONFIG: ProvidersConfig = { - DocumentType.TXT: UnstructuredProvider(), - DocumentType.MD: UnstructuredProvider(), - DocumentType.PDF: UnstructuredProvider(), - DocumentType.DOCX: UnstructuredProvider(), - DocumentType.DOC: UnstructuredProvider(), - DocumentType.PPTX: UnstructuredProvider(), - DocumentType.PPT: UnstructuredProvider(), - DocumentType.XLSX: UnstructuredProvider(), - DocumentType.XLS: UnstructuredProvider(), - DocumentType.CSV: UnstructuredProvider(), - DocumentType.HTML: UnstructuredProvider(), - DocumentType.EPUB: UnstructuredProvider(), - DocumentType.ORG: UnstructuredProvider(), - DocumentType.ODT: UnstructuredProvider(), - DocumentType.RST: UnstructuredProvider(), - DocumentType.RTF: UnstructuredProvider(), - DocumentType.TSV: UnstructuredProvider(), - DocumentType.XML: UnstructuredProvider(), + DocumentType.TXT: UnstructuredDefaultProvider(), + DocumentType.MD: UnstructuredDefaultProvider(), + DocumentType.PDF: UnstructuredPdfProvider(), + DocumentType.DOCX: UnstructuredDefaultProvider(), + DocumentType.DOC: UnstructuredDefaultProvider(), + DocumentType.PPTX: UnstructuredDefaultProvider(), + DocumentType.PPT: UnstructuredDefaultProvider(), + DocumentType.XLSX: UnstructuredDefaultProvider(), + DocumentType.XLS: UnstructuredDefaultProvider(), + DocumentType.CSV: UnstructuredDefaultProvider(), + DocumentType.HTML: UnstructuredDefaultProvider(), + DocumentType.EPUB: UnstructuredDefaultProvider(), + DocumentType.ORG: UnstructuredDefaultProvider(), + DocumentType.ODT: UnstructuredDefaultProvider(), + DocumentType.RST: UnstructuredDefaultProvider(), + DocumentType.RTF: UnstructuredDefaultProvider(), + DocumentType.TSV: UnstructuredDefaultProvider(), + DocumentType.XML: UnstructuredDefaultProvider(), + DocumentType.JPG: UnstructuredImageProvider(), + DocumentType.PNG: UnstructuredImageProvider(), } diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/__init__.py index fc2c5d2f5..dd06046b5 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/__init__.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/__init__.py @@ -4,16 +4,24 @@ from .base import BaseProvider from .dummy import DummyProvider -from .unstructured import UnstructuredProvider - -__all__ = ["BaseProvider", "DummyProvider", "UnstructuredProvider"] +from .unstructured.default import UnstructuredDefaultProvider +from .unstructured.images import UnstructuredImageProvider +from .unstructured.pdf import UnstructuredPdfProvider + +__all__ = [ + "BaseProvider", + "DummyProvider", + "UnstructuredDefaultProvider", + "UnstructuredPdfProvider", + "UnstructuredImageProvider", +] module = sys.modules[__name__] def get_provider(provider_config: dict) -> BaseProvider: """ - Initializes and returns an Provider object based on the provided configuration. + Initializes and returns a Provider object based on the provided configuration. Args: provider_config : A dictionary containing configuration details for the provider. diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/default.py similarity index 75% rename from packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured.py rename to packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/default.py index 116d38645..d6c92681d 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/default.py @@ -1,5 +1,5 @@ -import os from io import BytesIO +from pathlib import Path from typing import Optional from unstructured.chunking.basic import chunk_elements @@ -9,8 +9,9 @@ from unstructured_client import UnstructuredClient from ragbits.document_search.documents.document import DocumentMeta, DocumentType -from ragbits.document_search.documents.element import Element, TextElement +from ragbits.document_search.documents.element import Element from ragbits.document_search.ingestion.providers.base import BaseProvider +from ragbits.document_search.ingestion.providers.unstructured.utils import check_required_argument, to_text_element DEFAULT_PARTITION_KWARGS: dict = { "strategy": "hi_res", @@ -26,15 +27,14 @@ UNSTRUCTURED_SERVER_URL_ENV = "UNSTRUCTURED_SERVER_URL" -class UnstructuredProvider(BaseProvider): +class UnstructuredDefaultProvider(BaseProvider): """ - A provider that uses the Unstructured API to process the documents. + A provider that uses the Unstructured API or local SDK to process the documents. """ SUPPORTED_DOCUMENT_TYPES = { DocumentType.TXT, DocumentType.MD, - DocumentType.PDF, DocumentType.DOCX, DocumentType.DOC, DocumentType.PPTX, @@ -59,8 +59,9 @@ def __init__( api_key: Optional[str] = None, api_server: Optional[str] = None, use_api: bool = False, + ignore_images: bool = False, ) -> None: - """Initialize the UnstructuredProvider. + """Initialize the UnstructuredDefaultProvider. Args: partition_kwargs: The additional arguments for the partitioning. Refer to the Unstructured API documentation @@ -70,6 +71,8 @@ def __init__( variable will be used. api_server: The API server URL to use for the Unstructured API. If not specified, the UNSTRUCTURED_SERVER_URL environment variable will be used. + use_api: whether to use Unstructured API, otherwise use local version of Unstructured library + ignore_images: if True images will be skipped """ self.partition_kwargs = partition_kwargs or DEFAULT_PARTITION_KWARGS self.chunking_kwargs = chunking_kwargs or DEFAULT_CHUNKING_KWARGS @@ -77,6 +80,7 @@ def __init__( self.api_server = api_server self.use_api = use_api self._client = None + self.ignore_images = ignore_images @property def client(self) -> UnstructuredClient: @@ -91,8 +95,10 @@ def client(self) -> UnstructuredClient: """ if self._client is not None: return self._client - api_key = _set_or_raise(name="api_key", value=self.api_key, env_var=UNSTRUCTURED_API_KEY_ENV) - api_server = _set_or_raise(name="api_server", value=self.api_server, env_var=UNSTRUCTURED_SERVER_URL_ENV) + api_key = check_required_argument(arg_name="api_key", value=self.api_key, fallback_env=UNSTRUCTURED_API_KEY_ENV) + api_server = check_required_argument( + arg_name="api_server", value=self.api_server, fallback_env=UNSTRUCTURED_SERVER_URL_ENV + ) self._client = UnstructuredClient(api_key_auth=api_key, server_url=api_server) return self._client @@ -120,6 +126,7 @@ async def process(self, document_meta: DocumentMeta) -> list[Element]: "content": document.local_path.read_bytes(), "file_name": document.local_path.name, }, + "coordinates": True, **self.partition_kwargs, } } @@ -132,20 +139,14 @@ async def process(self, document_meta: DocumentMeta) -> list[Element]: **self.partition_kwargs, ) - elements = chunk_elements(elements, **self.chunking_kwargs) - return [_to_text_element(element, document_meta) for element in elements] + return await self._chunk_and_convert(elements, document_meta, document.local_path) - -def _to_text_element(element: UnstructuredElement, document_meta: DocumentMeta) -> TextElement: - return TextElement( - document_meta=document_meta, - content=element.text, - ) - - -def _set_or_raise(name: str, value: Optional[str], env_var: str) -> str: - if value is not None: - return value - if (env_value := os.getenv(env_var)) is None: - raise ValueError(f"Either pass {name} argument or set the {env_var} environment variable") - return env_value + async def _chunk_and_convert( + # pylint: disable=unused-argument + self, + elements: list[UnstructuredElement], + document_meta: DocumentMeta, + document_path: Path, + ) -> list[Element]: + chunked_elements = chunk_elements(elements, **self.chunking_kwargs) + return [to_text_element(element, document_meta) for element in chunked_elements] diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py new file mode 100644 index 000000000..466ace1dc --- /dev/null +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py @@ -0,0 +1,107 @@ +from pathlib import Path +from typing import Optional + +from PIL import Image +from unstructured.chunking.basic import chunk_elements +from unstructured.documents.elements import Element as UnstructuredElement +from unstructured.documents.elements import ElementType + +from ragbits.core.llms.base import LLM +from ragbits.core.llms.litellm import LiteLLM +from ragbits.document_search.documents.document import DocumentMeta, DocumentType +from ragbits.document_search.documents.element import Element, ImageElement +from ragbits.document_search.ingestion.providers.unstructured.default import UnstructuredDefaultProvider +from ragbits.document_search.ingestion.providers.unstructured.utils import ( + ImageDescriber, + crop_and_convert_to_bytes, + extract_image_coordinates, + to_text_element, +) + +DEFAULT_LLM_IMAGE_SUMMARIZATION_MODEL = "gpt-4o-mini" + + +class UnstructuredImageProvider(UnstructuredDefaultProvider): + """ + A specialized provider that handles pngs and jpgs using the Unstructured + """ + + SUPPORTED_DOCUMENT_TYPES = { + DocumentType.JPG, + DocumentType.PNG, + } + + def __init__( + self, + partition_kwargs: Optional[dict] = None, + chunking_kwargs: Optional[dict] = None, + api_key: Optional[str] = None, + api_server: Optional[str] = None, + use_api: bool = False, + llm: Optional[LLM] = None, + ) -> None: + """Initialize the UnstructuredPdfProvider. + + Args: + partition_kwargs: The additional arguments for the partitioning. Refer to the Unstructured API documentation + for the available options: https://docs.unstructured.io/api-reference/api-services/api-parameters + chunking_kwargs: The additional arguments for the chunking. + api_key: The API key to use for the Unstructured API. If not specified, the UNSTRUCTURED_API_KEY environment + variable will be used. + api_server: The API server URL to use for the Unstructured API. If not specified, the + UNSTRUCTURED_SERVER_URL environment variable will be used. + llm: llm to use + """ + super().__init__(partition_kwargs, chunking_kwargs, api_key, api_server, use_api) + self.image_summarizer = ImageDescriber(llm or LiteLLM(DEFAULT_LLM_IMAGE_SUMMARIZATION_MODEL)) + + async def _chunk_and_convert( + self, elements: list[UnstructuredElement], document_meta: DocumentMeta, document_path: Path + ) -> list[Element]: + image_elements = [e for e in elements if e.category == ElementType.IMAGE] + other_elements = [e for e in elements if e.category != ElementType.IMAGE] + chunked_other_elements = chunk_elements(other_elements, **self.chunking_kwargs) + + text_elements: list[Element] = [to_text_element(element, document_meta) for element in chunked_other_elements] + if self.ignore_images: + return text_elements + return text_elements + [ + await self._to_image_element(element, document_meta, document_path) for element in image_elements + ] + + async def _to_image_element( + self, element: UnstructuredElement, document_meta: DocumentMeta, document_path: Path + ) -> ImageElement: + top_x, top_y, bottom_x, bottom_y = extract_image_coordinates(element) + image = self._load_document_as_image(document_path) + top_x, top_y, bottom_x, bottom_y = self._convert_coordinates( + top_x, top_y, bottom_x, bottom_y, image.width, image.height, element + ) + + img_bytes = crop_and_convert_to_bytes(image, top_x, top_y, bottom_x, bottom_y) + image_description = await self.image_summarizer.get_image_description(img_bytes) + return ImageElement( + description=image_description, + ocr_extracted_text=element.text, + image_bytes=img_bytes, + document_meta=document_meta, + ) + + @staticmethod + def _load_document_as_image( + document_path: Path, page: Optional[int] = None # pylint: disable=unused-argument + ) -> Image: + return Image.open(document_path).convert("RGB") + + @staticmethod + def _convert_coordinates( + # pylint: disable=unused-argument + top_x: float, + top_y: float, + bottom_x: float, + bottom_y: float, + image_width: int, + image_height: int, + element: UnstructuredElement, + ) -> tuple[float, float, float, float]: + return top_x, top_y, bottom_x, bottom_y diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/pdf.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/pdf.py new file mode 100644 index 000000000..585282196 --- /dev/null +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/pdf.py @@ -0,0 +1,44 @@ +from pathlib import Path +from typing import Optional + +from pdf2image import convert_from_path +from PIL import Image +from unstructured.documents.coordinates import CoordinateSystem, Orientation +from unstructured.documents.elements import Element as UnstructuredElement + +from ragbits.document_search.documents.document import DocumentType +from ragbits.document_search.ingestion.providers.unstructured.images import UnstructuredImageProvider + + +class UnstructuredPdfProvider(UnstructuredImageProvider): + """ + A specialized provider that handles pdfs using the Unstructured + """ + + SUPPORTED_DOCUMENT_TYPES = { + DocumentType.PDF, + } + + @staticmethod + def _load_document_as_image(document_path: Path, page: Optional[int] = None) -> Image: + return convert_from_path(document_path, first_page=page, last_page=page)[0] + + @staticmethod + def _convert_coordinates( + top_x: float, + top_y: float, + bottom_x: float, + bottom_y: float, + image_width: int, + image_height: int, + element: UnstructuredElement, + ) -> tuple[float, float, float, float]: + new_system = CoordinateSystem(image_width, image_height) + new_system.orientation = Orientation.SCREEN + new_top_x, new_top_y = element.metadata.coordinates.system.convert_coordinates_to_new_system( + new_system, top_x, top_y + ) + new_bottom_x, new_bottom_y = element.metadata.coordinates.system.convert_coordinates_to_new_system( + new_system, bottom_x, bottom_y + ) + return new_top_x, new_top_y, new_bottom_x, new_bottom_y diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/utils.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/utils.py new file mode 100644 index 000000000..5347168d4 --- /dev/null +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/utils.py @@ -0,0 +1,119 @@ +import base64 +import io +import os +from typing import Optional + +from PIL import Image +from unstructured.documents.elements import Element as UnstructuredElement + +from ragbits.core.llms.base import LLM +from ragbits.document_search.documents.document import DocumentMeta +from ragbits.document_search.documents.element import TextElement + + +def to_text_element(element: UnstructuredElement, document_meta: DocumentMeta) -> TextElement: + """ + Converts unstructured element to ragbits text element + + Args: + element: element from unstructured + document_meta: metadata of the document + + Returns: + text element + """ + return TextElement( + document_meta=document_meta, + content=element.text, + ) + + +def check_required_argument(value: Optional[str], arg_name: str, fallback_env: str) -> str: + """ + Checks if given environment variable is set and returns it or raises an error + + Args: + arg_name: name of the variable + value: optional default value + fallback_env: name of the environment variable to get + + Raises: + ValueError: if environment variable is not set + + Returns: + environment variable value + """ + if value is not None: + return value + if (env_value := os.getenv(fallback_env)) is None: + raise ValueError(f"Either pass {arg_name} argument or set the {fallback_env} environment variable") + return env_value + + +def extract_image_coordinates(element: UnstructuredElement) -> tuple[float, float, float, float]: + """ + Extracts image coordinates from unstructured element + Args: + element: element from unstructured + Returns: + x of top left corner, y of top left corner, x of bottom right corner, y of bottom right corner + """ + p1, p2, p3, p4 = element.metadata.coordinates.points + return min(p1[0], p2[0]), min(p1[1], p4[1]), max(p3[0], p4[0]), max(p2[1], p3[1]) + + +def crop_and_convert_to_bytes(image: Image, x0: float, y0: float, x1: float, y1: float) -> bytes: + """ + Crops the image and converts to bytes + Args: + image: PIL image + x0: x of top left corner + y0: y of top left corner + x1: x of bottom right corner + y1: y of bottom right corner + Returns: + bytes of the cropped image + """ + image = image.crop((x0, y0, x1, y1)) + buffered = io.BytesIO() + image.save(buffered, format="JPEG") + return buffered.getvalue() + + +class ImageDescriber: + """ + Describes images content using an LLM + """ + + DEFAULT_PROMPT = "Describe the content of the image." + + def __init__(self, llm: LLM): + self.llm = llm + + async def get_image_description(self, image_bytes: bytes, prompt: Optional[str] = DEFAULT_PROMPT) -> str: + """ + Provides summary of the image (passed as bytes) + + Args: + image_bytes: bytes of the image + prompt: prompt to be used + + Returns: + summary of the image + """ + img_base64 = base64.b64encode(image_bytes).decode("utf-8") + + # TODO make this use prompt structure from ragbits core once there is a support for images + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": f"{prompt}"}, + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}, + }, + ], + } + ] + return await self.llm.client.call(messages, self.llm.default_options) # type: ignore diff --git a/packages/ragbits-document-search/tests/integration/test_unstructured.py b/packages/ragbits-document-search/tests/integration/test_unstructured.py index 0ed20475f..e37cc1209 100644 --- a/packages/ragbits-document-search/tests/integration/test_unstructured.py +++ b/packages/ragbits-document-search/tests/integration/test_unstructured.py @@ -4,11 +4,11 @@ from ragbits.document_search.documents.document import DocumentMeta, DocumentType from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter -from ragbits.document_search.ingestion.providers.unstructured import ( +from ragbits.document_search.ingestion.providers.unstructured.default import ( DEFAULT_PARTITION_KWARGS, UNSTRUCTURED_API_KEY_ENV, UNSTRUCTURED_SERVER_URL_ENV, - UnstructuredProvider, + UnstructuredDefaultProvider, ) from ..helpers import env_vars_not_set @@ -19,7 +19,7 @@ [ {}, pytest.param( - {DocumentType.TXT: UnstructuredProvider(use_api=True)}, + {DocumentType.TXT: UnstructuredDefaultProvider(use_api=True)}, marks=pytest.mark.skipif( env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]), reason="Unstructured API environment variables not set", @@ -33,7 +33,7 @@ async def test_document_processor_processes_text_document_with_unstructured_prov elements = await document_processor.get_provider(document_meta).process(document_meta) - assert isinstance(document_processor._providers[DocumentType.TXT], UnstructuredProvider) + assert isinstance(document_processor._providers[DocumentType.TXT], UnstructuredDefaultProvider) assert len(elements) == 1 assert elements[0].content == "Name of Peppa's brother is George." @@ -52,6 +52,25 @@ async def test_document_processor_processes_md_document_with_unstructured_provid assert elements[0].content == "Ragbits\n\nRepository for internal experiment with our upcoming LLM framework." +@pytest.mark.skipif( + env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]), + reason="Unstructured API environment variables not set", +) +@pytest.mark.skipif( + env_vars_not_set(["OPENAI_API_KEY"]), + reason="OpenAI API environment variables not set", +) +@pytest.mark.parametrize("file_name", ["transformers_paper_page.pdf", "transformers_paper_page.png"]) +async def test_document_processor_processes_image_document_with_unstructured_provider(file_name): + document_processor = DocumentProcessorRouter.from_config() + document_meta = DocumentMeta.from_local_path(Path(__file__).parent / file_name) + + elements = await document_processor.get_provider(document_meta).process(document_meta) + + assert len(elements) == 7 + assert elements[-1].description != "" + + @pytest.mark.parametrize( "use_api", [ @@ -67,7 +86,7 @@ async def test_document_processor_processes_md_document_with_unstructured_provid ) async def test_unstructured_provider_document_with_default_partition_kwargs(use_api): document_meta = DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.") - unstructured_provider = UnstructuredProvider(use_api=use_api) + unstructured_provider = UnstructuredDefaultProvider(use_api=use_api) elements = await unstructured_provider.process(document_meta) assert unstructured_provider.partition_kwargs == DEFAULT_PARTITION_KWARGS @@ -91,7 +110,7 @@ async def test_unstructured_provider_document_with_default_partition_kwargs(use_ async def test_unstructured_provider_document_with_custom_partition_kwargs(use_api): document_meta = DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.") partition_kwargs = {"languages": ["pl"], "strategy": "fast"} - unstructured_provider = UnstructuredProvider(use_api=use_api, partition_kwargs=partition_kwargs) + unstructured_provider = UnstructuredDefaultProvider(use_api=use_api, partition_kwargs=partition_kwargs) elements = await unstructured_provider.process(document_meta) assert unstructured_provider.partition_kwargs == partition_kwargs diff --git a/packages/ragbits-document-search/tests/integration/transformers_paper_page.pdf b/packages/ragbits-document-search/tests/integration/transformers_paper_page.pdf new file mode 100644 index 000000000..6d19028d3 Binary files /dev/null and b/packages/ragbits-document-search/tests/integration/transformers_paper_page.pdf differ diff --git a/packages/ragbits-document-search/tests/integration/transformers_paper_page.png b/packages/ragbits-document-search/tests/integration/transformers_paper_page.png new file mode 100644 index 000000000..35be618c0 Binary files /dev/null and b/packages/ragbits-document-search/tests/integration/transformers_paper_page.png differ diff --git a/packages/ragbits-document-search/tests/unit/test_providers.py b/packages/ragbits-document-search/tests/unit/test_providers.py index dc936c9f0..8444dbf10 100644 --- a/packages/ragbits-document-search/tests/unit/test_providers.py +++ b/packages/ragbits-document-search/tests/unit/test_providers.py @@ -5,25 +5,37 @@ from ragbits.document_search.documents.document import DocumentMeta, DocumentType from ragbits.document_search.ingestion.providers.base import DocumentTypeNotSupportedError -from ragbits.document_search.ingestion.providers.unstructured import UnstructuredProvider +from ragbits.document_search.ingestion.providers.unstructured.default import UnstructuredDefaultProvider +from ragbits.document_search.ingestion.providers.unstructured.images import UnstructuredImageProvider +from ragbits.document_search.ingestion.providers.unstructured.pdf import UnstructuredPdfProvider -@pytest.mark.parametrize("document_type", UnstructuredProvider.SUPPORTED_DOCUMENT_TYPES) +@pytest.mark.parametrize("document_type", UnstructuredDefaultProvider.SUPPORTED_DOCUMENT_TYPES) def test_unsupported_provider_validates_supported_document_types_passes(document_type: DocumentType): - UnstructuredProvider().validate_document_type(document_type) + UnstructuredDefaultProvider().validate_document_type(document_type) + + +@pytest.mark.parametrize("document_type", UnstructuredPdfProvider.SUPPORTED_DOCUMENT_TYPES) +def test_unsupported_pdf_provider_validates_supported_document_types_passes(document_type: DocumentType): + UnstructuredPdfProvider().validate_document_type(document_type) + + +@pytest.mark.parametrize("document_type", UnstructuredImageProvider.SUPPORTED_DOCUMENT_TYPES) +def test_unsupported_images_provider_validates_supported_document_types_passes(document_type: DocumentType): + UnstructuredImageProvider().validate_document_type(document_type) def test_unsupported_provider_validates_supported_document_types_fails(): with pytest.raises(DocumentTypeNotSupportedError) as err: - UnstructuredProvider().validate_document_type(DocumentType.UNKNOWN) + UnstructuredDefaultProvider().validate_document_type(DocumentType.UNKNOWN) - assert "Document type unknown is not supported by the UnstructuredProvider" in str(err.value) + assert "Document type unknown is not supported by the UnstructuredDefaultProvider" in str(err.value) @patch.dict(os.environ, {}, clear=True) async def test_unstructured_provider_raises_value_error_when_api_key_not_set(): with pytest.raises(ValueError) as err: - await UnstructuredProvider(use_api=True).process( + await UnstructuredDefaultProvider(use_api=True).process( DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.") ) @@ -33,7 +45,7 @@ async def test_unstructured_provider_raises_value_error_when_api_key_not_set(): @patch.dict(os.environ, {}, clear=True) async def test_unstructured_provider_raises_value_error_when_server_url_not_set(): with pytest.raises(ValueError) as err: - await UnstructuredProvider(api_key="api_key", use_api=True).process( + await UnstructuredDefaultProvider(api_key="api_key", use_api=True).process( DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.") )