From 80622966951cc2897a689805de203f866c955819 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 17 Jul 2024 09:14:45 +0200 Subject: [PATCH] fix: `PgVector` - Fallback to default filter policy when deserializing retrievers without the init parameter (#900) * Add defensive check for filter_policy deserialization * Update integrations/pgvector/tests/test_retrievers.py Co-authored-by: David S. Batista --------- Co-authored-by: David S. Batista --- .../pgvector/embedding_retriever.py | 5 +- .../retrievers/pgvector/keyword_retriever.py | 5 +- .../pgvector/tests/test_retrievers.py | 48 +++++++++++++++++++ 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py index 1cceffc7d..22aab1a73 100644 --- a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py @@ -126,7 +126,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "PgvectorEmbeddingRetriever": """ doc_store_params = data["init_parameters"]["document_store"] data["init_parameters"]["document_store"] = PgvectorDocumentStore.from_dict(doc_store_params) - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/keyword_retriever.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/keyword_retriever.py index f185fb8f1..636471c31 100644 --- a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/keyword_retriever.py +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/keyword_retriever.py @@ -100,7 +100,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "PgvectorKeywordRetriever": """ doc_store_params = data["init_parameters"]["document_store"] data["init_parameters"]["document_store"] = PgvectorDocumentStore.from_dict(doc_store_params) - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) diff --git a/integrations/pgvector/tests/test_retrievers.py b/integrations/pgvector/tests/test_retrievers.py index b96bc915c..031c735fd 100644 --- a/integrations/pgvector/tests/test_retrievers.py +++ b/integrations/pgvector/tests/test_retrievers.py @@ -241,6 +241,54 @@ def test_from_dict(self, monkeypatch): assert retriever.filters == {"field": "value"} assert retriever.top_k == 5 + @pytest.mark.usefixtures("patches_for_unit_tests") + def test_from_dict_without_filter_policy(self, monkeypatch): + monkeypatch.setenv("PG_CONN_STR", "some-connection-string") + t = "haystack_integrations.components.retrievers.pgvector.keyword_retriever.PgvectorKeywordRetriever" + data = { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "table_name": "haystack_test_to_dict", + "embedding_dimension": 768, + "vector_function": "cosine_similarity", + "recreate_table": True, + "search_strategy": "exact_nearest_neighbor", + "hnsw_recreate_index_if_exists": False, + "hnsw_index_creation_kwargs": {}, + "hnsw_index_name": "haystack_hnsw_index", + "hnsw_ef_search": None, + "keyword_index_name": "haystack_keyword_index", + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + }, + } + + retriever = PgvectorKeywordRetriever.from_dict(data) + document_store = retriever.document_store + + assert isinstance(document_store, PgvectorDocumentStore) + assert isinstance(document_store.connection_string, EnvVarSecret) + assert document_store.table_name == "haystack_test_to_dict" + assert document_store.embedding_dimension == 768 + assert document_store.vector_function == "cosine_similarity" + assert document_store.recreate_table + assert document_store.search_strategy == "exact_nearest_neighbor" + assert not document_store.hnsw_recreate_index_if_exists + assert document_store.hnsw_index_creation_kwargs == {} + assert document_store.hnsw_index_name == "haystack_hnsw_index" + assert document_store.hnsw_ef_search is None + assert document_store.keyword_index_name == "haystack_keyword_index" + + assert retriever.filters == {"field": "value"} + assert retriever.filter_policy == FilterPolicy.REPLACE # defaults to REPLACE + assert retriever.top_k == 5 + def test_run(self): mock_store = Mock(spec=PgvectorDocumentStore) doc = Document(content="Test doc", embedding=[0.1, 0.2])