Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(document-search): Allow to create DocumentSearch instances from config #62

24 changes: 23 additions & 1 deletion packages/ragbits-core/src/ragbits/core/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,27 @@
import sys

from .base import Embeddings
from .litellm import LiteLLMEmbeddings
from .local import LocalEmbeddings
from .noop import NoopEmbeddings

__all__ = ["Embeddings", "LiteLLMEmbeddings", "LocalEmbeddings", "NoopEmbeddings"]

module = sys.modules[__name__]


def get_embeddings(embedder_config: dict) -> Embeddings:
"""
Initializes and returns an Embeddings object based on the provided embedder configuration.

Args:
embedder_config : A dictionary containing configuration details for the embedder.

Returns:
An instance of the specified Embeddings class, initialized with the provided config
(if any) or default arguments.
"""
embeddings_type = embedder_config["type"]
config = embedder_config.get("config", {})

__all__ = ["Embeddings", "LiteLLMEmbeddings", "LocalEmbeddings"]
return getattr(module, embeddings_type)(**config)
25 changes: 25 additions & 0 deletions packages/ragbits-core/src/ragbits/core/embeddings/noop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from ragbits.core.embeddings.base import Embeddings


class NoopEmbeddings(Embeddings):
"""
A no-op implementation of the Embeddings class.

This class provides a simple embedding method that returns a fixed
embedding vector for each input text. It's mainly useful for testing
or as a placeholder when an actual embedding model is not required.
"""

async def embed_text(self, data: list[str]) -> list[list[float]]:
"""
Embeds a list of strings into a list of vectors.

Args:
data: A list of input text strings to embed.

Returns:
A list of embedding vectors, where each vector
is a fixed value of [0.1, 0.1] for each input string.
"""

return [[0.1, 0.1]] * len(data)
26 changes: 26 additions & 0 deletions packages/ragbits-core/src/ragbits/core/utils/config_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from importlib import import_module
from types import ModuleType
from typing import Any


def get_cls_from_config(cls_path: str, default_module: ModuleType) -> Any:
"""
Retrieves and returns a class based on the given type string. The class can be either in the
default module or a specified module if provided in the type string.

Args:
cls_path: A string representing the path to the class or object. This can either be a
path implicitly referencing the default module or a full path (module.submodule:ClassName)
if the class is located in a different module.
default_module: The default module to search for the class if no specific module
is provided in the type string.

Returns:
Any: The object retrieved from the specified or default module.
"""
if ":" in cls_path:
module_stringified, object_stringified = cls_path.split(":")
module = import_module(module_stringified)
return getattr(module, object_stringified)

return getattr(default_module, cls_path)
28 changes: 27 additions & 1 deletion packages/ragbits-core/src/ragbits/core/vector_store/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,31 @@
import sys

from ..utils.config_handling import get_cls_from_config
from .base import VectorDBEntry, VectorStore
from .chromadb_store import ChromaDBStore
from .in_memory import InMemoryVectorStore

__all__ = ["VectorStore", "VectorDBEntry", "InMemoryVectorStore", "ChromaDBStore"]
__all__ = ["InMemoryVectorStore", "VectorDBEntry", "VectorStore", "ChromaDBStore"]

module = sys.modules[__name__]


def get_vector_store(vector_store_config: dict) -> VectorStore:
"""
Initializes and returns a VectorStore object based on the provided configuration.

Args:
vector_store_config: A dictionary containing configuration details for the VectorStore.

Returns:
An instance of the specified VectorStore class, initialized with the provided config
(if any) or default arguments.
"""

vector_store_cls = get_cls_from_config(vector_store_config["type"], module)
config = vector_store_config.get("config", {})

if vector_store_config["type"] == "ChromaDBStore":
return vector_store_cls.from_config(config)

return vector_store_cls(**config)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
HAS_CHROMADB = False

from ragbits.core.embeddings import Embeddings
from ragbits.core.utils.config_handling import get_cls_from_config
from ragbits.core.vector_store import VectorDBEntry, VectorStore


