diff --git a/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py b/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py index 1539489e0..cfa45e81f 100644 --- a/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py +++ b/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py @@ -101,5 +101,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "AstraEmbeddingRetriever": """ document_store = AstraDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) diff --git a/integrations/astra/tests/test_retriever.py b/integrations/astra/tests/test_retriever.py index e702d0788..4ffe30919 100644 --- a/integrations/astra/tests/test_retriever.py +++ b/integrations/astra/tests/test_retriever.py @@ -92,3 +92,33 @@ def test_retriever_from_json(*_): retriever = AstraEmbeddingRetriever.from_dict(data) assert retriever.top_k == 42 assert retriever.filters == {"bar": "baz"} + + +@patch.dict( + "os.environ", + {"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}, +) +@patch("haystack_integrations.document_stores.astra.document_store.AstraClient") +def test_retriever_from_json_no_filter_policy(*_): + data = { + "type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever", + "init_parameters": { + "filters": {"bar": "baz"}, + "top_k": 42, + "document_store": { + "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", + "init_parameters": { + "api_endpoint": {"type": "env_var", "env_vars": ["ASTRA_DB_API_ENDPOINT"], "strict": True}, + "token": {"type": "env_var", "env_vars": ["ASTRA_DB_APPLICATION_TOKEN"], "strict": True}, + "collection_name": "documents", + "embedding_dimension": 768, + "duplicates_policy": "NONE", + "similarity": "cosine", + }, + }, + }, + } + retriever = AstraEmbeddingRetriever.from_dict(data) + assert retriever.top_k == 42 + assert retriever.filters == {"bar": "baz"} + assert retriever.filter_policy == FilterPolicy.REPLACE # defaults to REPLACE