Skip to content

Commit

Permalink
fix: PgVector - Fallback to default filter policy when deserializin…
Browse files Browse the repository at this point in the history
…g retrievers without the init parameter (#900)

* Add defensive check for filter_policy deserialization

* Update integrations/pgvector/tests/test_retrievers.py

Co-authored-by: David S. Batista <[email protected]>

---------

Co-authored-by: David S. Batista <[email protected]>
  • Loading branch information
vblagoje and davidsbatista authored Jul 17, 2024
1 parent d943f4e commit 8062296
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "PgvectorEmbeddingRetriever":
"""
doc_store_params = data["init_parameters"]["document_store"]
data["init_parameters"]["document_store"] = PgvectorDocumentStore.from_dict(doc_store_params)
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if filter_policy := data["init_parameters"].get("filter_policy"):
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "PgvectorKeywordRetriever":
"""
doc_store_params = data["init_parameters"]["document_store"]
data["init_parameters"]["document_store"] = PgvectorDocumentStore.from_dict(doc_store_params)
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if filter_policy := data["init_parameters"].get("filter_policy"):
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand Down
48 changes: 48 additions & 0 deletions integrations/pgvector/tests/test_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,54 @@ def test_from_dict(self, monkeypatch):
assert retriever.filters == {"field": "value"}
assert retriever.top_k == 5

@pytest.mark.usefixtures("patches_for_unit_tests")
def test_from_dict_without_filter_policy(self, monkeypatch):
monkeypatch.setenv("PG_CONN_STR", "some-connection-string")
t = "haystack_integrations.components.retrievers.pgvector.keyword_retriever.PgvectorKeywordRetriever"
data = {
"type": t,
"init_parameters": {
"document_store": {
"type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore",
"init_parameters": {
"connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"},
"table_name": "haystack_test_to_dict",
"embedding_dimension": 768,
"vector_function": "cosine_similarity",
"recreate_table": True,
"search_strategy": "exact_nearest_neighbor",
"hnsw_recreate_index_if_exists": False,
"hnsw_index_creation_kwargs": {},
"hnsw_index_name": "haystack_hnsw_index",
"hnsw_ef_search": None,
"keyword_index_name": "haystack_keyword_index",
},
},
"filters": {"field": "value"},
"top_k": 5,
},
}

retriever = PgvectorKeywordRetriever.from_dict(data)
document_store = retriever.document_store

assert isinstance(document_store, PgvectorDocumentStore)
assert isinstance(document_store.connection_string, EnvVarSecret)
assert document_store.table_name == "haystack_test_to_dict"
assert document_store.embedding_dimension == 768
assert document_store.vector_function == "cosine_similarity"
assert document_store.recreate_table
assert document_store.search_strategy == "exact_nearest_neighbor"
assert not document_store.hnsw_recreate_index_if_exists
assert document_store.hnsw_index_creation_kwargs == {}
assert document_store.hnsw_index_name == "haystack_hnsw_index"
assert document_store.hnsw_ef_search is None
assert document_store.keyword_index_name == "haystack_keyword_index"

assert retriever.filters == {"field": "value"}
assert retriever.filter_policy == FilterPolicy.REPLACE # defaults to REPLACE
assert retriever.top_k == 5

def test_run(self):
mock_store = Mock(spec=PgvectorDocumentStore)
doc = Document(content="Test doc", embedding=[0.1, 0.2])
Expand Down

0 comments on commit 8062296

Please sign in to comment.