diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py index 1ef9af8c2..fdbf95eb0 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py @@ -100,8 +100,7 @@ def __init__( self.embedding_dimension = embedding_dimension self.duplicates_policy = duplicates_policy self.similarity = similarity - if namespace: - self.namespace = namespace + self.namespace = namespace self.index = AstraClient( resolved_api_endpoint, @@ -132,6 +131,7 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ + return default_to_dict( self, api_endpoint=self.api_endpoint.to_dict(), @@ -140,6 +140,7 @@ def to_dict(self) -> Dict[str, Any]: embedding_dimension=self.embedding_dimension, duplicates_policy=self.duplicates_policy.name, similarity=self.similarity, + namespace=self.namespace, ) def write_documents( diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index a6f2f221c..3650ffd61 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -25,6 +25,22 @@ def test_namespace_init(): assert client.call_args.kwargs["namespace"] == "foo" +def test_to_dict(): + with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB"): + ds = AstraDocumentStore() + result = ds.to_dict() + assert result["type"] == "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore" + assert set(result["init_parameters"]) == { + "api_endpoint", + "token", + "collection_name", + "embedding_dimension", + "duplicates_policy", + "similarity", + "namespace", + } + + @pytest.mark.integration @pytest.mark.skipif( os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN env var not set" diff --git a/integrations/astra/tests/test_retriever.py b/integrations/astra/tests/test_retriever.py index 95ba7a263..b52cedf33 100644 --- a/integrations/astra/tests/test_retriever.py +++ b/integrations/astra/tests/test_retriever.py @@ -30,6 +30,7 @@ def test_retriever_to_json(*_): "embedding_dimension": 768, "duplicates_policy": "NONE", "similarity": "cosine", + "namespace": None, }, }, }, @@ -42,7 +43,6 @@ def test_retriever_to_json(*_): ) @patch("haystack_integrations.document_stores.astra.document_store.AstraClient") def test_retriever_from_json(*_): - data = { "type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever", "init_parameters": {