Skip to content

Commit

Permalink
DocumentSearchConfig.from_config() should take dict
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer committed Dec 5, 2024
1 parent 65d287d commit d01f28c
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 54 deletions.
83 changes: 40 additions & 43 deletions examples/document-search/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class to rephrase the query.
import asyncio

from ragbits.document_search import DocumentSearch
from ragbits.document_search._main import DocumentSearchConfig
from ragbits.document_search.documents.document import DocumentMeta

documents = [
Expand All @@ -57,56 +56,54 @@ class to rephrase the query.
),
]

config = DocumentSearchConfig.model_validate(
{
"embedder": {
"type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings",
},
"vector_store": {
"type": "ragbits.core.vector_stores.chroma:ChromaVectorStore",
"config": {
"client": {
"type": "PersistentClient",
"config": {
"path": "chroma",
},
},
"index_name": "jokes",
"distance_method": "l2",
"default_options": {
"k": 3,
"max_distance": 1.2,
},
"metadata_store": {
"type": "InMemoryMetadataStore",
config = {
"embedder": {
"type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings",
},
"vector_store": {
"type": "ragbits.core.vector_stores.chroma:ChromaVectorStore",
"config": {
"client": {
"type": "PersistentClient",
"config": {
"path": "chroma",
},
},
"index_name": "jokes",
"distance_method": "l2",
"default_options": {
"k": 3,
"max_distance": 1.2,
},
"metadata_store": {
"type": "InMemoryMetadataStore",
},
},
"reranker": {
"type": "ragbits.document_search.retrieval.rerankers.litellm:LiteLLMReranker",
"config": {
"model": "cohere/rerank-english-v3.0",
"default_options": {
"top_n": 3,
"max_chunks_per_doc": None,
},
},
"reranker": {
"type": "ragbits.document_search.retrieval.rerankers.litellm:LiteLLMReranker",
"config": {
"model": "cohere/rerank-english-v3.0",
"default_options": {
"top_n": 3,
"max_chunks_per_doc": None,
},
},
"providers": {"txt": {"type": "DummyProvider"}},
"rephraser": {
"type": "LLMQueryRephraser",
"config": {
"llm": {
"type": "ragbits.core.llms.litellm:LiteLLM",
"config": {
"model_name": "gpt-4-turbo",
},
},
"providers": {"txt": {"type": "DummyProvider"}},
"rephraser": {
"type": "LLMQueryRephraser",
"config": {
"llm": {
"type": "ragbits.core.llms.litellm:LiteLLM",
"config": {
"model_name": "gpt-4-turbo",
},
"prompt": "QueryRephraserPrompt",
},
"prompt": "QueryRephraserPrompt",
},
}
)
},
}


async def main() -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class SearchConfig(BaseModel):

class DocumentSearchConfig(BaseModel):
"""
Configuration for the DocumentSearch.from_config method.
Schema for for the dict taken by DocumentSearch.from_config method.
"""

embedder: ObjectContructionConfig
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
self.processing_strategy = processing_strategy or SequentialProcessing()

@classmethod
def from_config(cls, config: DocumentSearchConfig) -> "DocumentSearch":
def from_config(cls, config: dict) -> "DocumentSearch":
"""
Creates and returns an instance of the DocumentSearch class from the given configuration.
Expand All @@ -94,13 +94,15 @@ def from_config(cls, config: DocumentSearchConfig) -> "DocumentSearch":
Returns:
DocumentSearch: An initialized instance of the DocumentSearch class.
"""
embedder = Embeddings.subclass_from_config(config.embedder)
query_rephraser = QueryRephraser.subclass_from_config(config.rephraser)
reranker = Reranker.subclass_from_config(config.reranker)
vector_store = VectorStore.subclass_from_config(config.vector_store)
processing_strategy = ProcessingExecutionStrategy.subclass_from_config(config.processing_strategy)
model = DocumentSearchConfig.model_validate(config)

providers_config = DocumentProcessorRouter.from_dict_to_providers_config(config.providers)
embedder = Embeddings.subclass_from_config(model.embedder)
query_rephraser = QueryRephraser.subclass_from_config(model.rephraser)
reranker = Reranker.subclass_from_config(model.reranker)
vector_store = VectorStore.subclass_from_config(model.vector_store)
processing_strategy = ProcessingExecutionStrategy.subclass_from_config(model.processing_strategy)

providers_config = DocumentProcessorRouter.from_dict_to_providers_config(model.providers)
document_processor_router = DocumentProcessorRouter.from_config(providers_config)

return cls(embedder, vector_store, query_rephraser, reranker, document_processor_router, processing_strategy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ragbits.core.vector_stores.in_memory import InMemoryVectorStore
from ragbits.document_search import DocumentSearch
from ragbits.document_search._main import DocumentSearchConfig, SearchConfig
from ragbits.document_search._main import SearchConfig
from ragbits.document_search.documents.document import Document, DocumentMeta, DocumentType
from ragbits.document_search.documents.element import TextElement
from ragbits.document_search.documents.sources import LocalFileSource
Expand Down Expand Up @@ -41,7 +41,7 @@
],
)
async def test_document_search_from_config(document: DocumentMeta, expected: str):
document_search = DocumentSearch.from_config(DocumentSearchConfig.model_validate(CONFIG))
document_search = DocumentSearch.from_config(CONFIG)

await document_search.ingest([document])
results = await document_search.search("Peppa's brother")
Expand Down Expand Up @@ -154,7 +154,7 @@ async def test_document_search_with_search_config():


async def test_document_search_ingest_multiple_from_sources():
document_search = DocumentSearch.from_config(DocumentSearchConfig.model_validate(CONFIG))
document_search = DocumentSearch.from_config(CONFIG)
examples_files = Path(__file__).parent / "example_files"

await document_search.ingest(
Expand Down

0 comments on commit d01f28c

Please sign in to comment.