From 02218f0711771d4154b8e8d97d0abca589fc35f4 Mon Sep 17 00:00:00 2001 From: kdziedzic68 Date: Tue, 3 Dec 2024 12:38:39 +0100 Subject: [PATCH] fix(document-search): do not create all of provider instances when importing document processor (#219) --- .../ingestion/document_processor.py | 59 +++++++++++-------- .../tests/integration/test_unstructured.py | 8 ++- .../tests/unit/test_document_search.py | 3 +- 3 files changed, 43 insertions(+), 27 deletions(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py index 64c71150..9a5cf054 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py @@ -1,4 +1,6 @@ import copy +from collections.abc import Callable +from typing import cast from ragbits.document_search.documents.document import DocumentMeta, DocumentType from ragbits.document_search.ingestion.providers import get_provider @@ -7,30 +9,31 @@ from ragbits.document_search.ingestion.providers.unstructured.images import UnstructuredImageProvider from ragbits.document_search.ingestion.providers.unstructured.pdf import UnstructuredPdfProvider -ProvidersConfig = dict[DocumentType, BaseProvider] +# TODO consider defining with some defined schema +ProvidersConfig = dict[DocumentType, Callable[[], BaseProvider] | BaseProvider] DEFAULT_PROVIDERS_CONFIG: ProvidersConfig = { - DocumentType.TXT: UnstructuredDefaultProvider(), - DocumentType.MD: UnstructuredDefaultProvider(), - DocumentType.PDF: UnstructuredPdfProvider(), - DocumentType.DOCX: UnstructuredDefaultProvider(), - DocumentType.DOC: UnstructuredDefaultProvider(), - DocumentType.PPTX: UnstructuredDefaultProvider(), - DocumentType.PPT: UnstructuredDefaultProvider(), - DocumentType.XLSX: UnstructuredDefaultProvider(), - DocumentType.XLS: UnstructuredDefaultProvider(), - DocumentType.CSV: UnstructuredDefaultProvider(), - DocumentType.HTML: UnstructuredDefaultProvider(), - DocumentType.EPUB: UnstructuredDefaultProvider(), - DocumentType.ORG: UnstructuredDefaultProvider(), - DocumentType.ODT: UnstructuredDefaultProvider(), - DocumentType.RST: UnstructuredDefaultProvider(), - DocumentType.RTF: UnstructuredDefaultProvider(), - DocumentType.TSV: UnstructuredDefaultProvider(), - DocumentType.XML: UnstructuredDefaultProvider(), - DocumentType.JPG: UnstructuredImageProvider(), - DocumentType.PNG: UnstructuredImageProvider(), + DocumentType.TXT: UnstructuredDefaultProvider, + DocumentType.MD: UnstructuredDefaultProvider, + DocumentType.PDF: UnstructuredPdfProvider, + DocumentType.DOCX: UnstructuredDefaultProvider, + DocumentType.DOC: UnstructuredDefaultProvider, + DocumentType.PPTX: UnstructuredDefaultProvider, + DocumentType.PPT: UnstructuredDefaultProvider, + DocumentType.XLSX: UnstructuredDefaultProvider, + DocumentType.XLS: UnstructuredDefaultProvider, + DocumentType.CSV: UnstructuredDefaultProvider, + DocumentType.HTML: UnstructuredDefaultProvider, + DocumentType.EPUB: UnstructuredDefaultProvider, + DocumentType.ORG: UnstructuredDefaultProvider, + DocumentType.ODT: UnstructuredDefaultProvider, + DocumentType.RST: UnstructuredDefaultProvider, + DocumentType.RTF: UnstructuredDefaultProvider, + DocumentType.TSV: UnstructuredDefaultProvider, + DocumentType.XML: UnstructuredDefaultProvider, + DocumentType.JPG: UnstructuredImageProvider, + DocumentType.PNG: UnstructuredImageProvider, } @@ -40,7 +43,7 @@ class DocumentProcessorRouter: metadata such as the document type. """ - def __init__(self, providers: dict[DocumentType, BaseProvider]): + def __init__(self, providers: dict[DocumentType, Callable[[], BaseProvider] | BaseProvider]): self._providers = providers @staticmethod @@ -65,7 +68,9 @@ def from_dict_to_providers_config(dict_config: dict) -> ProvidersConfig: providers_config = {} for document_type, config in dict_config.items(): - providers_config[DocumentType(document_type)] = get_provider(config) + providers_config[DocumentType(document_type)] = cast( + Callable[[], BaseProvider] | BaseProvider, get_provider(config) + ) return providers_config @@ -106,7 +111,11 @@ def get_provider(self, document_meta: DocumentMeta) -> BaseProvider: Raises: ValueError: If no provider is found for the document type. """ - provider = self._providers.get(document_meta.document_type) - if provider is None: + provider_class_or_provider = self._providers.get(document_meta.document_type) + if provider_class_or_provider is None: raise ValueError(f"No provider found for the document type {document_meta.document_type}") + elif isinstance(provider_class_or_provider, BaseProvider): + provider = provider_class_or_provider + else: + provider = provider_class_or_provider() return provider diff --git a/packages/ragbits-document-search/tests/integration/test_unstructured.py b/packages/ragbits-document-search/tests/integration/test_unstructured.py index 05c2e45c..e75de77a 100644 --- a/packages/ragbits-document-search/tests/integration/test_unstructured.py +++ b/packages/ragbits-document-search/tests/integration/test_unstructured.py @@ -18,6 +18,7 @@ "config", [ {}, + pytest.param({DocumentType.TXT: UnstructuredDefaultProvider()}), pytest.param( {DocumentType.TXT: UnstructuredDefaultProvider(use_api=True)}, marks=pytest.mark.skipif( @@ -33,7 +34,12 @@ async def test_document_processor_processes_text_document_with_unstructured_prov elements = await document_processor.get_provider(document_meta).process(document_meta) - assert isinstance(document_processor._providers[DocumentType.TXT], UnstructuredDefaultProvider) + expected_provider_type = ( + UnstructuredDefaultProvider + if isinstance(config.get(DocumentType.TXT), UnstructuredDefaultProvider) + else type(UnstructuredDefaultProvider) + ) + assert isinstance(document_processor._providers[DocumentType.TXT], expected_provider_type) assert len(elements) == 1 assert elements[0].content == "Name of Peppa's brother is George." # type: ignore diff --git a/packages/ragbits-document-search/tests/unit/test_document_search.py b/packages/ragbits-document-search/tests/unit/test_document_search.py index b544ca47..181b420b 100644 --- a/packages/ragbits-document-search/tests/unit/test_document_search.py +++ b/packages/ragbits-document-search/tests/unit/test_document_search.py @@ -1,4 +1,5 @@ import tempfile +from collections.abc import Callable from pathlib import Path from unittest.mock import AsyncMock @@ -55,7 +56,7 @@ async def test_document_search_ingest_from_source(): embeddings_mock = AsyncMock() embeddings_mock.embed_text.return_value = [[0.1, 0.1]] - providers: dict[DocumentType, BaseProvider] = {DocumentType.TXT: DummyProvider()} + providers: dict[DocumentType, Callable[[], BaseProvider] | BaseProvider] = {DocumentType.TXT: DummyProvider()} router = DocumentProcessorRouter.from_config(providers) document_search = DocumentSearch(