From 60a37ec50d37f75eef482add3520411aa8fdbc4d Mon Sep 17 00:00:00 2001 From: "alicja.kotyla" Date: Tue, 8 Oct 2024 15:20:58 +0200 Subject: [PATCH] Integrate config with providers --- .../src/ragbits/core/vector_store/__init__.py | 3 ++ .../core/vector_store/chromadb_store.py | 18 ++++++++++++ .../examples/from_config_example.py | 10 ++++++- .../src/ragbits/document_search/_main.py | 5 +++- .../ingestion/document_processor.py | 29 +++++++++++++++++++ .../ingestion/providers/__init__.py | 29 +++++++++++++++++++ .../tests/unit/test_document_search.py | 3 +- 7 files changed, 94 insertions(+), 3 deletions(-) diff --git a/packages/ragbits-core/src/ragbits/core/vector_store/__init__.py b/packages/ragbits-core/src/ragbits/core/vector_store/__init__.py index 94a9ccf8..5a5494d5 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_store/__init__.py +++ b/packages/ragbits-core/src/ragbits/core/vector_store/__init__.py @@ -25,4 +25,7 @@ def get_vector_store(vector_store_config: dict) -> VectorStore: vector_store_cls = get_cls_from_config(vector_store_config["type"], module) config = vector_store_config.get("config", {}) + if vector_store_config["type"] == "ChromaDBStore": + return vector_store_cls.from_config(config) + return vector_store_cls(**config) diff --git a/packages/ragbits-core/src/ragbits/core/vector_store/chromadb_store.py b/packages/ragbits-core/src/ragbits/core/vector_store/chromadb_store.py index 259fdc7f..0f3f63a8 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_store/chromadb_store.py +++ b/packages/ragbits-core/src/ragbits/core/vector_store/chromadb_store.py @@ -10,6 +10,7 @@ HAS_CHROMADB = False from ragbits.core.embeddings.base import Embeddings +from ragbits.core.utils import get_cls_from_config from ragbits.core.vector_store.base import VectorStore from ragbits.core.vector_store.in_memory import VectorDBEntry @@ -46,6 +47,23 @@ def __init__( self._metadata = {"hnsw:space": distance_method} self._collection = self._get_chroma_collection() + @classmethod + def from_config(cls, config: dict) -> "ChromaDBStore": + chroma_client = get_cls_from_config(config["chroma_client"]["type"], chromadb)( + **config["chroma_client"].get("config", {}) + ) + embedding_function = get_cls_from_config(config["embedding_function"]["type"], chromadb)( + **config["embedding_function"].get("config", {}) + ) + + return cls( + config["index_name"], + chroma_client, + embedding_function, + max_distance=config.get("max_distance"), + distance_method=config.get("distance_method", "l2"), + ) + def _get_chroma_collection(self) -> chromadb.Collection: """ Based on the selected embedding_function, chooses how to retrieve the ChromaDB collection. diff --git a/packages/ragbits-document-search/examples/from_config_example.py b/packages/ragbits-document-search/examples/from_config_example.py index c5b9b898..1599cf84 100644 --- a/packages/ragbits-document-search/examples/from_config_example.py +++ b/packages/ragbits-document-search/examples/from_config_example.py @@ -22,8 +22,16 @@ config = { "embedder": {"type": "LiteLLMEmbeddings"}, - "vector_store": {"type": "InMemoryVectorStore"}, + "vector_store": { + "type": "ChromaDBStore", + "config": { + "chroma_client": {"type": "PersistentClient", "config": {"path": "chroma"}}, + "embedding_function": {"type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings"}, + "index_name": "jokes", + }, + }, "reranker": {"type": "ragbits.document_search.retrieval.rerankers.noop:NoopReranker"}, + "providers": {"txt": {"type": "DummyProvider"}}, } diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index fa35a42b..69778b5b 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -79,7 +79,10 @@ def from_config(cls, config: dict) -> "DocumentSearch": reranker = get_reranker(config.get("reranker")) vector_store = get_vector_store(config["vector_store"]) - return cls(embedder, vector_store, query_rephraser, reranker) + providers_config = DocumentProcessorRouter.from_dict_to_providers_config(config.get("providers")) + document_processor_router = DocumentProcessorRouter.from_config(providers_config) + + return cls(embedder, vector_store, query_rephraser, reranker, document_processor_router) async def search(self, query: str, search_config: SearchConfig = SearchConfig()) -> list[Element]: """ diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py index 79da68d7..0ce6b7ce 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py @@ -2,6 +2,7 @@ from typing import Optional from ragbits.document_search.documents.document import DocumentMeta, DocumentType +from ragbits.document_search.ingestion.providers import get_provider from ragbits.document_search.ingestion.providers.base import BaseProvider from ragbits.document_search.ingestion.providers.unstructured import UnstructuredProvider @@ -39,6 +40,33 @@ class DocumentProcessorRouter: def __init__(self, providers: dict[DocumentType, BaseProvider]): self._providers = providers + @staticmethod + def from_dict_to_providers_config(dict_config: dict) -> ProvidersConfig: + """ + Creates ProvidersConfig from dictionary config. + Example of the dictionary config: + { + "txt": { + { + "type": "UnstructuredProvider" + } + } + } + + Args: + dict_config: The dictionary with configuration. + + Returns: + ProvidersConfig object. + """ + + providers_config = {} + + for document_type, config in dict_config.items(): + providers_config[DocumentType(document_type)] = get_provider(config) + + return providers_config + @classmethod def from_config(cls, providers_config: Optional[ProvidersConfig] = None) -> "DocumentProcessorRouter": """ @@ -58,6 +86,7 @@ def from_config(cls, providers_config: Optional[ProvidersConfig] = None) -> "Doc Returns: The DocumentProcessorRouter. """ + config = copy.deepcopy(DEFAULT_PROVIDERS_CONFIG) config.update(providers_config if providers_config is not None else {}) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/__init__.py index e69de29b..0bc05eee 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/__init__.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/__init__.py @@ -0,0 +1,29 @@ +import sys + +from ragbits.core.utils import get_cls_from_config + +from .base import BaseProvider +from .dummy import DummyProvider +from .unstructured import UnstructuredProvider + +__all__ = ["BaseProvider", "DummyProvider", "UnstructuredProvider"] + +module = sys.modules[__name__] + + +def get_provider(provider_config: dict) -> BaseProvider: + """ + Initializes and returns an Provider object based on the provided configuration. + + Args: + provider_config : A dictionary containing configuration details for the provider. + + Returns: + An instance of the specified Provider class, initialized with the provided config + (if any) or default arguments. + """ + + provider_cls = get_cls_from_config(provider_config["type"], module) + config = provider_config.get("config", {}) + + return provider_cls(**config) diff --git a/packages/ragbits-document-search/tests/unit/test_document_search.py b/packages/ragbits-document-search/tests/unit/test_document_search.py index 3c192de3..76743c89 100644 --- a/packages/ragbits-document-search/tests/unit/test_document_search.py +++ b/packages/ragbits-document-search/tests/unit/test_document_search.py @@ -15,6 +15,7 @@ "embedder": {"type": "NoopEmbeddings"}, "vector_store": {"type": "ragbits.core.vector_store.in_memory:InMemoryVectorStore"}, "reranker": {"type": "NoopReranker"}, + "providers": {"txt": {"type": "DummyProvider"}}, } @@ -36,7 +37,7 @@ async def test_document_search_from_config(document, expected): document_search = DocumentSearch.from_config(CONFIG) - await document_search.ingest_document(document, document_processor=DummyProvider()) + await document_search.ingest_document(document) results = await document_search.search("Peppa's brother") first_result = results[0]