diff --git a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py index 67a771742..71ac3457e 100644 --- a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py +++ b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py @@ -100,7 +100,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever": """ document_store = ChromaDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) diff --git a/integrations/chroma/tests/test_retriever.py b/integrations/chroma/tests/test_retriever.py index bea40cb5f..f0e71828d 100644 --- a/integrations/chroma/tests/test_retriever.py +++ b/integrations/chroma/tests/test_retriever.py @@ -73,3 +73,32 @@ def test_retriever_from_json(request): assert retriever.filters == {"bar": "baz"} assert retriever.top_k == 42 assert retriever.filter_policy == FilterPolicy.REPLACE + + +@pytest.mark.integration +def test_retriever_from_json_no_filter_policy(request): + data = { + "type": "haystack_integrations.components.retrievers.chroma.retriever.ChromaQueryTextRetriever", + "init_parameters": { + "filters": {"bar": "baz"}, + "top_k": 42, + "document_store": { + "type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore", + "init_parameters": { + "collection_name": "test_retriever_from_json_no_filter_policy", + "embedding_function": "HuggingFaceEmbeddingFunction", + "persist_path": ".", + "api_key": "1234567890", + "distance_function": "l2", + }, + }, + }, + } + retriever = ChromaQueryTextRetriever.from_dict(data) + assert retriever.document_store._collection_name == request.node.name + assert retriever.document_store._embedding_function == "HuggingFaceEmbeddingFunction" + assert retriever.document_store._embedding_function_params == {"api_key": "1234567890"} + assert retriever.document_store._persist_path == "." + assert retriever.filters == {"bar": "baz"} + assert retriever.top_k == 42 + assert retriever.filter_policy == FilterPolicy.REPLACE # default even if not specified