From 8df50f8143a88de4e927f48f918629a7fdbc0856 Mon Sep 17 00:00:00 2001 From: kdziedzic68 Date: Mon, 2 Dec 2024 11:13:09 +0100 Subject: [PATCH 1/4] remove providers from document processors import --- .../ingestion/document_processor.py | 57 +++++++++++-------- .../tests/integration/test_unstructured.py | 2 +- .../tests/unit/test_document_search.py | 2 +- 3 files changed, 34 insertions(+), 27 deletions(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py index f6160937..b5ceff8f 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py @@ -1,4 +1,5 @@ import copy +from typing import cast from ragbits.document_search.documents.document import DocumentMeta, DocumentType from ragbits.document_search.ingestion.providers import get_provider @@ -7,30 +8,30 @@ from ragbits.document_search.ingestion.providers.unstructured.images import UnstructuredImageProvider from ragbits.document_search.ingestion.providers.unstructured.pdf import UnstructuredPdfProvider -ProvidersConfig = dict[DocumentType, BaseProvider] +ProvidersConfig = dict[DocumentType, type[BaseProvider] | BaseProvider] DEFAULT_PROVIDERS_CONFIG: ProvidersConfig = { - DocumentType.TXT: UnstructuredDefaultProvider(), - DocumentType.MD: UnstructuredDefaultProvider(), - DocumentType.PDF: UnstructuredPdfProvider(), - DocumentType.DOCX: UnstructuredDefaultProvider(), - DocumentType.DOC: UnstructuredDefaultProvider(), - DocumentType.PPTX: UnstructuredDefaultProvider(), - DocumentType.PPT: UnstructuredDefaultProvider(), - DocumentType.XLSX: UnstructuredDefaultProvider(), - DocumentType.XLS: UnstructuredDefaultProvider(), - DocumentType.CSV: UnstructuredDefaultProvider(), - DocumentType.HTML: UnstructuredDefaultProvider(), - DocumentType.EPUB: UnstructuredDefaultProvider(), - DocumentType.ORG: UnstructuredDefaultProvider(), - DocumentType.ODT: UnstructuredDefaultProvider(), - DocumentType.RST: UnstructuredDefaultProvider(), - DocumentType.RTF: UnstructuredDefaultProvider(), - DocumentType.TSV: UnstructuredDefaultProvider(), - DocumentType.XML: UnstructuredDefaultProvider(), - DocumentType.JPG: UnstructuredImageProvider(), - DocumentType.PNG: UnstructuredImageProvider(), + DocumentType.TXT: UnstructuredDefaultProvider, + DocumentType.MD: UnstructuredDefaultProvider, + DocumentType.PDF: UnstructuredPdfProvider, + DocumentType.DOCX: UnstructuredDefaultProvider, + DocumentType.DOC: UnstructuredDefaultProvider, + DocumentType.PPTX: UnstructuredDefaultProvider, + DocumentType.PPT: UnstructuredDefaultProvider, + DocumentType.XLSX: UnstructuredDefaultProvider, + DocumentType.XLS: UnstructuredDefaultProvider, + DocumentType.CSV: UnstructuredDefaultProvider, + DocumentType.HTML: UnstructuredDefaultProvider, + DocumentType.EPUB: UnstructuredDefaultProvider, + DocumentType.ORG: UnstructuredDefaultProvider, + DocumentType.ODT: UnstructuredDefaultProvider, + DocumentType.RST: UnstructuredDefaultProvider, + DocumentType.RTF: UnstructuredDefaultProvider, + DocumentType.TSV: UnstructuredDefaultProvider, + DocumentType.XML: UnstructuredDefaultProvider, + DocumentType.JPG: UnstructuredImageProvider, + DocumentType.PNG: UnstructuredImageProvider, } @@ -40,7 +41,7 @@ class DocumentProcessorRouter: metadata such as the document type. """ - def __init__(self, providers: dict[DocumentType, BaseProvider]): + def __init__(self, providers: dict[DocumentType, type[BaseProvider] | BaseProvider]): self._providers = providers @staticmethod @@ -65,7 +66,9 @@ def from_dict_to_providers_config(dict_config: dict) -> ProvidersConfig: providers_config = {} for document_type, config in dict_config.items(): - providers_config[DocumentType(document_type)] = get_provider(config) + providers_config[DocumentType(document_type)] = cast( + type[BaseProvider] | BaseProvider, get_provider(config) + ) return providers_config @@ -106,7 +109,11 @@ def get_provider(self, document_meta: DocumentMeta) -> BaseProvider: Raises: ValueError: If no provider is found for the document type. """ - provider = self._providers.get(document_meta.document_type) - if provider is None: + provider_class_or_provider = self._providers.get(document_meta.document_type) + if provider_class_or_provider is None: raise ValueError(f"No provider found for the document type {document_meta.document_type}") + elif isinstance(provider_class_or_provider, BaseProvider): + provider = provider_class_or_provider + else: + provider = provider_class_or_provider() return provider diff --git a/packages/ragbits-document-search/tests/integration/test_unstructured.py b/packages/ragbits-document-search/tests/integration/test_unstructured.py index 05c2e45c..b3516ae0 100644 --- a/packages/ragbits-document-search/tests/integration/test_unstructured.py +++ b/packages/ragbits-document-search/tests/integration/test_unstructured.py @@ -33,7 +33,7 @@ async def test_document_processor_processes_text_document_with_unstructured_prov elements = await document_processor.get_provider(document_meta).process(document_meta) - assert isinstance(document_processor._providers[DocumentType.TXT], UnstructuredDefaultProvider) + assert isinstance(document_processor._providers[DocumentType.TXT], type(UnstructuredDefaultProvider)) assert len(elements) == 1 assert elements[0].content == "Name of Peppa's brother is George." # type: ignore diff --git a/packages/ragbits-document-search/tests/unit/test_document_search.py b/packages/ragbits-document-search/tests/unit/test_document_search.py index b544ca47..c6f89686 100644 --- a/packages/ragbits-document-search/tests/unit/test_document_search.py +++ b/packages/ragbits-document-search/tests/unit/test_document_search.py @@ -55,7 +55,7 @@ async def test_document_search_ingest_from_source(): embeddings_mock = AsyncMock() embeddings_mock.embed_text.return_value = [[0.1, 0.1]] - providers: dict[DocumentType, BaseProvider] = {DocumentType.TXT: DummyProvider()} + providers: dict[DocumentType, BaseProvider | type[BaseProvider]] = {DocumentType.TXT: DummyProvider()} router = DocumentProcessorRouter.from_config(providers) document_search = DocumentSearch( From 50b19a564ee93a649002bf6682e5cd6780a19654 Mon Sep 17 00:00:00 2001 From: kdziedzic68 Date: Tue, 3 Dec 2024 10:40:52 +0100 Subject: [PATCH 2/4] test for providers --- .../tests/integration/test_unstructured.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/packages/ragbits-document-search/tests/integration/test_unstructured.py b/packages/ragbits-document-search/tests/integration/test_unstructured.py index b3516ae0..e75de77a 100644 --- a/packages/ragbits-document-search/tests/integration/test_unstructured.py +++ b/packages/ragbits-document-search/tests/integration/test_unstructured.py @@ -18,6 +18,7 @@ "config", [ {}, + pytest.param({DocumentType.TXT: UnstructuredDefaultProvider()}), pytest.param( {DocumentType.TXT: UnstructuredDefaultProvider(use_api=True)}, marks=pytest.mark.skipif( @@ -33,7 +34,12 @@ async def test_document_processor_processes_text_document_with_unstructured_prov elements = await document_processor.get_provider(document_meta).process(document_meta) - assert isinstance(document_processor._providers[DocumentType.TXT], type(UnstructuredDefaultProvider)) + expected_provider_type = ( + UnstructuredDefaultProvider + if isinstance(config.get(DocumentType.TXT), UnstructuredDefaultProvider) + else type(UnstructuredDefaultProvider) + ) + assert isinstance(document_processor._providers[DocumentType.TXT], expected_provider_type) assert len(elements) == 1 assert elements[0].content == "Name of Peppa's brother is George." # type: ignore From 61e54820b3a9c2a3650560830a7a6f71ddf8449d Mon Sep 17 00:00:00 2001 From: kdziedzic68 Date: Tue, 3 Dec 2024 12:21:36 +0100 Subject: [PATCH 3/4] change typing --- .../document_search/ingestion/document_processor.py | 8 +++++--- .../tests/unit/test_document_search.py | 3 ++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py index b5ceff8f..3577f999 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py @@ -1,4 +1,5 @@ import copy +from collections.abc import Callable from typing import cast from ragbits.document_search.documents.document import DocumentMeta, DocumentType @@ -8,7 +9,8 @@ from ragbits.document_search.ingestion.providers.unstructured.images import UnstructuredImageProvider from ragbits.document_search.ingestion.providers.unstructured.pdf import UnstructuredPdfProvider -ProvidersConfig = dict[DocumentType, type[BaseProvider] | BaseProvider] +# TODO consider defining with some defined schema +ProvidersConfig = dict[DocumentType, Callable[[], BaseProvider] | BaseProvider] DEFAULT_PROVIDERS_CONFIG: ProvidersConfig = { @@ -41,7 +43,7 @@ class DocumentProcessorRouter: metadata such as the document type. """ - def __init__(self, providers: dict[DocumentType, type[BaseProvider] | BaseProvider]): + def __init__(self, providers: dict[DocumentType, Callable[[], BaseProvider] | BaseProvider]): self._providers = providers @staticmethod @@ -67,7 +69,7 @@ def from_dict_to_providers_config(dict_config: dict) -> ProvidersConfig: for document_type, config in dict_config.items(): providers_config[DocumentType(document_type)] = cast( - type[BaseProvider] | BaseProvider, get_provider(config) + Callable[[], BaseProvider] | BaseProvider, get_provider(config) ) return providers_config diff --git a/packages/ragbits-document-search/tests/unit/test_document_search.py b/packages/ragbits-document-search/tests/unit/test_document_search.py index c6f89686..d35149ca 100644 --- a/packages/ragbits-document-search/tests/unit/test_document_search.py +++ b/packages/ragbits-document-search/tests/unit/test_document_search.py @@ -1,6 +1,7 @@ import tempfile from pathlib import Path from unittest.mock import AsyncMock +from typing import Callable import pytest @@ -55,7 +56,7 @@ async def test_document_search_ingest_from_source(): embeddings_mock = AsyncMock() embeddings_mock.embed_text.return_value = [[0.1, 0.1]] - providers: dict[DocumentType, BaseProvider | type[BaseProvider]] = {DocumentType.TXT: DummyProvider()} + providers: dict[DocumentType, Callable[[], BaseProvider] | BaseProvider] = {DocumentType.TXT: DummyProvider()} router = DocumentProcessorRouter.from_config(providers) document_search = DocumentSearch( From a0201ec57abcc2baf7012fd3e90c7f03490c7aa4 Mon Sep 17 00:00:00 2001 From: kdziedzic68 Date: Tue, 3 Dec 2024 12:23:18 +0100 Subject: [PATCH 4/4] fix ruff --- .../ragbits-document-search/tests/unit/test_document_search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/ragbits-document-search/tests/unit/test_document_search.py b/packages/ragbits-document-search/tests/unit/test_document_search.py index d35149ca..181b420b 100644 --- a/packages/ragbits-document-search/tests/unit/test_document_search.py +++ b/packages/ragbits-document-search/tests/unit/test_document_search.py @@ -1,7 +1,7 @@ import tempfile +from collections.abc import Callable from pathlib import Path from unittest.mock import AsyncMock -from typing import Callable import pytest