Skip to content

Commit

Permalink
fix(document-search): avoid metadata mutation (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
PatrykWyzgowski authored Oct 7, 2024
1 parent 08acb63 commit 704eef2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 119 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from copy import deepcopy
from hashlib import sha256
from typing import List, Literal, Optional, Union

Expand Down Expand Up @@ -79,48 +78,16 @@ def _return_best_match(self, retrieved: dict) -> Optional[str]:

return None

def _process_db_entry(self, entry: VectorDBEntry) -> tuple[str, list[float], str, dict]:
def _process_db_entry(self, entry: VectorDBEntry) -> tuple[str, list[float], dict]:
doc_id = sha256(entry.key.encode("utf-8")).hexdigest()
embedding = entry.vector
text = entry.metadata["content"]

metadata = deepcopy(entry.metadata)
metadata["document"]["source"]["path"] = str(metadata["document"]["source"]["path"])
metadata["key"] = entry.key
metadata = {key: json.dumps(val) if isinstance(val, dict) else val for key, val in metadata.items()}
metadata = {
"__key": entry.key,
"__metadata": json.dumps(entry.metadata, default=str),
}

return doc_id, embedding, text, metadata

def _process_metadata(self, metadata: dict) -> dict[str, Union[str, int, float, bool]]:
"""
Processes the metadata dictionary by parsing JSON strings if applicable.
Args:
metadata: A dictionary containing metadata where values may be JSON strings.
Returns:
A dictionary with the same keys as the input, where JSON strings are parsed
into their respective Python data types.
"""
return {key: json.loads(val) if self._is_json(val) else val for key, val in metadata.items()}

def _is_json(self, myjson: str) -> bool:
"""
Check if the provided string is a valid JSON.
Args:
myjson: The string to be checked.
Returns:
True if the string is a valid JSON, False otherwise.
"""
try:
if isinstance(myjson, str):
json.loads(myjson)
return True
return False
except ValueError:
return False
return doc_id, embedding, metadata

@property
def embedding_function(self) -> Union[Embeddings, chromadb.EmbeddingFunction]:
Expand All @@ -139,12 +106,10 @@ async def store(self, entries: List[VectorDBEntry]) -> None:
Args:
entries: The entries to store.
"""
collection = self._get_chroma_collection()

entries_processed = list(map(self._process_db_entry, entries))
ids, embeddings, texts, metadatas = map(list, zip(*entries_processed))
ids, embeddings, metadatas = map(list, zip(*entries_processed))

collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
self._collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas)

async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]:
"""
Expand All @@ -157,43 +122,20 @@ async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]
Returns:
The retrieved entries.
"""
collection = self._get_chroma_collection()
query_result = collection.query(query_embeddings=[vector], n_results=k)
query_result = self._collection.query(query_embeddings=[vector], n_results=k)

db_entries = []
for meta in query_result.get("metadatas"):
db_entry = VectorDBEntry(
key=meta[0].get("key"),
key=meta[0]["__key"],
vector=vector,
metadata=self._process_metadata(meta[0]),
metadata=json.loads(meta[0]["__metadata"]),
)

db_entries.append(db_entry)

return db_entries

async def find_similar(self, text: str) -> Optional[str]:
"""
Finds the most similar text in the chroma collection or returns None if the most similar text
has distance bigger than `self.max_distance`.
Args:
text: The text to find similar to.
Returns:
The most similar text or None if no similar text is found.
"""

collection = self._get_chroma_collection()

if isinstance(self._embedding_function, Embeddings):
embedding = await self._embedding_function.embed_text([text])
retrieved = collection.query(query_embeddings=embedding, n_results=1)
else:
retrieved = collection.query(query_texts=[text], n_results=1)

return self._return_best_match(retrieved)

def __repr__(self) -> str:
"""
Returns the string representation of the object.
Expand Down
56 changes: 6 additions & 50 deletions packages/ragbits-document-search/tests/unit/test_chromadb_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,6 @@ def test_get_chroma_collection(mock_chromadb_store):
assert mock_chromadb_store._chroma_client.get_or_create_collection.called


def test_get_chroma_collection_with_custom_embedding_function(
custom_embedding_function, mock_chromadb_store_with_custom_embedding_function, mock_chroma_client
):
mock_chroma_client.get_or_create_collection.assert_called_once_with(
name="test_index",
metadata={"hnsw:space": "l2"},
)


async def test_stores_entries_correctly(mock_chromadb_store):
data = [
VectorDBEntry(
Expand All @@ -96,17 +87,15 @@ async def test_stores_entries_correctly(mock_chromadb_store):


def test_process_db_entry(mock_chromadb_store, mock_vector_db_entry):
id, embedding, text, metadata = mock_chromadb_store._process_db_entry(mock_vector_db_entry)
print(f"metadata: {metadata}, type: {type(metadata)}")
id, embedding, metadata = mock_chromadb_store._process_db_entry(mock_vector_db_entry)

assert id == sha256(b"test_key").hexdigest()
assert embedding == [0.1, 0.2, 0.3]
assert text == "test content"
assert (
metadata["document"]
== '{"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}'
metadata["__metadata"]
== '{"content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}}'
)
assert metadata["key"] == "test_key"
assert metadata["__key"] == "test_key"


async def test_store(mock_chromadb_store, mock_vector_db_entry):
Expand All @@ -122,9 +111,8 @@ async def test_retrieves_entries_correctly(mock_chromadb_store):
"metadatas": [
[
{
"key": "test_key",
"content": "test content",
"document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"},
"__key": "test_key",
"__metadata": '{"content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}}',
}
]
],
Expand All @@ -143,27 +131,6 @@ async def test_handles_empty_retrieve(mock_chromadb_store):
assert len(entries) == 0


async def test_find_similar(mock_chromadb_store, mock_embedding_function):
mock_embedding_function.embed_text.return_value = [[0.1, 0.2, 0.3]]
mock_chromadb_store._embedding_function = mock_embedding_function
mock_chromadb_store._chroma_client.get_or_create_collection().query.return_value = {
"documents": [["test content"]],
"distances": [[0.1]],
}
result = await mock_chromadb_store.find_similar("test text")
assert result == "test content"


async def test_find_similar_with_custom_embeddings(mock_chromadb_store, custom_embedding_function):
mock_chromadb_store._embedding_function = custom_embedding_function
mock_chromadb_store._chroma_client.get_or_create_collection().query.return_value = {
"documents": [["test content"]],
"distances": [[0.1]],
}
result = await mock_chromadb_store.find_similar("test text")
assert result == "test content"


def test_repr(mock_chromadb_store):
assert repr(mock_chromadb_store) == "ChromaDBStore(index_name=test_index)"

Expand All @@ -180,14 +147,3 @@ def test_return_best_match(mock_chromadb_store, retrieved, max_distance, expecte
mock_chromadb_store._max_distance = max_distance
result = mock_chromadb_store._return_best_match(retrieved)
assert result == expected


def test_is_json_valid_string(mock_chromadb_store):
# Arrange
valid_json_string = '{"key": "value"}'

# Act
result = mock_chromadb_store._is_json(valid_json_string)

# Assert
assert result is True

0 comments on commit 704eef2

Please sign in to comment.