From 704eef2445895092f3f1ef05a0b00c24c0325780 Mon Sep 17 00:00:00 2001 From: PatrykWyzgowski Date: Mon, 7 Oct 2024 09:04:27 +0200 Subject: [PATCH] fix(document-search): avoid metadata mutation (#63) --- .../vector_store/chromadb_store.py | 80 +++---------------- .../tests/unit/test_chromadb_store.py | 56 ++----------- 2 files changed, 17 insertions(+), 119 deletions(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/vector_store/chromadb_store.py b/packages/ragbits-document-search/src/ragbits/document_search/vector_store/chromadb_store.py index 685950be3..6d4d4bc4b 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/vector_store/chromadb_store.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/vector_store/chromadb_store.py @@ -1,5 +1,4 @@ import json -from copy import deepcopy from hashlib import sha256 from typing import List, Literal, Optional, Union @@ -79,48 +78,16 @@ def _return_best_match(self, retrieved: dict) -> Optional[str]: return None - def _process_db_entry(self, entry: VectorDBEntry) -> tuple[str, list[float], str, dict]: + def _process_db_entry(self, entry: VectorDBEntry) -> tuple[str, list[float], dict]: doc_id = sha256(entry.key.encode("utf-8")).hexdigest() embedding = entry.vector - text = entry.metadata["content"] - metadata = deepcopy(entry.metadata) - metadata["document"]["source"]["path"] = str(metadata["document"]["source"]["path"]) - metadata["key"] = entry.key - metadata = {key: json.dumps(val) if isinstance(val, dict) else val for key, val in metadata.items()} + metadata = { + "__key": entry.key, + "__metadata": json.dumps(entry.metadata, default=str), + } - return doc_id, embedding, text, metadata - - def _process_metadata(self, metadata: dict) -> dict[str, Union[str, int, float, bool]]: - """ - Processes the metadata dictionary by parsing JSON strings if applicable. - - Args: - metadata: A dictionary containing metadata where values may be JSON strings. - - Returns: - A dictionary with the same keys as the input, where JSON strings are parsed - into their respective Python data types. - """ - return {key: json.loads(val) if self._is_json(val) else val for key, val in metadata.items()} - - def _is_json(self, myjson: str) -> bool: - """ - Check if the provided string is a valid JSON. - - Args: - myjson: The string to be checked. - - Returns: - True if the string is a valid JSON, False otherwise. - """ - try: - if isinstance(myjson, str): - json.loads(myjson) - return True - return False - except ValueError: - return False + return doc_id, embedding, metadata @property def embedding_function(self) -> Union[Embeddings, chromadb.EmbeddingFunction]: @@ -139,12 +106,10 @@ async def store(self, entries: List[VectorDBEntry]) -> None: Args: entries: The entries to store. """ - collection = self._get_chroma_collection() - entries_processed = list(map(self._process_db_entry, entries)) - ids, embeddings, texts, metadatas = map(list, zip(*entries_processed)) + ids, embeddings, metadatas = map(list, zip(*entries_processed)) - collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas) + self._collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas) async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]: """ @@ -157,43 +122,20 @@ async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry] Returns: The retrieved entries. """ - collection = self._get_chroma_collection() - query_result = collection.query(query_embeddings=[vector], n_results=k) + query_result = self._collection.query(query_embeddings=[vector], n_results=k) db_entries = [] for meta in query_result.get("metadatas"): db_entry = VectorDBEntry( - key=meta[0].get("key"), + key=meta[0]["__key"], vector=vector, - metadata=self._process_metadata(meta[0]), + metadata=json.loads(meta[0]["__metadata"]), ) db_entries.append(db_entry) return db_entries - async def find_similar(self, text: str) -> Optional[str]: - """ - Finds the most similar text in the chroma collection or returns None if the most similar text - has distance bigger than `self.max_distance`. - - Args: - text: The text to find similar to. - - Returns: - The most similar text or None if no similar text is found. - """ - - collection = self._get_chroma_collection() - - if isinstance(self._embedding_function, Embeddings): - embedding = await self._embedding_function.embed_text([text]) - retrieved = collection.query(query_embeddings=embedding, n_results=1) - else: - retrieved = collection.query(query_texts=[text], n_results=1) - - return self._return_best_match(retrieved) - def __repr__(self) -> str: """ Returns the string representation of the object. diff --git a/packages/ragbits-document-search/tests/unit/test_chromadb_store.py b/packages/ragbits-document-search/tests/unit/test_chromadb_store.py index fd86dde0a..9d45bdc10 100644 --- a/packages/ragbits-document-search/tests/unit/test_chromadb_store.py +++ b/packages/ragbits-document-search/tests/unit/test_chromadb_store.py @@ -71,15 +71,6 @@ def test_get_chroma_collection(mock_chromadb_store): assert mock_chromadb_store._chroma_client.get_or_create_collection.called -def test_get_chroma_collection_with_custom_embedding_function( - custom_embedding_function, mock_chromadb_store_with_custom_embedding_function, mock_chroma_client -): - mock_chroma_client.get_or_create_collection.assert_called_once_with( - name="test_index", - metadata={"hnsw:space": "l2"}, - ) - - async def test_stores_entries_correctly(mock_chromadb_store): data = [ VectorDBEntry( @@ -96,17 +87,15 @@ async def test_stores_entries_correctly(mock_chromadb_store): def test_process_db_entry(mock_chromadb_store, mock_vector_db_entry): - id, embedding, text, metadata = mock_chromadb_store._process_db_entry(mock_vector_db_entry) - print(f"metadata: {metadata}, type: {type(metadata)}") + id, embedding, metadata = mock_chromadb_store._process_db_entry(mock_vector_db_entry) assert id == sha256(b"test_key").hexdigest() assert embedding == [0.1, 0.2, 0.3] - assert text == "test content" assert ( - metadata["document"] - == '{"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}' + metadata["__metadata"] + == '{"content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}}' ) - assert metadata["key"] == "test_key" + assert metadata["__key"] == "test_key" async def test_store(mock_chromadb_store, mock_vector_db_entry): @@ -122,9 +111,8 @@ async def test_retrieves_entries_correctly(mock_chromadb_store): "metadatas": [ [ { - "key": "test_key", - "content": "test content", - "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}, + "__key": "test_key", + "__metadata": '{"content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}}', } ] ], @@ -143,27 +131,6 @@ async def test_handles_empty_retrieve(mock_chromadb_store): assert len(entries) == 0 -async def test_find_similar(mock_chromadb_store, mock_embedding_function): - mock_embedding_function.embed_text.return_value = [[0.1, 0.2, 0.3]] - mock_chromadb_store._embedding_function = mock_embedding_function - mock_chromadb_store._chroma_client.get_or_create_collection().query.return_value = { - "documents": [["test content"]], - "distances": [[0.1]], - } - result = await mock_chromadb_store.find_similar("test text") - assert result == "test content" - - -async def test_find_similar_with_custom_embeddings(mock_chromadb_store, custom_embedding_function): - mock_chromadb_store._embedding_function = custom_embedding_function - mock_chromadb_store._chroma_client.get_or_create_collection().query.return_value = { - "documents": [["test content"]], - "distances": [[0.1]], - } - result = await mock_chromadb_store.find_similar("test text") - assert result == "test content" - - def test_repr(mock_chromadb_store): assert repr(mock_chromadb_store) == "ChromaDBStore(index_name=test_index)" @@ -180,14 +147,3 @@ def test_return_best_match(mock_chromadb_store, retrieved, max_distance, expecte mock_chromadb_store._max_distance = max_distance result = mock_chromadb_store._return_best_match(retrieved) assert result == expected - - -def test_is_json_valid_string(mock_chromadb_store): - # Arrange - valid_json_string = '{"key": "value"}' - - # Act - result = mock_chromadb_store._is_json(valid_json_string) - - # Assert - assert result is True