diff --git a/examples/document-search/chroma.py b/examples/document-search/chroma.py index 0c5421a2..5f4ec70b 100644 --- a/examples/document-search/chroma.py +++ b/examples/document-search/chroma.py @@ -7,7 +7,7 @@ # /// import asyncio -import chromadb +from chromadb import PersistentClient from ragbits.core.embeddings import LiteLLMEmbeddings from ragbits.core.vector_store.chromadb_store import ChromaDBStore @@ -27,27 +27,36 @@ async def main() -> None: """ Run the example. """ - chroma_client = chromadb.PersistentClient(path="chroma") - embedding_client = LiteLLMEmbeddings() - vector_store = ChromaDBStore( + client=PersistentClient("./chroma"), index_name="jokes", - chroma_client=chroma_client, - embedding_function=embedding_client, ) - document_search = DocumentSearch(embedder=embedding_client, vector_store=vector_store) + embedder = LiteLLMEmbeddings("text-embedding-3-small") + document_search = DocumentSearch( + embedder=embedder, + vector_store=vector_store, + ) await document_search.ingest(documents) + all_documents = await vector_store.list() + print() print("All documents:") - all_documents = await vector_store.list() print([doc.metadata["content"] for doc in all_documents]) query = "I'm boiling my water and I need a joke" + vector_store_kwargs = { + "k": 2, + "max_distance": None, + } + results = await document_search.search( + query, + config=SearchConfig(vector_store_kwargs=vector_store_kwargs), + ) + print() print(f"Documents similar to: {query}") - results = await document_search.search(query, search_config=SearchConfig(vector_store_kwargs={"k": 2})) print([element.get_key() for element in results]) diff --git a/examples/document-search/from_config.py b/examples/document-search/from_config.py index 3e3aa806..507af8e0 100644 --- a/examples/document-search/from_config.py +++ b/examples/document-search/from_config.py @@ -25,9 +25,18 @@ "vector_store": { "type": "ChromaDBStore", "config": { - "chroma_client": {"type": "PersistentClient", "config": {"path": "chroma"}}, - "embedding_function": {"type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings"}, + "client": { + "type": "PersistentClient", + "config": { + "path": "chroma", + }, + }, "index_name": "jokes", + "distance_method": "l2", + "default_options": { + "k": 3, + "max_distance": 1.15, + }, }, }, "reranker": {"type": "ragbits.document_search.retrieval.rerankers.noop:NoopReranker"}, diff --git a/packages/ragbits-core/src/ragbits/core/cli.py b/packages/ragbits-core/src/ragbits/core/cli.py index f1ac8dee..c359312e 100644 --- a/packages/ragbits-core/src/ragbits/core/cli.py +++ b/packages/ragbits-core/src/ragbits/core/cli.py @@ -81,7 +81,6 @@ def execute( Raises: ValueError: If `llm_factory` is not provided. """ - from ragbits.core.llms.factory import get_llm_from_factory prompt = _render(prompt_path=prompt_path, payload=payload) diff --git a/packages/ragbits-core/src/ragbits/core/vector_store/base.py b/packages/ragbits-core/src/ragbits/core/vector_store/base.py index 29e76e7e..a04cebd9 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_store/base.py +++ b/packages/ragbits-core/src/ragbits/core/vector_store/base.py @@ -1,7 +1,9 @@ -import abc +from abc import ABC, abstractmethod from pydantic import BaseModel +WhereQuery = dict[str, str | int | float | bool] + class VectorDBEntry(BaseModel): """ @@ -13,15 +15,25 @@ class VectorDBEntry(BaseModel): metadata: dict -WhereQuery = dict[str, str | int | float | bool] +class VectorStoreOptions(BaseModel, ABC): + """ + An object representing the options for the vector store. + """ + + k: int = 5 + max_distance: float | None = None -class VectorStore(abc.ABC): +class VectorStore(ABC): """ A class with an implementation of Vector Store, allowing to store and retrieve vectors by similarity function. """ - @abc.abstractmethod + def __init__(self, default_options: VectorStoreOptions | None = None) -> None: + super().__init__() + self._default_options = default_options or VectorStoreOptions() + + @abstractmethod async def store(self, entries: list[VectorDBEntry]) -> None: """ Store entries in the vector store. @@ -30,20 +42,20 @@ async def store(self, entries: list[VectorDBEntry]) -> None: entries: The entries to store. """ - @abc.abstractmethod - async def retrieve(self, vector: list[float], k: int = 5) -> list[VectorDBEntry]: + @abstractmethod + async def retrieve(self, vector: list[float], options: VectorStoreOptions | None = None) -> list[VectorDBEntry]: """ Retrieve entries from the vector store. Args: vector: The vector to search for. - k: The number of entries to retrieve. + options: The options for querying the vector store. Returns: The entries. """ - @abc.abstractmethod + @abstractmethod async def list( self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0 ) -> list[VectorDBEntry]: diff --git a/packages/ragbits-core/src/ragbits/core/vector_store/chromadb_store.py b/packages/ragbits-core/src/ragbits/core/vector_store/chromadb_store.py index 0296a9c8..1918d7e1 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_store/chromadb_store.py +++ b/packages/ragbits-core/src/ragbits/core/vector_store/chromadb_store.py @@ -6,14 +6,15 @@ try: import chromadb - - HAS_CHROMADB = True + from chromadb import Collection + from chromadb.api import ClientAPI except ImportError: HAS_CHROMADB = False +else: + HAS_CHROMADB = True -from ragbits.core.embeddings import Embeddings from ragbits.core.utils.config_handling import get_cls_from_config -from ragbits.core.vector_store import VectorDBEntry, VectorStore, WhereQuery +from ragbits.core.vector_store.base import VectorDBEntry, VectorStore, VectorStoreOptions, WhereQuery class ChromaDBStore(VectorStore): @@ -23,33 +24,41 @@ class ChromaDBStore(VectorStore): def __init__( self, + client: ClientAPI, index_name: str, - chroma_client: chromadb.ClientAPI, - embedding_function: Embeddings | chromadb.EmbeddingFunction, - max_distance: float | None = None, distance_method: Literal["l2", "ip", "cosine"] = "l2", + default_options: VectorStoreOptions | None = None, ): """ Initializes the ChromaDBStore with the given parameters. Args: + client: The ChromaDB client. index_name: The name of the index. - chroma_client: The ChromaDB client. - embedding_function: The embedding function. - max_distance: The maximum distance for similarity. distance_method: The distance method to use. + default_options: The default options for querying the vector store. """ if not HAS_CHROMADB: raise ImportError("Install the 'ragbits-document-search[chromadb]' extra to use LiteLLM embeddings models") - super().__init__() + super().__init__(default_options) + self._client = client self._index_name = index_name - self._chroma_client = chroma_client - self._embedding_function = embedding_function - self._max_distance = max_distance - self._metadata = {"hnsw:space": distance_method} + self._distance_method = distance_method self._collection = self._get_chroma_collection() + def _get_chroma_collection(self) -> Collection: + """ + Gets or creates a collection with the given name and metadata. + + Returns: + The collection. + """ + return self._client.get_or_create_collection( + name=self._index_name, + metadata={"hnsw:space": self._distance_method}, + ) + @classmethod def from_config(cls, config: dict) -> ChromaDBStore: """ @@ -61,75 +70,14 @@ def from_config(cls, config: dict) -> ChromaDBStore: Returns: An initialized instance of the ChromaDBStore class. """ - chroma_client = get_cls_from_config(config["chroma_client"]["type"], chromadb)( - **config["chroma_client"].get("config", {}) - ) - embedding_function = get_cls_from_config(config["embedding_function"]["type"], chromadb)( - **config["embedding_function"].get("config", {}) - ) - + client = get_cls_from_config(config["client"]["type"], chromadb) # type: ignore return cls( - config["index_name"], - chroma_client, - embedding_function, - max_distance=config.get("max_distance"), + client=client(**config["client"].get("config", {})), + index_name=config["index_name"], distance_method=config.get("distance_method", "l2"), + default_options=VectorStoreOptions(**config.get("options", {})), ) - def _get_chroma_collection(self) -> chromadb.Collection: - """ - Based on the selected embedding_function, chooses how to retrieve the ChromaDB collection. - If the collection doesn't exist, it creates one. - - Returns: - Retrieved collection - """ - if isinstance(self._embedding_function, Embeddings): - return self._chroma_client.get_or_create_collection(name=self._index_name, metadata=self._metadata) - - return self._chroma_client.get_or_create_collection( - name=self._index_name, - metadata=self._metadata, - embedding_function=self._embedding_function, - ) - - def _return_best_match(self, retrieved: dict) -> str | None: - """ - Based on the retrieved data, returns the best match or None if no match is found. - - Args: - retrieved: Retrieved data, with a column-first format. - - Returns: - The best match or None if no match is found. - """ - if self._max_distance is None or retrieved["distances"][0][0] <= self._max_distance: - return retrieved["documents"][0][0] - - return None - - @staticmethod - def _process_db_entry(entry: VectorDBEntry) -> tuple[str, list[float], dict]: - doc_id = sha256(entry.key.encode("utf-8")).hexdigest() - embedding = entry.vector - - metadata = { - "__key": entry.key, - "__metadata": json.dumps(entry.metadata, default=str), - } - - return doc_id, embedding, metadata - - @property - def embedding_function(self) -> Embeddings | chromadb.EmbeddingFunction: - """ - Returns the embedding function. - - Returns: - The embedding function. - """ - return self._embedding_function - async def store(self, entries: list[VectorDBEntry]) -> None: """ Stores entries in the ChromaDB collection. @@ -137,38 +85,47 @@ async def store(self, entries: list[VectorDBEntry]) -> None: Args: entries: The entries to store. """ - entries_processed = list(map(self._process_db_entry, entries)) - ids, embeddings, metadatas = map(list, zip(*entries_processed, strict=False)) + # TODO: Think about better id components for hashing + ids = [sha256(entry.key.encode("utf-8")).hexdigest() for entry in entries] + embeddings = [entry.vector for entry in entries] + metadatas = [ + { + "__key": entry.key, + "__metadata": json.dumps(entry.metadata, default=str), + } + for entry in entries + ] + self._collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas) # type: ignore - self._collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas) - - async def retrieve(self, vector: list[float], k: int = 5) -> list[VectorDBEntry]: + async def retrieve(self, vector: list[float], options: VectorStoreOptions | None = None) -> list[VectorDBEntry]: """ Retrieves entries from the ChromaDB collection. Args: vector: The vector to query. - k: The number of entries to retrieve. + options: The options for querying the vector store. Returns: The retrieved entries. """ - query_result = self._collection.query(query_embeddings=vector, n_results=k, include=["metadatas", "embeddings"]) - metadatas = query_result.get("metadatas") or [] - embeddings = query_result.get("embeddings") or [] - - db_entries = [] - for meta_list, embeddings_list in zip(metadatas, embeddings, strict=False): - for meta, embedding in zip(meta_list, embeddings_list, strict=False): - db_entry = VectorDBEntry( - key=str(meta["__key"]), - vector=list(embedding), - metadata=json.loads(str(meta["__metadata"])), - ) - - db_entries.append(db_entry) + options = self._default_options if options is None else options + results = self._collection.query( + query_embeddings=vector, + n_results=options.k, + include=["metadatas", "embeddings", "distances"], + ) + metadatas = results.get("metadatas") or [] + embeddings = results.get("embeddings") or [] + distances = results.get("distances") or [] - return db_entries + return [ + VectorDBEntry( + key=str(metadata["__key"]), vector=list(embeddings), metadata=json.loads(str(metadata["__metadata"])) + ) + for batch in zip(metadatas, embeddings, distances, strict=False) + for metadata, embeddings, distance in zip(*batch, strict=False) + if options.max_distance is None or distance <= options.max_distance + ] async def list( self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0 @@ -189,28 +146,19 @@ async def list( where_chroma: chromadb.Where | None = dict(where) if where else None get_results = self._collection.get( - where=where_chroma, limit=limit, offset=offset, include=["metadatas", "embeddings"] + where=where_chroma, + limit=limit, + offset=offset, + include=["metadatas", "embeddings"], ) metadatas = get_results.get("metadatas") or [] embeddings = get_results.get("embeddings") or [] - db_entries = [] - for meta, embedding in zip(metadatas, embeddings, strict=False): - db_entry = VectorDBEntry( - key=str(meta["__key"]), + return [ + VectorDBEntry( + key=str(metadata["__key"]), vector=list(embedding), - metadata=json.loads(str(meta["__metadata"])), + metadata=json.loads(str(metadata["__metadata"])), ) - - db_entries.append(db_entry) - - return db_entries - - def __repr__(self) -> str: - """ - Returns the string representation of the object. - - Returns: - The string representation of the object. - """ - return f"{self.__class__.__name__}(index_name={self._index_name})" + for metadata, embedding in zip(metadatas, embeddings, strict=False) + ] diff --git a/packages/ragbits-core/tests/unit/vector_stores/test_chromadb_store.py b/packages/ragbits-core/tests/unit/vector_stores/test_chromadb_store.py index 5a44f877..9994cffe 100644 --- a/packages/ragbits-core/tests/unit/vector_stores/test_chromadb_store.py +++ b/packages/ragbits-core/tests/unit/vector_stores/test_chromadb_store.py @@ -1,85 +1,30 @@ -from hashlib import sha256 -from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest -from ragbits.core.embeddings import Embeddings -from ragbits.core.vector_store import ChromaDBStore, VectorDBEntry +from ragbits.core.vector_store.base import VectorDBEntry +from ragbits.core.vector_store.chromadb_store import ChromaDBStore @pytest.fixture -def mock_chroma_client(): - return MagicMock() - - -@pytest.fixture -def mock_embedding_function(): - return AsyncMock() - - -@pytest.fixture -def mock_chromadb_store(mock_chroma_client: MagicMock, mock_embedding_function: MagicMock): +def mock_chromadb_store(): return ChromaDBStore( + client=MagicMock(), index_name="test_index", - chroma_client=mock_chroma_client, - embedding_function=mock_embedding_function, - ) - - -class MockEmbeddings(Embeddings): - async def embed_text(self, text: list[str]): # noqa: PLR6301 - return [[0.4, 0.5, 0.6]] - - def __call__(self, input: list[str]): - return self.embed_text(input) - - -@pytest.fixture -def custom_embedding_function(): - return MockEmbeddings() - - -@pytest.fixture -def mock_chromadb_store_with_custom_embedding_function( - mock_chroma_client: MagicMock, custom_embedding_function: MagicMock -): - return ChromaDBStore( - index_name="test_index", - chroma_client=mock_chroma_client, - embedding_function=custom_embedding_function, - ) - - -@pytest.fixture -def mock_vector_db_entry(): - return VectorDBEntry( - key="test_key", - vector=[0.1, 0.2, 0.3], - metadata={ - "content": "test content", - "document": { - "title": "test title", - "source": {"path": "/test/path"}, - "document_type": "test_type", - }, - }, ) def test_chromadbstore_init_import_error(): with patch("ragbits.core.vector_store.chromadb_store.HAS_CHROMADB", False), pytest.raises(ImportError): ChromaDBStore( + client=MagicMock(), index_name="test_index", - chroma_client=MagicMock(), - embedding_function=MagicMock(), ) def test_get_chroma_collection(mock_chromadb_store: ChromaDBStore): _ = mock_chromadb_store._get_chroma_collection() - - assert mock_chromadb_store._chroma_client.get_or_create_collection.called # type: ignore + assert mock_chromadb_store._client.get_or_create_collection.call_count == 2 # type: ignore async def test_stores_entries_correctly(mock_chromadb_store: ChromaDBStore): @@ -100,42 +45,35 @@ async def test_stores_entries_correctly(mock_chromadb_store: ChromaDBStore): await mock_chromadb_store.store(data) - mock_chromadb_store._chroma_client.get_or_create_collection().add.assert_called_once() # type: ignore - - -def test_process_db_entry(mock_chromadb_store: ChromaDBStore, mock_vector_db_entry: VectorDBEntry): - id, embedding, metadata = mock_chromadb_store._process_db_entry(mock_vector_db_entry) - - assert id == sha256(b"test_key").hexdigest() - assert embedding == [0.1, 0.2, 0.3] - assert ( - metadata["__metadata"] == '{"content": "test content", "document": {"title": "test title", "source":' - ' {"path": "/test/path"}, "document_type": "test_type"}}' + mock_chromadb_store._client.get_or_create_collection().add.assert_called_once() # type: ignore + mock_chromadb_store._client.get_or_create_collection().add.assert_called_with( # type: ignore + ids=["92488e1e3eeecdf99f3ed2ce59233efb4b4fb612d5655c0ce9ea52b5a502e655"], + embeddings=[[0.1, 0.2, 0.3]], + metadatas=[ + { + "__key": "test_key", + "__metadata": '{"content": "test content", "document": {"title": "test title", "source":' + ' {"path": "/test/path"}, "document_type": "test_type"}}', + } + ], ) - assert metadata["__key"] == "test_key" - - -async def test_store(mock_chromadb_store: ChromaDBStore, mock_vector_db_entry: VectorDBEntry): - await mock_chromadb_store.store([mock_vector_db_entry]) - - assert mock_chromadb_store._chroma_client.get_or_create_collection().add.called # type: ignore async def test_retrieves_entries_correctly(mock_chromadb_store: ChromaDBStore): vector = [0.1, 0.2, 0.3] mock_collection = mock_chromadb_store._get_chroma_collection() mock_collection.query.return_value = { # type: ignore - "documents": [["test content"]], "metadatas": [ [ { "__key": "test_key", "__metadata": '{"content": "test content", "document": {"title": "test title", "source":' - ' {"path": "/test/path"}, "document_type": "test_type"}}', - } + ' {"path": "/test/path-1"}, "document_type": "txt"}}', + }, ] ], "embeddings": [[[0.12, 0.25, 0.29]]], + "distances": [[0.1]], } entries = await mock_chromadb_store.retrieve(vector) @@ -149,7 +87,6 @@ async def test_retrieves_entries_correctly(mock_chromadb_store: ChromaDBStore): async def test_lists_entries_correctly(mock_chromadb_store: ChromaDBStore): mock_collection = mock_chromadb_store._get_chroma_collection() mock_collection.get.return_value = { # type: ignore - "documents": ["test content", "test content 2"], "metadatas": [ { "__key": "test_key", @@ -174,35 +111,3 @@ async def test_lists_entries_correctly(mock_chromadb_store: ChromaDBStore): assert entries[1].metadata["content"] == "test content 2" assert entries[1].metadata["document"]["title"] == "test title 2" assert entries[1].vector == [0.13, 0.26, 0.30] - - -async def test_handles_empty_retrieve(mock_chromadb_store: ChromaDBStore): - vector = [0.1, 0.2, 0.3] - mock_collection = mock_chromadb_store._get_chroma_collection() - mock_collection.query.return_value = {"documents": [], "metadatas": []} # type: ignore - - entries = await mock_chromadb_store.retrieve(vector) - - assert len(entries) == 0 - - -def test_repr(mock_chromadb_store: ChromaDBStore): - assert repr(mock_chromadb_store) == "ChromaDBStore(index_name=test_index)" - - -@pytest.mark.parametrize( - ("retrieved", "max_distance", "expected"), - [ - ({"distances": [[0.1]], "documents": [["test content"]]}, None, "test content"), - ({"distances": [[0.1]], "documents": [["test content"]]}, 0.2, "test content"), - ({"distances": [[0.3]], "documents": [["test content"]]}, 0.2, None), - ], -) -def test_return_best_match( - mock_chromadb_store: ChromaDBStore, retrieved: dict[str, Any], max_distance: float | None, expected: str | None -): - mock_chromadb_store._max_distance = max_distance - - result = mock_chromadb_store._return_best_match(retrieved) - - assert result == expected diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index a537ae99..ff37e9ea 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -42,9 +42,7 @@ class DocumentSearch: """ embedder: Embeddings - vector_store: VectorStore - query_rephraser: QueryRephraser reranker: Reranker @@ -84,23 +82,23 @@ def from_config(cls, config: dict) -> "DocumentSearch": return cls(embedder, vector_store, query_rephraser, reranker, document_processor_router) - async def search(self, query: str, search_config: SearchConfig | None = None) -> list[Element]: + async def search(self, query: str, config: SearchConfig | None = None) -> list[Element]: """ Search for the most relevant chunks for a query. Args: query: The query to search for. - search_config: The search configuration. + config: The search configuration. Returns: A list of chunks. """ - search_config = search_config or SearchConfig() + config = config or SearchConfig() queries = await self.query_rephraser.rephrase(query) elements = [] for rephrased_query in queries: search_vector = await self.embedder.embed_text([rephrased_query]) - entries = await self.vector_store.retrieve(search_vector[0], **search_config.vector_store_kwargs) + entries = await self.vector_store.retrieve(search_vector[0]) elements.extend([Element.from_vector_db_entry(entry) for entry in entries]) return self.reranker.rerank(elements) diff --git a/packages/ragbits-document-search/tests/unit/test_document_search.py b/packages/ragbits-document-search/tests/unit/test_document_search.py index 8edb635e..fdbcda8f 100644 --- a/packages/ragbits-document-search/tests/unit/test_document_search.py +++ b/packages/ragbits-document-search/tests/unit/test_document_search.py @@ -144,7 +144,7 @@ async def test_document_search_with_search_config(): document_processor=DummyProvider(), ) - results = await document_search.search("Peppa's brother", search_config=SearchConfig(vector_store_kwargs={"k": 1})) + results = await document_search.search("Peppa's brother", config=SearchConfig(vector_store_kwargs={"k": 1})) assert len(results) == 1 assert results[0].content == "Name of Peppa's brother is George" # type: ignore