From b9da65a0efa6298aaf78a1cff6d304d7c782ebc2 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 24 Apr 2024 11:45:01 +0200 Subject: [PATCH 1/2] pass namespace in the docstore init --- .../document_stores/astra/document_store.py | 4 ++++ integrations/astra/tests/test_document_store.py | 12 ++++++++++++ 2 files changed, 16 insertions(+) diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py index a894b94c1..1ef9af8c2 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py @@ -55,6 +55,7 @@ def __init__( embedding_dimension: int = 768, duplicates_policy: DuplicatePolicy = DuplicatePolicy.NONE, similarity: str = "cosine", + namespace: Optional[str] = None, ): """ The connection to Astra DB is established and managed through the JSON API. @@ -99,6 +100,8 @@ def __init__( self.embedding_dimension = embedding_dimension self.duplicates_policy = duplicates_policy self.similarity = similarity + if namespace: + self.namespace = namespace self.index = AstraClient( resolved_api_endpoint, @@ -106,6 +109,7 @@ def __init__( self.collection_name, self.embedding_dimension, self.similarity, + namespace, ) @classmethod diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index f1fad4f5d..a6f2f221c 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import os from typing import List +from unittest import mock import pytest from haystack import Document @@ -13,6 +14,17 @@ from haystack_integrations.document_stores.astra import AstraDocumentStore +def test_namespace_init(): + with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB") as client: + AstraDocumentStore() + assert "namespace" in client.call_args.kwargs + assert client.call_args.kwargs["namespace"] is None + + AstraDocumentStore(namespace="foo") + assert "namespace" in client.call_args.kwargs + assert client.call_args.kwargs["namespace"] == "foo" + + @pytest.mark.integration @pytest.mark.skipif( os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN env var not set" From 15038b64bae12a0ae44f0afefdae14095f556e20 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 24 Apr 2024 12:08:39 +0200 Subject: [PATCH 2/2] manage serialization --- .../document_stores/astra/document_store.py | 5 +++-- integrations/astra/tests/test_document_store.py | 16 ++++++++++++++++ integrations/astra/tests/test_retriever.py | 2 +- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py index 1ef9af8c2..fdbf95eb0 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py @@ -100,8 +100,7 @@ def __init__( self.embedding_dimension = embedding_dimension self.duplicates_policy = duplicates_policy self.similarity = similarity - if namespace: - self.namespace = namespace + self.namespace = namespace self.index = AstraClient( resolved_api_endpoint, @@ -132,6 +131,7 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ + return default_to_dict( self, api_endpoint=self.api_endpoint.to_dict(), @@ -140,6 +140,7 @@ def to_dict(self) -> Dict[str, Any]: embedding_dimension=self.embedding_dimension, duplicates_policy=self.duplicates_policy.name, similarity=self.similarity, + namespace=self.namespace, ) def write_documents( diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index a6f2f221c..3650ffd61 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -25,6 +25,22 @@ def test_namespace_init(): assert client.call_args.kwargs["namespace"] == "foo" +def test_to_dict(): + with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB"): + ds = AstraDocumentStore() + result = ds.to_dict() + assert result["type"] == "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore" + assert set(result["init_parameters"]) == { + "api_endpoint", + "token", + "collection_name", + "embedding_dimension", + "duplicates_policy", + "similarity", + "namespace", + } + + @pytest.mark.integration @pytest.mark.skipif( os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN env var not set" diff --git a/integrations/astra/tests/test_retriever.py b/integrations/astra/tests/test_retriever.py index 95ba7a263..b52cedf33 100644 --- a/integrations/astra/tests/test_retriever.py +++ b/integrations/astra/tests/test_retriever.py @@ -30,6 +30,7 @@ def test_retriever_to_json(*_): "embedding_dimension": 768, "duplicates_policy": "NONE", "similarity": "cosine", + "namespace": None, }, }, }, @@ -42,7 +43,6 @@ def test_retriever_to_json(*_): ) @patch("haystack_integrations.document_stores.astra.document_store.AstraClient") def test_retriever_from_json(*_): - data = { "type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever", "init_parameters": {