Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(document-search): avoid metadata mutation #63

Merged
merged 5 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from copy import deepcopy
from hashlib import sha256
from typing import List, Literal, Optional, Union

Expand Down Expand Up @@ -79,48 +78,16 @@ def _return_best_match(self, retrieved: dict) -> Optional[str]:

return None

def _process_db_entry(self, entry: VectorDBEntry) -> tuple[str, list[float], str, dict]:
def _process_db_entry(self, entry: VectorDBEntry) -> tuple[str, list[float], dict]:
doc_id = sha256(entry.key.encode("utf-8")).hexdigest()
embedding = entry.vector
text = entry.metadata["content"]

metadata = deepcopy(entry.metadata)
metadata["document"]["source"]["path"] = str(metadata["document"]["source"]["path"])
metadata["key"] = entry.key
metadata = {key: json.dumps(val) if isinstance(val, dict) else val for key, val in metadata.items()}
metadata = {
"__key": entry.key,
"__metadata": json.dumps(entry.metadata, default=str),
}

return doc_id, embedding, text, metadata

def _process_metadata(self, metadata: dict) -> dict[str, Union[str, int, float, bool]]:
"""
Processes the metadata dictionary by parsing JSON strings if applicable.

Args:
metadata: A dictionary containing metadata where values may be JSON strings.

Returns:
A dictionary with the same keys as the input, where JSON strings are parsed
into their respective Python data types.
"""
return {key: json.loads(val) if self._is_json(val) else val for key, val in metadata.items()}

def _is_json(self, myjson: str) -> bool:
"""
Check if the provided string is a valid JSON.

Args:
myjson: The string to be checked.

Returns:
True if the string is a valid JSON, False otherwise.
"""
try:
if isinstance(myjson, str):
json.loads(myjson)
return True
return False
except ValueError:
return False
return doc_id, embedding, metadata

@property
def embedding_function(self) -> Union[Embeddings, chromadb.EmbeddingFunction]:
Expand All @@ -139,12 +106,10 @@ async def store(self, entries: List[VectorDBEntry]) -> None:
Args:
entries: The entries to store.
"""
collection = self._get_chroma_collection()

entries_processed = list(map(self._process_db_entry, entries))
ids, embeddings, texts, metadatas = map(list, zip(*entries_processed))
ids, embeddings, metadatas = map(list, zip(*entries_processed))

collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
self._collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas)

async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]:
"""
Expand All @@ -157,43 +122,20 @@ async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]
Returns:
The retrieved entries.
"""
collection = self._get_chroma_collection()
query_result = collection.query(query_embeddings=[vector], n_results=k)
query_result = self._collection.query(query_embeddings=[vector], n_results=k)

db_entries = []
for meta in query_result.get("metadatas"):
db_entry = VectorDBEntry(
key=meta[0].get("key"),
key=meta[0]["__key"],
vector=vector,
metadata=self._process_metadata(meta[0]),
metadata=json.loads(meta[0]["__metadata"]),
)

db_entries.append(db_entry)

return db_entries

async def find_similar(self, text: str) -> Optional[str]:
"""
Finds the most similar text in the chroma collection or returns None if the most similar text
has distance bigger than `self.max_distance`.

Args:
text: The text to find similar to.

Returns:
The most similar text or None if no similar text is found.
"""

collection = self._get_chroma_collection()

if isinstance(self._embedding_function, Embeddings):
embedding = await self._embedding_function.embed_text([text])
retrieved = collection.query(query_embeddings=embedding, n_results=1)
else:
retrieved = collection.query(query_texts=[text], n_results=1)

return self._return_best_match(retrieved)

def __repr__(self) -> str:
"""
Returns the string representation of the object.
Expand Down
56 changes: 6 additions & 50 deletions packages/ragbits-document-search/tests/unit/test_chromadb_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,6 @@ def test_get_chroma_collection(mock_chromadb_store):
assert mock_chromadb_store._chroma_client.get_or_create_collection.called


def test_get_chroma_collection_with_custom_embedding_function(
custom_embedding_function, mock_chromadb_store_with_custom_embedding_function, mock_chroma_client
):
mock_chroma_client.get_or_create_collection.assert_called_once_with(
name="test_index",
metadata={"hnsw:space": "l2"},
)


async def test_stores_entries_correctly(mock_chromadb_store):
data = [
VectorDBEntry(
Expand All @@ -96,17 +87,15 @@ async def test_stores_entries_correctly(mock_chromadb_store):


def test_process_db_entry(mock_chromadb_store, mock_vector_db_entry):
id, embedding, text, metadata = mock_chromadb_store._process_db_entry(mock_vector_db_entry)
print(f"metadata: {metadata}, type: {type(metadata)}")
id, embedding, metadata = mock_chromadb_store._process_db_entry(mock_vector_db_entry)

assert id == sha256(b"test_key").hexdigest()
assert embedding == [0.1, 0.2, 0.3]
assert text == "test content"
assert (
metadata["document"]
== '{"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}'
metadata["__metadata"]
== '{"content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}}'
)
assert metadata["key"] == "test_key"
assert metadata["__key"] == "test_key"


async def test_store(mock_chromadb_store, mock_vector_db_entry):
Expand All @@ -122,9 +111,8 @@ async def test_retrieves_entries_correctly(mock_chromadb_store):
"metadatas": [
[
{
"key": "test_key",
"content": "test content",
"document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"},
"__key": "test_key",
"__metadata": '{"content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}}',
}
]
],
Expand All @@ -143,27 +131,6 @@ async def test_handles_empty_retrieve(mock_chromadb_store):
assert len(entries) == 0


async def test_find_similar(mock_chromadb_store, mock_embedding_function):
mock_embedding_function.embed_text.return_value = [[0.1, 0.2, 0.3]]
mock_chromadb_store._embedding_function = mock_embedding_function
mock_chromadb_store._chroma_client.get_or_create_collection().query.return_value = {
"documents": [["test content"]],
"distances": [[0.1]],
}
result = await mock_chromadb_store.find_similar("test text")
assert result == "test content"


async def test_find_similar_with_custom_embeddings(mock_chromadb_store, custom_embedding_function):
mock_chromadb_store._embedding_function = custom_embedding_function
mock_chromadb_store._chroma_client.get_or_create_collection().query.return_value = {
"documents": [["test content"]],
"distances": [[0.1]],
}
result = await mock_chromadb_store.find_similar("test text")
assert result == "test content"


def test_repr(mock_chromadb_store):
assert repr(mock_chromadb_store) == "ChromaDBStore(index_name=test_index)"

Expand All @@ -180,14 +147,3 @@ def test_return_best_match(mock_chromadb_store, retrieved, max_distance, expecte
mock_chromadb_store._max_distance = max_distance
result = mock_chromadb_store._return_best_match(retrieved)
assert result == expected


def test_is_json_valid_string(mock_chromadb_store):
# Arrange
valid_json_string = '{"key": "value"}'

# Act
result = mock_chromadb_store._is_json(valid_json_string)

# Assert
assert result is True
Loading