From d01f28cab0b24673df2a191e41a3975866f9b5ee Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Thu, 5 Dec 2024 12:21:24 +0100 Subject: [PATCH] DocumentSearchConfig.from_config() should take dict --- examples/document-search/from_config.py | 83 +++++++++---------- .../src/ragbits/document_search/_main.py | 18 ++-- .../tests/unit/test_document_search.py | 6 +- 3 files changed, 53 insertions(+), 54 deletions(-) diff --git a/examples/document-search/from_config.py b/examples/document-search/from_config.py index 51b8da749..1b5912f8c 100644 --- a/examples/document-search/from_config.py +++ b/examples/document-search/from_config.py @@ -31,7 +31,6 @@ class to rephrase the query. import asyncio from ragbits.document_search import DocumentSearch -from ragbits.document_search._main import DocumentSearchConfig from ragbits.document_search.documents.document import DocumentMeta documents = [ @@ -57,56 +56,54 @@ class to rephrase the query. ), ] -config = DocumentSearchConfig.model_validate( - { - "embedder": { - "type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings", - }, - "vector_store": { - "type": "ragbits.core.vector_stores.chroma:ChromaVectorStore", - "config": { - "client": { - "type": "PersistentClient", - "config": { - "path": "chroma", - }, - }, - "index_name": "jokes", - "distance_method": "l2", - "default_options": { - "k": 3, - "max_distance": 1.2, - }, - "metadata_store": { - "type": "InMemoryMetadataStore", +config = { + "embedder": { + "type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings", + }, + "vector_store": { + "type": "ragbits.core.vector_stores.chroma:ChromaVectorStore", + "config": { + "client": { + "type": "PersistentClient", + "config": { + "path": "chroma", }, }, + "index_name": "jokes", + "distance_method": "l2", + "default_options": { + "k": 3, + "max_distance": 1.2, + }, + "metadata_store": { + "type": "InMemoryMetadataStore", + }, }, - "reranker": { - "type": "ragbits.document_search.retrieval.rerankers.litellm:LiteLLMReranker", - "config": { - "model": "cohere/rerank-english-v3.0", - "default_options": { - "top_n": 3, - "max_chunks_per_doc": None, - }, + }, + "reranker": { + "type": "ragbits.document_search.retrieval.rerankers.litellm:LiteLLMReranker", + "config": { + "model": "cohere/rerank-english-v3.0", + "default_options": { + "top_n": 3, + "max_chunks_per_doc": None, }, }, - "providers": {"txt": {"type": "DummyProvider"}}, - "rephraser": { - "type": "LLMQueryRephraser", - "config": { - "llm": { - "type": "ragbits.core.llms.litellm:LiteLLM", - "config": { - "model_name": "gpt-4-turbo", - }, + }, + "providers": {"txt": {"type": "DummyProvider"}}, + "rephraser": { + "type": "LLMQueryRephraser", + "config": { + "llm": { + "type": "ragbits.core.llms.litellm:LiteLLM", + "config": { + "model_name": "gpt-4-turbo", }, - "prompt": "QueryRephraserPrompt", }, + "prompt": "QueryRephraserPrompt", }, - } -) + }, +} async def main() -> None: diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index dfd3cd1b8..976457091 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -36,7 +36,7 @@ class SearchConfig(BaseModel): class DocumentSearchConfig(BaseModel): """ - Configuration for the DocumentSearch.from_config method. + Schema for for the dict taken by DocumentSearch.from_config method. """ embedder: ObjectContructionConfig @@ -84,7 +84,7 @@ def __init__( self.processing_strategy = processing_strategy or SequentialProcessing() @classmethod - def from_config(cls, config: DocumentSearchConfig) -> "DocumentSearch": + def from_config(cls, config: dict) -> "DocumentSearch": """ Creates and returns an instance of the DocumentSearch class from the given configuration. @@ -94,13 +94,15 @@ def from_config(cls, config: DocumentSearchConfig) -> "DocumentSearch": Returns: DocumentSearch: An initialized instance of the DocumentSearch class. """ - embedder = Embeddings.subclass_from_config(config.embedder) - query_rephraser = QueryRephraser.subclass_from_config(config.rephraser) - reranker = Reranker.subclass_from_config(config.reranker) - vector_store = VectorStore.subclass_from_config(config.vector_store) - processing_strategy = ProcessingExecutionStrategy.subclass_from_config(config.processing_strategy) + model = DocumentSearchConfig.model_validate(config) - providers_config = DocumentProcessorRouter.from_dict_to_providers_config(config.providers) + embedder = Embeddings.subclass_from_config(model.embedder) + query_rephraser = QueryRephraser.subclass_from_config(model.rephraser) + reranker = Reranker.subclass_from_config(model.reranker) + vector_store = VectorStore.subclass_from_config(model.vector_store) + processing_strategy = ProcessingExecutionStrategy.subclass_from_config(model.processing_strategy) + + providers_config = DocumentProcessorRouter.from_dict_to_providers_config(model.providers) document_processor_router = DocumentProcessorRouter.from_config(providers_config) return cls(embedder, vector_store, query_rephraser, reranker, document_processor_router, processing_strategy) diff --git a/packages/ragbits-document-search/tests/unit/test_document_search.py b/packages/ragbits-document-search/tests/unit/test_document_search.py index 095e42efc..181b420be 100644 --- a/packages/ragbits-document-search/tests/unit/test_document_search.py +++ b/packages/ragbits-document-search/tests/unit/test_document_search.py @@ -7,7 +7,7 @@ from ragbits.core.vector_stores.in_memory import InMemoryVectorStore from ragbits.document_search import DocumentSearch -from ragbits.document_search._main import DocumentSearchConfig, SearchConfig +from ragbits.document_search._main import SearchConfig from ragbits.document_search.documents.document import Document, DocumentMeta, DocumentType from ragbits.document_search.documents.element import TextElement from ragbits.document_search.documents.sources import LocalFileSource @@ -41,7 +41,7 @@ ], ) async def test_document_search_from_config(document: DocumentMeta, expected: str): - document_search = DocumentSearch.from_config(DocumentSearchConfig.model_validate(CONFIG)) + document_search = DocumentSearch.from_config(CONFIG) await document_search.ingest([document]) results = await document_search.search("Peppa's brother") @@ -154,7 +154,7 @@ async def test_document_search_with_search_config(): async def test_document_search_ingest_multiple_from_sources(): - document_search = DocumentSearch.from_config(DocumentSearchConfig.model_validate(CONFIG)) + document_search = DocumentSearch.from_config(CONFIG) examples_files = Path(__file__).parent / "example_files" await document_search.ingest(