From 52345424db8232e353516577a96b5323b276d6c2 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 15 Jul 2024 18:29:35 +0200 Subject: [PATCH] fix: `qdrant` - Fallback to default filter policy when deserializing retrievers without the init parameter (#902) * Add defensive check for filter_policy deserialization * Add unit tests --- .../components/retrievers/qdrant/retriever.py | 15 ++++-- integrations/qdrant/tests/test_retriever.py | 48 +++++++++++++++++++ 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py index 7dd22aab5..275a46f95 100644 --- a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py +++ b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py @@ -108,7 +108,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "QdrantEmbeddingRetriever": """ document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) @@ -249,7 +252,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "QdrantSparseEmbeddingRetriever": """ document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) @@ -394,7 +400,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "QdrantHybridRetriever": """ document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) diff --git a/integrations/qdrant/tests/test_retriever.py b/integrations/qdrant/tests/test_retriever.py index 396d4b519..a92f6917f 100644 --- a/integrations/qdrant/tests/test_retriever.py +++ b/integrations/qdrant/tests/test_retriever.py @@ -296,6 +296,31 @@ def test_from_dict(self): assert retriever._return_embedding is True assert retriever._score_threshold is None + def test_from_dict_no_filter_policy(self): + data = { + "type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantSparseEmbeddingRetriever", + "init_parameters": { + "document_store": { + "init_parameters": {"location": ":memory:", "index": "test"}, + "type": "haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore", + }, + "filters": None, + "top_k": 5, + "scale_score": False, + "return_embedding": True, + "score_threshold": None, + }, + } + retriever = QdrantSparseEmbeddingRetriever.from_dict(data) + assert isinstance(retriever._document_store, QdrantDocumentStore) + assert retriever._document_store.index == "test" + assert retriever._filters is None + assert retriever._top_k == 5 + assert retriever._filter_policy == FilterPolicy.REPLACE # defaults to REPLACE + assert retriever._scale_score is False + assert retriever._return_embedding is True + assert retriever._score_threshold is None + def test_run(self, filterable_docs: List[Document], generate_sparse_embedding): document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=True) @@ -414,6 +439,29 @@ def test_from_dict(self): assert retriever._return_embedding assert retriever._score_threshold is None + def test_from_dict_no_filter_policy(self): + data = { + "type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantHybridRetriever", + "init_parameters": { + "document_store": { + "init_parameters": {"location": ":memory:", "index": "test"}, + "type": "haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore", + }, + "filters": None, + "top_k": 5, + "return_embedding": True, + "score_threshold": None, + }, + } + retriever = QdrantHybridRetriever.from_dict(data) + assert isinstance(retriever._document_store, QdrantDocumentStore) + assert retriever._document_store.index == "test" + assert retriever._filters is None + assert retriever._top_k == 5 + assert retriever._filter_policy == FilterPolicy.REPLACE # defaults to REPLACE + assert retriever._return_embedding + assert retriever._score_threshold is None + def test_run(self): mock_store = Mock(spec=QdrantDocumentStore) sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33])