Skip to content

Commit

Permalink
Add defensive check for filter_policy deserialization (#903)
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje authored Jul 17, 2024
1 parent be04358 commit db2b5f7
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "WeaviateBM25Retriever":
data["init_parameters"]["document_store"] = WeaviateDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if filter_policy := data["init_parameters"].get("filter_policy"):
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)

return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "WeaviateEmbeddingRetriever":
data["init_parameters"]["document_store"]
)

data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if filter_policy := data["init_parameters"].get("filter_policy"):
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)

return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand Down
39 changes: 39 additions & 0 deletions integrations/weaviate/tests/test_bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,45 @@ def test_from_dict(_mock_weaviate):
assert retriever._top_k == 10


@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate")
def test_from_dict_no_filter_policy(_mock_weaviate):
retriever = WeaviateBM25Retriever.from_dict(
{
"type": "haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateBM25Retriever",
"init_parameters": {
"filters": {},
"top_k": 10,
"document_store": {
"type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore",
"init_parameters": {
"url": None,
"collection_settings": {
"class": "Default",
"invertedIndexConfig": {"indexNullState": True},
"properties": [
{"name": "_original_id", "dataType": ["text"]},
{"name": "content", "dataType": ["text"]},
{"name": "dataframe", "dataType": ["text"]},
{"name": "blob_data", "dataType": ["blob"]},
{"name": "blob_mime_type", "dataType": ["text"]},
{"name": "score", "dataType": ["number"]},
],
},
"auth_client_secret": None,
"additional_headers": None,
"embedded_options": None,
"additional_config": None,
},
},
},
}
)
assert retriever._document_store
assert retriever._filters == {}
assert retriever._top_k == 10
assert retriever._filter_policy == FilterPolicy.REPLACE


@patch("haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateDocumentStore")
def test_run(mock_document_store):
retriever = WeaviateBM25Retriever(document_store=mock_document_store)
Expand Down
43 changes: 43 additions & 0 deletions integrations/weaviate/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,49 @@ def test_from_dict(_mock_weaviate):
assert retriever._certainty is None


@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate")
def test_from_dict_no_filter_policy(_mock_weaviate):
retriever = WeaviateEmbeddingRetriever.from_dict(
{
"type": "haystack_integrations.components.retrievers.weaviate.embedding_retriever.WeaviateEmbeddingRetriever", # noqa: E501
"init_parameters": {
"filters": {},
"top_k": 10,
"distance": None,
"certainty": None,
"document_store": {
"type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore",
"init_parameters": {
"url": None,
"collection_settings": {
"class": "Default",
"invertedIndexConfig": {"indexNullState": True},
"properties": [
{"name": "_original_id", "dataType": ["text"]},
{"name": "content", "dataType": ["text"]},
{"name": "dataframe", "dataType": ["text"]},
{"name": "blob_data", "dataType": ["blob"]},
{"name": "blob_mime_type", "dataType": ["text"]},
{"name": "score", "dataType": ["number"]},
],
},
"auth_client_secret": None,
"additional_headers": None,
"embedded_options": None,
"additional_config": None,
},
},
},
}
)
assert retriever._document_store
assert retriever._filters == {}
assert retriever._top_k == 10
assert retriever._distance is None
assert retriever._certainty is None
assert retriever._filter_policy == FilterPolicy.REPLACE # defaults to REPLACE


@patch("haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateDocumentStore")
def test_run(mock_document_store):
retriever = WeaviateEmbeddingRetriever(document_store=mock_document_store)
Expand Down

0 comments on commit db2b5f7

Please sign in to comment.