Skip to content

Commit

Permalink
Integrate config with providers
Browse files Browse the repository at this point in the history
  • Loading branch information
akotyla committed Oct 8, 2024
1 parent db87234 commit 60a37ec
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,7 @@ def get_vector_store(vector_store_config: dict) -> VectorStore:
vector_store_cls = get_cls_from_config(vector_store_config["type"], module)
config = vector_store_config.get("config", {})

if vector_store_config["type"] == "ChromaDBStore":
return vector_store_cls.from_config(config)

return vector_store_cls(**config)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
HAS_CHROMADB = False

from ragbits.core.embeddings.base import Embeddings
from ragbits.core.utils import get_cls_from_config
from ragbits.core.vector_store.base import VectorStore
from ragbits.core.vector_store.in_memory import VectorDBEntry

Expand Down Expand Up @@ -46,6 +47,23 @@ def __init__(
self._metadata = {"hnsw:space": distance_method}
self._collection = self._get_chroma_collection()

@classmethod
def from_config(cls, config: dict) -> "ChromaDBStore":
chroma_client = get_cls_from_config(config["chroma_client"]["type"], chromadb)(
**config["chroma_client"].get("config", {})
)
embedding_function = get_cls_from_config(config["embedding_function"]["type"], chromadb)(
**config["embedding_function"].get("config", {})
)

return cls(
config["index_name"],
chroma_client,
embedding_function,
max_distance=config.get("max_distance"),
distance_method=config.get("distance_method", "l2"),
)

def _get_chroma_collection(self) -> chromadb.Collection:
"""
Based on the selected embedding_function, chooses how to retrieve the ChromaDB collection.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,16 @@

config = {
"embedder": {"type": "LiteLLMEmbeddings"},
"vector_store": {"type": "InMemoryVectorStore"},
"vector_store": {
"type": "ChromaDBStore",
"config": {
"chroma_client": {"type": "PersistentClient", "config": {"path": "chroma"}},
"embedding_function": {"type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings"},
"index_name": "jokes",
},
},
"reranker": {"type": "ragbits.document_search.retrieval.rerankers.noop:NoopReranker"},
"providers": {"txt": {"type": "DummyProvider"}},
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def from_config(cls, config: dict) -> "DocumentSearch":
reranker = get_reranker(config.get("reranker"))
vector_store = get_vector_store(config["vector_store"])

return cls(embedder, vector_store, query_rephraser, reranker)
providers_config = DocumentProcessorRouter.from_dict_to_providers_config(config.get("providers"))
document_processor_router = DocumentProcessorRouter.from_config(providers_config)

return cls(embedder, vector_store, query_rephraser, reranker, document_processor_router)

async def search(self, query: str, search_config: SearchConfig = SearchConfig()) -> list[Element]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.ingestion.providers import get_provider
from ragbits.document_search.ingestion.providers.base import BaseProvider
from ragbits.document_search.ingestion.providers.unstructured import UnstructuredProvider

Expand Down Expand Up @@ -39,6 +40,33 @@ class DocumentProcessorRouter:
def __init__(self, providers: dict[DocumentType, BaseProvider]):
self._providers = providers

@staticmethod
def from_dict_to_providers_config(dict_config: dict) -> ProvidersConfig:
"""
Creates ProvidersConfig from dictionary config.
Example of the dictionary config:
{
"txt": {
{
"type": "UnstructuredProvider"
}
}
}
Args:
dict_config: The dictionary with configuration.
Returns:
ProvidersConfig object.
"""

providers_config = {}

for document_type, config in dict_config.items():
providers_config[DocumentType(document_type)] = get_provider(config)

return providers_config

@classmethod
def from_config(cls, providers_config: Optional[ProvidersConfig] = None) -> "DocumentProcessorRouter":
"""
Expand All @@ -58,6 +86,7 @@ def from_config(cls, providers_config: Optional[ProvidersConfig] = None) -> "Doc
Returns:
The DocumentProcessorRouter.
"""

config = copy.deepcopy(DEFAULT_PROVIDERS_CONFIG)
config.update(providers_config if providers_config is not None else {})

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import sys

from ragbits.core.utils import get_cls_from_config

from .base import BaseProvider
from .dummy import DummyProvider
from .unstructured import UnstructuredProvider

__all__ = ["BaseProvider", "DummyProvider", "UnstructuredProvider"]

module = sys.modules[__name__]


def get_provider(provider_config: dict) -> BaseProvider:
"""
Initializes and returns an Provider object based on the provided configuration.
Args:
provider_config : A dictionary containing configuration details for the provider.
Returns:
An instance of the specified Provider class, initialized with the provided config
(if any) or default arguments.
"""

provider_cls = get_cls_from_config(provider_config["type"], module)
config = provider_config.get("config", {})

return provider_cls(**config)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"embedder": {"type": "NoopEmbeddings"},
"vector_store": {"type": "ragbits.core.vector_store.in_memory:InMemoryVectorStore"},
"reranker": {"type": "NoopReranker"},
"providers": {"txt": {"type": "DummyProvider"}},
}


Expand All @@ -36,7 +37,7 @@
async def test_document_search_from_config(document, expected):
document_search = DocumentSearch.from_config(CONFIG)

await document_search.ingest_document(document, document_processor=DummyProvider())
await document_search.ingest_document(document)
results = await document_search.search("Peppa's brother")

first_result = results[0]
Expand Down

0 comments on commit 60a37ec

Please sign in to comment.