Expand Down Expand Up @@ -45,6 +46,32 @@ def __init__(
self._metadata = {"hnsw:space": distance_method}
self._collection = self._get_chroma_collection()

@classmethod
def from_config(cls, config: dict) -> "ChromaDBStore":
"""
Creates and returns an instance of the ChromaDBStore class from the given configuration.

Args:
config: A dictionary containing the configuration for initializing the ChromaDBStore instance.

Returns:
An initialized instance of the ChromaDBStore class.
"""
chroma_client = get_cls_from_config(config["chroma_client"]["type"], chromadb)(
**config["chroma_client"].get("config", {})
)
embedding_function = get_cls_from_config(config["embedding_function"]["type"], chromadb)(
**config["embedding_function"].get("config", {})
)

return cls(
config["index_name"],
chroma_client,
embedding_function,
max_distance=config.get("max_distance"),
distance_method=config.get("distance_method", "l2"),
)

def _get_chroma_collection(self) -> chromadb.Collection:
"""
Based on the selected embedding_function, chooses how to retrieve the ChromaDB collection.
Expand Down
51 changes: 51 additions & 0 deletions packages/ragbits-document-search/examples/from_config_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "ragbits-document-search",
# "ragbits[litellm]",
# ]
# ///
import asyncio

from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta

documents = [
DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."),
DocumentMeta.create_text_document_from_literal(
"Why doesn't James Bond fart in bed? Because it would blow his cover."
),
DocumentMeta.create_text_document_from_literal(
"Why programmers don't like to swim? Because they're scared of the floating points."
),
]

config = {
"embedder": {"type": "LiteLLMEmbeddings"},
"vector_store": {
"type": "ChromaDBStore",
"config": {
"chroma_client": {"type": "PersistentClient", "config": {"path": "chroma"}},
"embedding_function": {"type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings"},
"index_name": "jokes",
},
},
"reranker": {"type": "ragbits.document_search.retrieval.rerankers.noop:NoopReranker"},
"providers": {"txt": {"type": "DummyProvider"}},
}


async def main():
"""Run the example."""

document_search = DocumentSearch.from_config(config)

for document in documents:
await document_search.ingest_document(document)

results = await document_search.search("I'm boiling my water and I need a joke")
print(results)


if __name__ == "__main__":
asyncio.run(main())
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

from pydantic import BaseModel, Field

from ragbits.core.embeddings import Embeddings
from ragbits.core.vector_store import VectorStore
from ragbits.core.embeddings import Embeddings, get_embeddings
from ragbits.core.vector_store import VectorStore, get_vector_store
from ragbits.document_search.documents.document import Document, DocumentMeta
from ragbits.document_search.documents.element import Element
from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter
from ragbits.document_search.ingestion.providers.base import BaseProvider
from ragbits.document_search.retrieval.rephrasers import get_rephraser
from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser
from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser
from ragbits.document_search.retrieval.rerankers import get_reranker
from ragbits.document_search.retrieval.rerankers.base import Reranker
from ragbits.document_search.retrieval.rerankers.noop import NoopReranker

Expand Down Expand Up @@ -58,6 +60,29 @@ def __init__(
self.reranker = reranker or NoopReranker()
self.document_processor_router = document_processor_router or DocumentProcessorRouter.from_config()

@classmethod
def from_config(cls, config: dict) -> "DocumentSearch":
"""
Creates and returns an instance of the DocumentSearch class from the given configuration.

Args:
config: A dictionary containing the configuration for initializing the DocumentSearch instance.

Returns:
DocumentSearch: An initialized instance of the DocumentSearch class.
"""

embedder = get_embeddings(config["embedder"])
query_rephraser = get_rephraser(config.get("rephraser"))
reranker = get_reranker(config.get("reranker"))
vector_store = get_vector_store(config["vector_store"])

providers_config_dict: dict = config.get("providers", {})
providers_config = DocumentProcessorRouter.from_dict_to_providers_config(providers_config_dict)
document_processor_router = DocumentProcessorRouter.from_config(providers_config)

return cls(embedder, vector_store, query_rephraser, reranker, document_processor_router)

async def search(self, query: str, search_config: SearchConfig = SearchConfig()) -> list[Element]:
"""
Search for the most relevant chunks for a query.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.ingestion.providers import get_provider
from ragbits.document_search.ingestion.providers.base import BaseProvider
from ragbits.document_search.ingestion.providers.unstructured import UnstructuredProvider

Expand Down Expand Up @@ -39,6 +40,33 @@ class DocumentProcessorRouter:
def __init__(self, providers: dict[DocumentType, BaseProvider]):
self._providers = providers

@staticmethod
def from_dict_to_providers_config(dict_config: dict) -> ProvidersConfig:
"""
Creates ProvidersConfig from dictionary config.
Example of the dictionary config:
{
"txt": {
{
"type": "UnstructuredProvider"
}
}
}

Args:
dict_config: The dictionary with configuration.

Returns:
ProvidersConfig object.
"""

providers_config = {}

for document_type, config in dict_config.items():
providers_config[DocumentType(document_type)] = get_provider(config)

return providers_config

@classmethod
def from_config(cls, providers_config: Optional[ProvidersConfig] = None) -> "DocumentProcessorRouter":
"""
Expand All @@ -58,6 +86,7 @@ def from_config(cls, providers_config: Optional[ProvidersConfig] = None) -> "Doc
Returns:
The DocumentProcessorRouter.
"""

config = copy.deepcopy(DEFAULT_PROVIDERS_CONFIG)
config.update(providers_config if providers_config is not None else {})

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import sys

from ragbits.core.utils.config_handling import get_cls_from_config

from .base import BaseProvider
from .dummy import DummyProvider
from .unstructured import UnstructuredProvider

__all__ = ["BaseProvider", "DummyProvider", "UnstructuredProvider"]

module = sys.modules[__name__]


def get_provider(provider_config: dict) -> BaseProvider:
"""
Initializes and returns an Provider object based on the provided configuration.

Args:
provider_config : A dictionary containing configuration details for the provider.

Returns:
An instance of the specified Provider class, initialized with the provided config
(if any) or default arguments.
"""

provider_cls = get_cls_from_config(provider_config["type"], module)
config = provider_config.get("config", {})

return provider_cls(**config)
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import sys
from typing import Optional

from ragbits.core.utils.config_handling import get_cls_from_config

from .base import QueryRephraser
from .noop import NoopQueryRephraser

__all__ = ["NoopQueryRephraser", "QueryRephraser"]

module = sys.modules[__name__]


def get_rephraser(rephraser_config: Optional[dict]) -> QueryRephraser:
"""
Initializes and returns a QueryRephraser object based on the provided configuration.

Args:
rephraser_config: A dictionary containing configuration details for the QueryRephraser.

Returns:
An instance of the specified QueryRephraser class, initialized with the provided config
(if any) or default arguments.
"""

if rephraser_config is None:
return NoopQueryRephraser()

rephraser_cls = get_cls_from_config(rephraser_config["type"], module)
config = rephraser_config.get("config", {})

return rephraser_cls(**config)
Loading
Loading