Skip to content

Commit

Permalink
fix: Astra - Fallback to default filter policy when deserializing r…
Browse files Browse the repository at this point in the history
…etrievers without the init parameter (#896)

* Add defensive check for filter_policy deserialization

* Add unit test

* Add comment
  • Loading branch information
vblagoje authored Jul 15, 2024
1 parent a9da4ed commit bcdf33d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "AstraEmbeddingRetriever":
"""
document_store = AstraDocumentStore.from_dict(data["init_parameters"]["document_store"])
data["init_parameters"]["document_store"] = document_store
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if filter_policy := data["init_parameters"].get("filter_policy"):
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)
return default_from_dict(cls, data)
30 changes: 30 additions & 0 deletions integrations/astra/tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,33 @@ def test_retriever_from_json(*_):
retriever = AstraEmbeddingRetriever.from_dict(data)
assert retriever.top_k == 42
assert retriever.filters == {"bar": "baz"}


@patch.dict(
"os.environ",
{"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"},
)
@patch("haystack_integrations.document_stores.astra.document_store.AstraClient")
def test_retriever_from_json_no_filter_policy(*_):
data = {
"type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever",
"init_parameters": {
"filters": {"bar": "baz"},
"top_k": 42,
"document_store": {
"type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore",
"init_parameters": {
"api_endpoint": {"type": "env_var", "env_vars": ["ASTRA_DB_API_ENDPOINT"], "strict": True},
"token": {"type": "env_var", "env_vars": ["ASTRA_DB_APPLICATION_TOKEN"], "strict": True},
"collection_name": "documents",
"embedding_dimension": 768,
"duplicates_policy": "NONE",
"similarity": "cosine",
},
},
},
}
retriever = AstraEmbeddingRetriever.from_dict(data)
assert retriever.top_k == 42
assert retriever.filters == {"bar": "baz"}
assert retriever.filter_policy == FilterPolicy.REPLACE # defaults to REPLACE

0 comments on commit bcdf33d

Please sign in to comment.