Skip to content

Commit

Permalink
Applying additional review feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatrykWyzgowski committed Oct 4, 2024
1 parent c95c29d commit c49966a
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,10 @@ async def store(self, entries: List[VectorDBEntry]) -> None:
Args:
entries: The entries to store.
"""
collection = self._get_chroma_collection()

entries_processed = list(map(self._process_db_entry, entries))
ids, embeddings, metadatas = map(list, zip(*entries_processed))

collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas)
self._collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas)

async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]:
"""
Expand All @@ -124,8 +122,7 @@ async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]
Returns:
The retrieved entries.
"""
collection = self._get_chroma_collection()
query_result = collection.query(query_embeddings=[vector], n_results=k)
query_result = self._collection.query(query_embeddings=[vector], n_results=k)

db_entries = []
for meta in query_result.get("metadatas"):
Expand All @@ -139,28 +136,6 @@ async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]

return db_entries

async def find_similar(self, text: str) -> Optional[str]:
"""
Finds the most similar text in the chroma collection or returns None if the most similar text
has distance bigger than `self.max_distance`.
Args:
text: The text to find similar to.
Returns:
The most similar text or None if no similar text is found.
"""

collection = self._get_chroma_collection()

if isinstance(self._embedding_function, Embeddings):
embedding = await self._embedding_function.embed_text([text])
retrieved = collection.query(query_embeddings=embedding, n_results=1)
else:
retrieved = collection.query(query_texts=[text], n_results=1)

return self._return_best_match(retrieved)

def __repr__(self) -> str:
"""
Returns the string representation of the object.
Expand Down
19 changes: 0 additions & 19 deletions packages/ragbits-document-search/tests/unit/test_chromadb_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,25 +132,6 @@ async def test_handles_empty_retrieve(mock_chromadb_store):
assert len(entries) == 0


async def test_find_similar(mock_chromadb_store, mock_embedding_function):
mock_embedding_function.embed_text.return_value = [[0.1, 0.2, 0.3]]
mock_chromadb_store._embedding_function = mock_embedding_function
mock_chromadb_store._chroma_client.get_or_create_collection().query.return_value = {
"documents": [["test content"]],
"distances": [[0.1]],
}


async def test_find_similar_with_custom_embeddings(mock_chromadb_store, custom_embedding_function):
mock_chromadb_store._embedding_function = custom_embedding_function
mock_chromadb_store._chroma_client.get_or_create_collection().query.return_value = {
"documents": [["test content"]],
"distances": [[0.1]],
}
result = await mock_chromadb_store.find_similar("test text")
assert result == "test content"


def test_repr(mock_chromadb_store):
assert repr(mock_chromadb_store) == "ChromaDBStore(index_name=test_index)"

Expand Down

0 comments on commit c49966a

Please sign in to comment.