Skip to content

Commit

Permalink
fix: Chroma - Fallback to default filter policy when deserializing …
Browse files Browse the repository at this point in the history
…retrievers without the init parameter (#897)

* Add defensive check for filter_policy deserialization

* Add unit test

* Fix test
  • Loading branch information
vblagoje authored Jul 15, 2024
1 parent 43ccd3c commit 16b3849
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever":
"""
document_store = ChromaDocumentStore.from_dict(data["init_parameters"]["document_store"])
data["init_parameters"]["document_store"] = document_store
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if filter_policy := data["init_parameters"].get("filter_policy"):
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)

return default_from_dict(cls, data)

Expand Down
29 changes: 29 additions & 0 deletions integrations/chroma/tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,32 @@ def test_retriever_from_json(request):
assert retriever.filters == {"bar": "baz"}
assert retriever.top_k == 42
assert retriever.filter_policy == FilterPolicy.REPLACE


@pytest.mark.integration
def test_retriever_from_json_no_filter_policy(request):
data = {
"type": "haystack_integrations.components.retrievers.chroma.retriever.ChromaQueryTextRetriever",
"init_parameters": {
"filters": {"bar": "baz"},
"top_k": 42,
"document_store": {
"type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore",
"init_parameters": {
"collection_name": "test_retriever_from_json_no_filter_policy",
"embedding_function": "HuggingFaceEmbeddingFunction",
"persist_path": ".",
"api_key": "1234567890",
"distance_function": "l2",
},
},
},
}
retriever = ChromaQueryTextRetriever.from_dict(data)
assert retriever.document_store._collection_name == request.node.name
assert retriever.document_store._embedding_function == "HuggingFaceEmbeddingFunction"
assert retriever.document_store._embedding_function_params == {"api_key": "1234567890"}
assert retriever.document_store._persist_path == "."
assert retriever.filters == {"bar": "baz"}
assert retriever.top_k == 42
assert retriever.filter_policy == FilterPolicy.REPLACE # default even if not specified

0 comments on commit 16b3849

Please sign in to comment.