diff --git a/packages/ragbits-document-search/examples/chromadb_example.py b/packages/ragbits-document-search/examples/chromadb_example.py index 9bd5ab1d8..03f26bd25 100644 --- a/packages/ragbits-document-search/examples/chromadb_example.py +++ b/packages/ragbits-document-search/examples/chromadb_example.py @@ -9,11 +9,11 @@ import os import chromadb + from ragbits.core.embeddings.litellm import LiteLLMEmbeddings from ragbits.document_search import DocumentSearch from ragbits.document_search.documents.document import DocumentMeta from ragbits.document_search.vector_store.chromadb_store import ChromaDBStore -from ragbits.document_search.vector_store.in_memory import InMemoryVectorStore documents = [ DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."), diff --git a/packages/ragbits-document-search/src/ragbits/document_search/vector_store/chromadb_store.py b/packages/ragbits-document-search/src/ragbits/document_search/vector_store/chromadb_store.py index cb91064a0..24cd3b90d 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/vector_store/chromadb_store.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/vector_store/chromadb_store.py @@ -1,12 +1,11 @@ +import json from copy import deepcopy from hashlib import sha256 -import json -from typing import Literal, Optional, Union, List - -from ragbits.document_search.documents.element import TextElement +from typing import List, Literal, Optional, Union try: import chromadb + HAS_CHROMADB = True except ImportError: HAS_CHROMADB = False @@ -37,7 +36,7 @@ def __init__( distance_method (Literal["l2", "ip", "cosine"], default="l2"): The distance method to use. """ if not HAS_CHROMADB: - raise ImportError("You need to install the 'ragbits-document-search[chromadb]' extra requirement of to use LiteLLM embeddings models") + raise ImportError("Install the 'ragbits-document-search[chromadb]' extra to use LiteLLM embeddings models") super().__init__() self.index_name = index_name @@ -78,9 +77,9 @@ def _return_best_match(self, retrieved: dict) -> Optional[str]: return retrieved["documents"][0][0] return None - - def _process_db_entry(self, entry: VectorDBEntry): - id = sha256(entry.key.encode("utf-8")).hexdigest() + + def _process_db_entry(self, entry: VectorDBEntry) -> tuple[str, list[float], str, dict]: + doc_id = sha256(entry.key.encode("utf-8")).hexdigest() embedding = entry.vector text = entry.metadata["content"] @@ -89,14 +88,31 @@ def _process_db_entry(self, entry: VectorDBEntry): metadata["key"] = entry.key metadata = {key: json.dumps(val) if isinstance(val, dict) else val for key, val in metadata.items()} + return doc_id, embedding, text, metadata - return id, embedding, text, metadata + def _process_metadata(self, metadata: dict) -> dict[str, Union[str, int, float, bool]]: + """ + Processes the metadata dictionary by parsing JSON strings if applicable. - def _process_metadata(self, metadata): - return {key: json.loads(val) if self.is_json(val) else val - for key, val in metadata.items()} + Args: + metadata (dict): A dictionary containing metadata where values may be JSON strings. - def is_json(self, myjson) -> bool: + Returns: + dict[str, Union[str, int, float, bool]]: A dictionary with the same keys as the input, + where JSON strings are parsed into their respective Python data types. + """ + return {key: json.loads(val) if self.is_json(val) else val for key, val in metadata.items()} + + def is_json(self, myjson: str) -> bool: + """ + Check if the provided string is a valid JSON. + + Args: + myjson (str): The string to be checked. + + Returns: + bool: True if the string is a valid JSON, False otherwise. + """ try: if isinstance(myjson, str): json.loads(myjson) @@ -104,8 +120,7 @@ def is_json(self, myjson) -> bool: return False except ValueError: return False - - + async def store(self, entries: List[VectorDBEntry]) -> None: """ Stores entries in the ChromaDB collection. @@ -114,9 +129,7 @@ async def store(self, entries: List[VectorDBEntry]) -> None: entries (List[VectorDBEntry]): The entries to store. """ collection = self._get_chroma_collection() - - entries_processed = list(map(self._process_db_entry, entries)) ids, embeddings, texts, metadatas = map(list, zip(*entries_processed)) @@ -137,12 +150,12 @@ async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry] query_result = collection.query(query_embeddings=[vector], n_results=k) db_entries = [] - for doc, meta in zip(query_result.get("documents"), query_result.get("metadatas")): + for meta in query_result.get("metadatas"): db_entry = VectorDBEntry( key=meta[0].get("key"), vector=vector, metadata=self._process_metadata(meta[0]), - ) + ) db_entries.append(db_entry) @@ -177,4 +190,4 @@ def __repr__(self) -> str: Returns: str: The string representation of the object. """ - return f"{self.__class__.__name__}(index_name={self.index_name})" \ No newline at end of file + return f"{self.__class__.__name__}(index_name={self.index_name})" diff --git a/packages/ragbits-document-search/tests/unit/test_chromadb_store.py b/packages/ragbits-document-search/tests/unit/test_chromadb_store.py index da4dea3ed..8ef0b95de 100644 --- a/packages/ragbits-document-search/tests/unit/test_chromadb_store.py +++ b/packages/ragbits-document-search/tests/unit/test_chromadb_store.py @@ -1,13 +1,9 @@ from hashlib import sha256 -import json from unittest.mock import AsyncMock, MagicMock, patch -import uuid -import chromadb import pytest from ragbits.core.embeddings.base import Embeddings -from ragbits.document_search.vector_store import chromadb_store from ragbits.document_search.vector_store.chromadb_store import ChromaDBStore, VectorDBEntry @@ -33,6 +29,7 @@ def mock_chromadb_store(mock_chroma_client, mock_embedding_function): class MockEmbeddings(Embeddings): async def embed_text(self, text): return [[0.4, 0.5, 0.6]] + def __call__(self, input): return self.embed_text(input) @@ -56,18 +53,17 @@ def mock_vector_db_entry(): return VectorDBEntry( key="test_key", vector=[0.1, 0.2, 0.3], - metadata={"content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}} + metadata={ + "content": "test content", + "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}, + }, ) - + def test_chromadbstore_init_import_error(): - with patch('ragbits.document_search.vector_store.chromadb_store.HAS_CHROMADB', False): + with patch("ragbits.document_search.vector_store.chromadb_store.HAS_CHROMADB", False): with pytest.raises(ImportError): - ChromaDBStore( - index_name="test_index", - chroma_client=MagicMock(), - embedding_function=MagicMock() - ) + ChromaDBStore(index_name="test_index", chroma_client=MagicMock(), embedding_function=MagicMock()) def test_get_chroma_collection(mock_chromadb_store): @@ -75,7 +71,9 @@ def test_get_chroma_collection(mock_chromadb_store): assert mock_chromadb_store.chroma_client.get_or_create_collection.called -def test_get_chroma_collection_with_custom_embedding_function(custom_embedding_function, mock_chromadb_store_with_custom_embedding_function, mock_chroma_client): +def test_get_chroma_collection_with_custom_embedding_function( + custom_embedding_function, mock_chromadb_store_with_custom_embedding_function, mock_chroma_client +): _ = mock_chromadb_store_with_custom_embedding_function._get_chroma_collection() mock_chroma_client.get_or_create_collection.assert_called_once_with( @@ -90,7 +88,10 @@ async def test_stores_entries_correctly(mock_chromadb_store): VectorDBEntry( key="test_key", vector=[0.1, 0.2, 0.3], - metadata={"content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}}, + metadata={ + "content": "test content", + "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}, + }, ) ] await mock_chromadb_store.store(data) @@ -101,10 +102,13 @@ def test_process_db_entry(mock_chromadb_store, mock_vector_db_entry): id, embedding, text, metadata = mock_chromadb_store._process_db_entry(mock_vector_db_entry) print(f"metadata: {metadata}, type: {type(metadata)}") - assert id == sha256("test_key".encode("utf-8")).hexdigest() + assert id == sha256(b"test_key").hexdigest() assert embedding == [0.1, 0.2, 0.3] assert text == "test content" - assert metadata["document"] == '{"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}' + assert ( + metadata["document"] + == '{"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}' + ) assert metadata["key"] == "test_key" @@ -120,7 +124,15 @@ async def test_retrieves_entries_correctly(mock_chromadb_store): mock_collection = mock_chromadb_store._get_chroma_collection() mock_collection.query.return_value = { "documents": [["test content"]], - "metadatas": [[{"key": "test_key", "content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}}]], + "metadatas": [ + [ + { + "key": "test_key", + "content": "test content", + "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}, + } + ] + ], } entries = await mock_chromadb_store.retrieve(vector) assert len(entries) == 1 @@ -143,7 +155,7 @@ async def test_find_similar(mock_chromadb_store, mock_embedding_function): mock_chromadb_store.embedding_function = mock_embedding_function mock_chromadb_store.chroma_client.get_or_create_collection().query.return_value = { "documents": [["test content"]], - "distances": [[0.1]] + "distances": [[0.1]], } result = await mock_chromadb_store.find_similar("test text") assert result == "test content" @@ -154,7 +166,7 @@ async def test_find_similar_with_custom_embeddings(mock_chromadb_store, custom_e mock_chromadb_store.embedding_function = custom_embedding_function mock_chromadb_store.chroma_client.get_or_create_collection().query.return_value = { "documents": [["test content"]], - "distances": [[0.1]] + "distances": [[0.1]], } result = await mock_chromadb_store.find_similar("test text") assert result == "test content" @@ -163,6 +175,7 @@ async def test_find_similar_with_custom_embeddings(mock_chromadb_store, custom_e def test_repr(mock_chromadb_store): assert repr(mock_chromadb_store) == "ChromaDBStore(index_name=test_index)" + @pytest.mark.parametrize( "retrieved, max_distance, expected", [ @@ -180,10 +193,9 @@ def test_return_best_match(mock_chromadb_store, retrieved, max_distance, expecte def test_is_json_valid_string(mock_chromadb_store): # Arrange valid_json_string = '{"key": "value"}' - + # Act result = mock_chromadb_store.is_json(valid_json_string) - + # Assert assert result is True -