diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py index a894b94c1..fdbf95eb0 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py @@ -55,6 +55,7 @@ def __init__( embedding_dimension: int = 768, duplicates_policy: DuplicatePolicy = DuplicatePolicy.NONE, similarity: str = "cosine", + namespace: Optional[str] = None, ): """ The connection to Astra DB is established and managed through the JSON API. @@ -99,6 +100,7 @@ def __init__( self.embedding_dimension = embedding_dimension self.duplicates_policy = duplicates_policy self.similarity = similarity + self.namespace = namespace self.index = AstraClient( resolved_api_endpoint, @@ -106,6 +108,7 @@ def __init__( self.collection_name, self.embedding_dimension, self.similarity, + namespace, ) @classmethod @@ -128,6 +131,7 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ + return default_to_dict( self, api_endpoint=self.api_endpoint.to_dict(), @@ -136,6 +140,7 @@ def to_dict(self) -> Dict[str, Any]: embedding_dimension=self.embedding_dimension, duplicates_policy=self.duplicates_policy.name, similarity=self.similarity, + namespace=self.namespace, ) def write_documents( diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index f1fad4f5d..3650ffd61 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import os from typing import List +from unittest import mock import pytest from haystack import Document @@ -13,6 +14,33 @@ from haystack_integrations.document_stores.astra import AstraDocumentStore +def test_namespace_init(): + with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB") as client: + AstraDocumentStore() + assert "namespace" in client.call_args.kwargs + assert client.call_args.kwargs["namespace"] is None + + AstraDocumentStore(namespace="foo") + assert "namespace" in client.call_args.kwargs + assert client.call_args.kwargs["namespace"] == "foo" + + +def test_to_dict(): + with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB"): + ds = AstraDocumentStore() + result = ds.to_dict() + assert result["type"] == "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore" + assert set(result["init_parameters"]) == { + "api_endpoint", + "token", + "collection_name", + "embedding_dimension", + "duplicates_policy", + "similarity", + "namespace", + } + + @pytest.mark.integration @pytest.mark.skipif( os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN env var not set" diff --git a/integrations/astra/tests/test_retriever.py b/integrations/astra/tests/test_retriever.py index 95ba7a263..b52cedf33 100644 --- a/integrations/astra/tests/test_retriever.py +++ b/integrations/astra/tests/test_retriever.py @@ -30,6 +30,7 @@ def test_retriever_to_json(*_): "embedding_dimension": 768, "duplicates_policy": "NONE", "similarity": "cosine", + "namespace": None, }, }, }, @@ -42,7 +43,6 @@ def test_retriever_to_json(*_): ) @patch("haystack_integrations.document_stores.astra.document_store.AstraClient") def test_retriever_from_json(*_): - data = { "type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever", "init_parameters": {