Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(document-search): do not create all of provider instances when importing document processor #219

Merged
merged 4 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import copy
from collections.abc import Callable
from typing import cast

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.ingestion.providers import get_provider
Expand All @@ -7,30 +9,31 @@
from ragbits.document_search.ingestion.providers.unstructured.images import UnstructuredImageProvider
from ragbits.document_search.ingestion.providers.unstructured.pdf import UnstructuredPdfProvider

ProvidersConfig = dict[DocumentType, BaseProvider]
# TODO consider defining with some defined schema
ProvidersConfig = dict[DocumentType, Callable[[], BaseProvider] | BaseProvider]


DEFAULT_PROVIDERS_CONFIG: ProvidersConfig = {
DocumentType.TXT: UnstructuredDefaultProvider(),
DocumentType.MD: UnstructuredDefaultProvider(),
DocumentType.PDF: UnstructuredPdfProvider(),
DocumentType.DOCX: UnstructuredDefaultProvider(),
DocumentType.DOC: UnstructuredDefaultProvider(),
DocumentType.PPTX: UnstructuredDefaultProvider(),
DocumentType.PPT: UnstructuredDefaultProvider(),
DocumentType.XLSX: UnstructuredDefaultProvider(),
DocumentType.XLS: UnstructuredDefaultProvider(),
DocumentType.CSV: UnstructuredDefaultProvider(),
DocumentType.HTML: UnstructuredDefaultProvider(),
DocumentType.EPUB: UnstructuredDefaultProvider(),
DocumentType.ORG: UnstructuredDefaultProvider(),
DocumentType.ODT: UnstructuredDefaultProvider(),
DocumentType.RST: UnstructuredDefaultProvider(),
DocumentType.RTF: UnstructuredDefaultProvider(),
DocumentType.TSV: UnstructuredDefaultProvider(),
DocumentType.XML: UnstructuredDefaultProvider(),
DocumentType.JPG: UnstructuredImageProvider(),
DocumentType.PNG: UnstructuredImageProvider(),
DocumentType.TXT: UnstructuredDefaultProvider,
DocumentType.MD: UnstructuredDefaultProvider,
DocumentType.PDF: UnstructuredPdfProvider,
DocumentType.DOCX: UnstructuredDefaultProvider,
DocumentType.DOC: UnstructuredDefaultProvider,
DocumentType.PPTX: UnstructuredDefaultProvider,
DocumentType.PPT: UnstructuredDefaultProvider,
DocumentType.XLSX: UnstructuredDefaultProvider,
DocumentType.XLS: UnstructuredDefaultProvider,
DocumentType.CSV: UnstructuredDefaultProvider,
DocumentType.HTML: UnstructuredDefaultProvider,
DocumentType.EPUB: UnstructuredDefaultProvider,
DocumentType.ORG: UnstructuredDefaultProvider,
DocumentType.ODT: UnstructuredDefaultProvider,
DocumentType.RST: UnstructuredDefaultProvider,
DocumentType.RTF: UnstructuredDefaultProvider,
DocumentType.TSV: UnstructuredDefaultProvider,
DocumentType.XML: UnstructuredDefaultProvider,
DocumentType.JPG: UnstructuredImageProvider,
DocumentType.PNG: UnstructuredImageProvider,
}


Expand All @@ -40,7 +43,7 @@ class DocumentProcessorRouter:
metadata such as the document type.
"""

def __init__(self, providers: dict[DocumentType, BaseProvider]):
def __init__(self, providers: dict[DocumentType, Callable[[], BaseProvider] | BaseProvider]):
self._providers = providers

@staticmethod
Expand All @@ -65,7 +68,9 @@ def from_dict_to_providers_config(dict_config: dict) -> ProvidersConfig:
providers_config = {}

for document_type, config in dict_config.items():
providers_config[DocumentType(document_type)] = get_provider(config)
providers_config[DocumentType(document_type)] = cast(
Callable[[], BaseProvider] | BaseProvider, get_provider(config)
)

return providers_config

Expand Down Expand Up @@ -106,7 +111,11 @@ def get_provider(self, document_meta: DocumentMeta) -> BaseProvider:
Raises:
ValueError: If no provider is found for the document type.
"""
provider = self._providers.get(document_meta.document_type)
if provider is None:
provider_class_or_provider = self._providers.get(document_meta.document_type)
if provider_class_or_provider is None:
raise ValueError(f"No provider found for the document type {document_meta.document_type}")
elif isinstance(provider_class_or_provider, BaseProvider):
provider = provider_class_or_provider
else:
provider = provider_class_or_provider()
kdziedzic68 marked this conversation as resolved.
Show resolved Hide resolved
return provider
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"config",
[
{},
pytest.param({DocumentType.TXT: UnstructuredDefaultProvider()}),
pytest.param(
{DocumentType.TXT: UnstructuredDefaultProvider(use_api=True)},
marks=pytest.mark.skipif(
Expand All @@ -33,7 +34,12 @@ async def test_document_processor_processes_text_document_with_unstructured_prov

elements = await document_processor.get_provider(document_meta).process(document_meta)

assert isinstance(document_processor._providers[DocumentType.TXT], UnstructuredDefaultProvider)
expected_provider_type = (
UnstructuredDefaultProvider
if isinstance(config.get(DocumentType.TXT), UnstructuredDefaultProvider)
else type(UnstructuredDefaultProvider)
)
assert isinstance(document_processor._providers[DocumentType.TXT], expected_provider_type)
assert len(elements) == 1
assert elements[0].content == "Name of Peppa's brother is George." # type: ignore

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tempfile
from collections.abc import Callable
from pathlib import Path
from unittest.mock import AsyncMock

Expand Down Expand Up @@ -55,7 +56,7 @@ async def test_document_search_ingest_from_source():
embeddings_mock = AsyncMock()
embeddings_mock.embed_text.return_value = [[0.1, 0.1]]

providers: dict[DocumentType, BaseProvider] = {DocumentType.TXT: DummyProvider()}
providers: dict[DocumentType, Callable[[], BaseProvider] | BaseProvider] = {DocumentType.TXT: DummyProvider()}
router = DocumentProcessorRouter.from_config(providers)

document_search = DocumentSearch(
Expand Down
Loading