Skip to content

Commit

Permalink
fix(document-search): do not create all of provider instances when im…
Browse files Browse the repository at this point in the history
…porting document processor (#219)
  • Loading branch information
kdziedzic68 authored Dec 3, 2024
1 parent 9c9a7c9 commit 02218f0
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import copy
from collections.abc import Callable
from typing import cast

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.ingestion.providers import get_provider
Expand All @@ -7,30 +9,31 @@
from ragbits.document_search.ingestion.providers.unstructured.images import UnstructuredImageProvider
from ragbits.document_search.ingestion.providers.unstructured.pdf import UnstructuredPdfProvider

ProvidersConfig = dict[DocumentType, BaseProvider]
# TODO consider defining with some defined schema
ProvidersConfig = dict[DocumentType, Callable[[], BaseProvider] | BaseProvider]


DEFAULT_PROVIDERS_CONFIG: ProvidersConfig = {
DocumentType.TXT: UnstructuredDefaultProvider(),
DocumentType.MD: UnstructuredDefaultProvider(),
DocumentType.PDF: UnstructuredPdfProvider(),
DocumentType.DOCX: UnstructuredDefaultProvider(),
DocumentType.DOC: UnstructuredDefaultProvider(),
DocumentType.PPTX: UnstructuredDefaultProvider(),
DocumentType.PPT: UnstructuredDefaultProvider(),
DocumentType.XLSX: UnstructuredDefaultProvider(),
DocumentType.XLS: UnstructuredDefaultProvider(),
DocumentType.CSV: UnstructuredDefaultProvider(),
DocumentType.HTML: UnstructuredDefaultProvider(),
DocumentType.EPUB: UnstructuredDefaultProvider(),
DocumentType.ORG: UnstructuredDefaultProvider(),
DocumentType.ODT: UnstructuredDefaultProvider(),
DocumentType.RST: UnstructuredDefaultProvider(),
DocumentType.RTF: UnstructuredDefaultProvider(),
DocumentType.TSV: UnstructuredDefaultProvider(),
DocumentType.XML: UnstructuredDefaultProvider(),
DocumentType.JPG: UnstructuredImageProvider(),
DocumentType.PNG: UnstructuredImageProvider(),
DocumentType.TXT: UnstructuredDefaultProvider,
DocumentType.MD: UnstructuredDefaultProvider,
DocumentType.PDF: UnstructuredPdfProvider,
DocumentType.DOCX: UnstructuredDefaultProvider,
DocumentType.DOC: UnstructuredDefaultProvider,
DocumentType.PPTX: UnstructuredDefaultProvider,
DocumentType.PPT: UnstructuredDefaultProvider,
DocumentType.XLSX: UnstructuredDefaultProvider,
DocumentType.XLS: UnstructuredDefaultProvider,
DocumentType.CSV: UnstructuredDefaultProvider,
DocumentType.HTML: UnstructuredDefaultProvider,
DocumentType.EPUB: UnstructuredDefaultProvider,
DocumentType.ORG: UnstructuredDefaultProvider,
DocumentType.ODT: UnstructuredDefaultProvider,
DocumentType.RST: UnstructuredDefaultProvider,
DocumentType.RTF: UnstructuredDefaultProvider,
DocumentType.TSV: UnstructuredDefaultProvider,
DocumentType.XML: UnstructuredDefaultProvider,
DocumentType.JPG: UnstructuredImageProvider,
DocumentType.PNG: UnstructuredImageProvider,
}


Expand All @@ -40,7 +43,7 @@ class DocumentProcessorRouter:
metadata such as the document type.
"""

def __init__(self, providers: dict[DocumentType, BaseProvider]):
def __init__(self, providers: dict[DocumentType, Callable[[], BaseProvider] | BaseProvider]):
self._providers = providers

@staticmethod
Expand All @@ -65,7 +68,9 @@ def from_dict_to_providers_config(dict_config: dict) -> ProvidersConfig:
providers_config = {}

for document_type, config in dict_config.items():
providers_config[DocumentType(document_type)] = get_provider(config)
providers_config[DocumentType(document_type)] = cast(
Callable[[], BaseProvider] | BaseProvider, get_provider(config)
)

return providers_config

Expand Down Expand Up @@ -106,7 +111,11 @@ def get_provider(self, document_meta: DocumentMeta) -> BaseProvider:
Raises:
ValueError: If no provider is found for the document type.
"""
provider = self._providers.get(document_meta.document_type)
if provider is None:
provider_class_or_provider = self._providers.get(document_meta.document_type)
if provider_class_or_provider is None:
raise ValueError(f"No provider found for the document type {document_meta.document_type}")
elif isinstance(provider_class_or_provider, BaseProvider):
provider = provider_class_or_provider
else:
provider = provider_class_or_provider()
return provider
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"config",
[
{},
pytest.param({DocumentType.TXT: UnstructuredDefaultProvider()}),
pytest.param(
{DocumentType.TXT: UnstructuredDefaultProvider(use_api=True)},
marks=pytest.mark.skipif(
Expand All @@ -33,7 +34,12 @@ async def test_document_processor_processes_text_document_with_unstructured_prov

elements = await document_processor.get_provider(document_meta).process(document_meta)

assert isinstance(document_processor._providers[DocumentType.TXT], UnstructuredDefaultProvider)
expected_provider_type = (
UnstructuredDefaultProvider
if isinstance(config.get(DocumentType.TXT), UnstructuredDefaultProvider)
else type(UnstructuredDefaultProvider)
)
assert isinstance(document_processor._providers[DocumentType.TXT], expected_provider_type)
assert len(elements) == 1
assert elements[0].content == "Name of Peppa's brother is George." # type: ignore

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tempfile
from collections.abc import Callable
from pathlib import Path
from unittest.mock import AsyncMock

Expand Down Expand Up @@ -55,7 +56,7 @@ async def test_document_search_ingest_from_source():
embeddings_mock = AsyncMock()
embeddings_mock.embed_text.return_value = [[0.1, 0.1]]

providers: dict[DocumentType, BaseProvider] = {DocumentType.TXT: DummyProvider()}
providers: dict[DocumentType, Callable[[], BaseProvider] | BaseProvider] = {DocumentType.TXT: DummyProvider()}
router = DocumentProcessorRouter.from_config(providers)

document_search = DocumentSearch(
Expand Down

0 comments on commit 02218f0

Please sign in to comment.