Skip to content

Commit

Permalink
make Chroma filter_documents return embeddings (#1361)
Browse files Browse the repository at this point in the history
Co-authored-by: Amna Mubashar <[email protected]>
  • Loading branch information
anakin87 and Amnah199 authored Feb 14, 2025
1 parent 2d77b83 commit c356921
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from haystack import default_from_dict, default_to_dict
from haystack.dataclasses import Document
from haystack.document_stores.types import DuplicatePolicy
from numpy import ndarray

from .filters import _convert_filters
from .utils import get_embedding_function
Expand Down Expand Up @@ -208,18 +209,18 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
self._ensure_initialized()
assert self._collection is not None

kwargs: Dict[str, Any] = {"include": ["embeddings", "documents", "metadatas"]}

if filters:
chroma_filter = _convert_filters(filters)
kwargs: Dict[str, Any] = {"where": chroma_filter.where}
kwargs["where"] = chroma_filter.where

if chroma_filter.ids:
kwargs["ids"] = chroma_filter.ids
if chroma_filter.where_document:
kwargs["where_document"] = chroma_filter.where_document

result = self._collection.get(**kwargs)
else:
result = self._collection.get()
result = self._collection.get(**kwargs)

return self._get_result_to_documents(result)

Expand Down Expand Up @@ -416,8 +417,11 @@ def _get_result_to_documents(result: GetResult) -> List[Document]:
document_dict["meta"] = result_metadata[i]

result_embeddings = result.get("embeddings")
if result_embeddings:
document_dict["embedding"] = list(result_embeddings[i])
if result_embeddings is not None:
if isinstance(result_embeddings[i], ndarray):
document_dict["embedding"] = result_embeddings[i].tolist()
else:
document_dict["embedding"] = result_embeddings[i]

retval.append(Document.from_dict(document_dict))

Expand Down
8 changes: 7 additions & 1 deletion integrations/chroma/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,13 @@ def test_delete_not_empty_nonexisting(self, document_store: ChromaDocumentStore)
document_store.write_documents([doc])
document_store.delete_documents(["non_existing"])
filters = {"operator": "==", "field": "id", "value": doc.id}
assert document_store.filter_documents(filters=filters) == [doc]

assert document_store.filter_documents(filters=filters)[0].id == doc.id

def test_filter_documents_return_embeddings(self, document_store: ChromaDocumentStore):
document_store.write_documents([Document(content="test doc", embedding=TEST_EMBEDDING_1)])

assert document_store.filter_documents()[0].embedding == pytest.approx(TEST_EMBEDDING_1)

def test_search(self):
document_store = ChromaDocumentStore()
Expand Down

0 comments on commit c356921

Please sign in to comment.