From 3a6cf81767999cc50494662d862261112e2a9093 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Tue, 20 Feb 2024 15:34:27 +0100 Subject: [PATCH] fix integration tests (#450) --- .../document_stores/astra/astra_client.py | 6 +++--- .../document_stores/astra/document_store.py | 3 ++- integrations/astra/tests/test_document_store.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py index bb0687a07..2a29cbc00 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py @@ -184,9 +184,9 @@ def _query(self, vector, top_k, filters=None): def find_documents(self, find_query): response_dict = self._astra_db_collection.find( - filter=find_query["filter"], - projection=find_query["sort"], - options=find_query["options"], + filter=find_query.get("filter"), + projection=find_query.get("sort"), + options=find_query.get("options"), ) if "data" in response_dict and "documents" in response_dict["data"]: diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py index 1bbf3a6ec..1a4ec9d17 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py @@ -161,7 +161,8 @@ def _convert_input_document(document: Union[dict, Document]): if "dataframe" in document_dict and document_dict["dataframe"] is not None: document_dict["dataframe"] = document_dict.pop("dataframe").to_json() - document_dict["$vector"] = document_dict.pop("embedding", None) + if embedding := document_dict.pop("embedding", []): + document_dict["$vector"] = embedding return document_dict diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index df70b2d13..7669fa8e1 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -29,7 +29,7 @@ def document_store(self) -> AstraDocumentStore: return AstraDocumentStore( collection_name="haystack_integration", duplicates_policy=DuplicatePolicy.OVERWRITE, - embedding_dim=768, + embedding_dimension=768, ) @pytest.fixture(autouse=True)