From eeecd3c68cbc2d4c67c8e936c177ad83b76f8fd9 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Thu, 15 Feb 2024 18:50:31 +0100 Subject: [PATCH] [Astra] Change authentication parameters (#423) * change auth params * make tests unit * linting * fix examples * adjust env vars from secrets * remove conftest --- .github/workflows/astra.yml | 4 +- integrations/astra/examples/example.py | 9 +--- .../astra/examples/pipeline_example.py | 25 ++++----- .../document_stores/astra/astra_client.py | 22 ++++---- .../document_stores/astra/document_store.py | 50 ++++++++++------- integrations/astra/tests/conftest.py | 35 ------------ .../astra/tests/test_document_store.py | 17 +++--- integrations/astra/tests/test_retriever.py | 53 +++++++++---------- 8 files changed, 92 insertions(+), 123 deletions(-) delete mode 100644 integrations/astra/tests/conftest.py diff --git a/.github/workflows/astra.yml b/.github/workflows/astra.yml index a1aab7154..6e0976ebe 100644 --- a/.github/workflows/astra.yml +++ b/.github/workflows/astra.yml @@ -59,6 +59,6 @@ jobs: - name: Run tests env: - ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }} - ASTRA_DB_ID: ${{ secrets.ASTRA_DB_ID }} + ASTRA_API_ENDPOINT: ${{ secrets.ASTRA_API_ENDPOINT }} + ASTRA_TOKEN: ${{ secrets.ASTRA_TOKEN }} run: hatch run cov \ No newline at end of file diff --git a/integrations/astra/examples/example.py b/integrations/astra/examples/example.py index 6d88f3929..eda62834a 100644 --- a/integrations/astra/examples/example.py +++ b/integrations/astra/examples/example.py @@ -21,20 +21,15 @@ file_paths = [HERE / "data" / Path(name) for name in os.listdir("integrations/astra/examples/data")] logger.info(file_paths) -astra_id = os.getenv("ASTRA_DB_ID", "") -astra_region = os.getenv("ASTRA_DB_REGION", "us-east1") - -astra_application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN", "") collection_name = os.getenv("COLLECTION_NAME", "haystack_vector_search") keyspace_name = os.getenv("KEYSPACE_NAME", "recommender_demo") +# Make sure ASTRA_API_ENDPOINT and ASTRA_TOKEN environment variables are set before proceeding + # We support many different databases. Here, we load a simple and lightweight in-memory database. document_store = AstraDocumentStore( - astra_id=astra_id, - astra_region=astra_region, astra_collection=collection_name, astra_keyspace=keyspace_name, - astra_application_token=astra_application_token, duplicates_policy=DuplicatePolicy.OVERWRITE, embedding_dim=384, ) diff --git a/integrations/astra/examples/pipeline_example.py b/integrations/astra/examples/pipeline_example.py index 09521dd64..731bffd54 100644 --- a/integrations/astra/examples/pipeline_example.py +++ b/integrations/astra/examples/pipeline_example.py @@ -17,32 +17,27 @@ # Create a RAG query pipeline prompt_template = """ - Given these documents, answer the question. +Given these documents, answer the question. - Documents: - {% for doc in documents %} - {{ doc.content }} - {% endfor %} +Documents: +{% for doc in documents %} + {{ doc.content }} +{% endfor %} - Question: {{question}} +Question: {{question}} - Answer: - """ +Answer: +""" -astra_id = os.getenv("ASTRA_DB_ID", "") -astra_region = os.getenv("ASTRA_DB_REGION", "us-east1") - -astra_application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN", "") collection_name = os.getenv("COLLECTION_NAME", "haystack_vector_search") keyspace_name = os.getenv("KEYSPACE_NAME", "recommender_demo") +# Make sure ASTRA_API_ENDPOINT and ASTRA_TOKEN environment variables are set before proceeding + # We support many different databases. Here, we load a simple and lightweight in-memory database. document_store = AstraDocumentStore( - astra_id=astra_id, - astra_region=astra_region, astra_collection=collection_name, astra_keyspace=keyspace_name, - astra_application_token=astra_application_token, duplicates_policy=DuplicatePolicy.SKIP, embedding_dim=384, ) diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py index ec0263a5a..99094a6a7 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py @@ -32,30 +32,26 @@ class AstraClient: def __init__( self, - astra_id: str, - astra_region: str, - astra_application_token: str, + api_endpoint: str, + token: str, keyspace_name: str, collection_name: str, embedding_dim: int, similarity_function: str, ): - self.astra_id = astra_id - self.astra_application_token = astra_application_token - self.astra_region = astra_region + self.api_endpoint = api_endpoint + self.token = token self.keyspace_name = keyspace_name self.collection_name = collection_name self.embedding_dim = embedding_dim self.similarity_function = similarity_function - self.request_url = f"https://{self.astra_id}-{self.astra_region}.apps.astra.datastax.com/api/json/v1/{self.keyspace_name}/{self.collection_name}" + self.request_url = f"{self.api_endpoint}/api/json/v1/{self.keyspace_name}/{self.collection_name}" self.request_header = { - "x-cassandra-token": self.astra_application_token, + "x-cassandra-token": self.token, "Content-Type": "application/json", } - self.create_url = ( - f"https://{self.astra_id}-{self.astra_region}.apps.astra.datastax.com/api/json/v1/{self.keyspace_name}" - ) + self.create_url = f"{self.api_endpoint}/api/json/v1/{self.keyspace_name}" index_exists = self.find_index() if not index_exists: @@ -198,7 +194,9 @@ def batch_generator(chunks, batch_size): yield batch for id_batch in batch_generator(ids, batch_size): - document_batch.extend(self.find_documents({"filter": {"_id": {"$in": id_batch}}})) + docs = self.find_documents({"filter": {"_id": {"$in": id_batch}}}) + if docs: + document_batch.extend(docs) formatted_docs = self._format_query_response(document_batch, include_metadata=True, include_values=True) return formatted_docs diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py index 8e03de4a6..8708ee785 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py @@ -11,6 +11,7 @@ from haystack.dataclasses import Document from haystack.document_stores.errors import DuplicateDocumentError, MissingDocumentError from haystack.document_stores.types import DuplicatePolicy +from haystack.utils import Secret, deserialize_secrets_inplace from .astra_client import AstraClient from .errors import AstraDocumentStoreFilterError @@ -35,11 +36,10 @@ class AstraDocumentStore: def __init__( self, - astra_id: str, - astra_region: str, - astra_application_token: str, - astra_keyspace: str, - astra_collection: str, + api_endpoint: Secret = Secret.from_env_var("ASTRA_API_ENDPOINT"), # noqa: B008 + token: Secret = Secret.from_env_var("ASTRA_TOKEN"), # noqa: B008 + astra_keyspace: str = "default_keyspace", + astra_collection: str = "documents", embedding_dim: int = 768, duplicates_policy: DuplicatePolicy = DuplicatePolicy.NONE, similarity: str = "cosine", @@ -65,36 +65,50 @@ def __init__( - `DuplicatePolicy.OVERWRITE`: If a Document with the same id already exists, it is overwritten. - `DuplicatePolicy.FAIL`: If a Document with the same id already exists, an error is raised. """ + resolved_api_endpoint = api_endpoint.resolve_value() + if resolved_api_endpoint is None: + msg = ( + "AstraDocumentStore expects the API endpoint. " + "Set the ASTRA_API_ENDPOINT environment variable (recommended) or pass it explicitly." + ) + raise ValueError(msg) - self.duplicates_policy = duplicates_policy - self.astra_id = astra_id - self.astra_region = astra_region - self.astra_application_token = astra_application_token + resolved_token = token.resolve_value() + if resolved_token is None: + msg = ( + "AstraDocumentStore expects an authentication token. " + "Set the ASTRA_TOKEN environment variable (recommended) or pass it explicitly." + ) + raise ValueError(msg) + + self.api_endpoint = api_endpoint + self.token = token self.astra_keyspace = astra_keyspace self.astra_collection = astra_collection self.embedding_dim = embedding_dim + self.duplicates_policy = duplicates_policy self.similarity = similarity self.index = AstraClient( - astra_id=self.astra_id, - astra_region=self.astra_region, - astra_application_token=self.astra_application_token, - keyspace_name=self.astra_keyspace, - collection_name=self.astra_collection, - embedding_dim=self.embedding_dim, - similarity_function=self.similarity, + resolved_api_endpoint, + resolved_token, + self.astra_keyspace, + self.astra_collection, + self.embedding_dim, + self.similarity, ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "AstraDocumentStore": + deserialize_secrets_inplace(data["init_parameters"], keys=["api_endpoint", "token"]) return default_from_dict(cls, data) def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, + api_endpoint=self.api_endpoint.to_dict(), + token=self.token.to_dict(), duplicates_policy=self.duplicates_policy.name, - astra_id=self.astra_id, - astra_region=self.astra_region, astra_keyspace=self.astra_keyspace, astra_collection=self.astra_collection, embedding_dim=self.embedding_dim, diff --git a/integrations/astra/tests/conftest.py b/integrations/astra/tests/conftest.py deleted file mode 100644 index 274b38352..000000000 --- a/integrations/astra/tests/conftest.py +++ /dev/null @@ -1,35 +0,0 @@ -import os - -import pytest -from haystack.document_stores.types import DuplicatePolicy - -from haystack_integrations.document_stores.astra import AstraDocumentStore - - -@pytest.fixture -def document_store() -> AstraDocumentStore: - """ - This is the most basic requirement for the child class: provide - an instance of this document store so the base class can use it. - """ - astra_id = os.getenv("ASTRA_DB_ID", "") - astra_region = os.getenv("ASTRA_DB_REGION", "us-east-2") - - astra_application_token = os.getenv( - "ASTRA_DB_APPLICATION_TOKEN", - "", - ) - - keyspace_name = "astra_haystack_test" - collection_name = "haystack_integration" - - astra_store = AstraDocumentStore( - astra_id=astra_id, - astra_region=astra_region, - astra_application_token=astra_application_token, - astra_keyspace=keyspace_name, - astra_collection=collection_name, - duplicates_policy=DuplicatePolicy.OVERWRITE, - embedding_dim=768, - ) - return astra_store diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index 019a66398..e1a4a5dfd 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -13,10 +13,9 @@ from haystack_integrations.document_stores.astra import AstraDocumentStore -@pytest.mark.skipif( - os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN is not set" -) -@pytest.mark.skipif(os.environ.get("ASTRA_DB_ID", "") == "", reason="ASTRA_DB_ID is not set") +@pytest.mark.integration +@pytest.mark.skipif(os.environ.get("ASTRA_TOKEN", "") == "", reason="ASTRA_TOKEN env var not set") +@pytest.mark.skipif(os.environ.get("ASTRA_API_ENDPOINT", "") == "", reason="ASTRA_API_ENDPOINT env var not set") class TestDocumentStore(DocumentStoreBaseTests): """ Common test cases will be provided by `DocumentStoreBaseTests` but @@ -24,9 +23,13 @@ class TestDocumentStore(DocumentStoreBaseTests): """ @pytest.fixture - @pytest.mark.usefixtures - def document_store(self, document_store) -> AstraDocumentStore: - return document_store + def document_store(self) -> AstraDocumentStore: + return AstraDocumentStore( + astra_keyspace="astra_haystack_test", + astra_collection="haystack_integration", + duplicates_policy=DuplicatePolicy.OVERWRITE, + embedding_dim=768, + ) @pytest.fixture(autouse=True) def run_before_and_after_tests(self, document_store: AstraDocumentStore): diff --git a/integrations/astra/tests/test_retriever.py b/integrations/astra/tests/test_retriever.py index c06a52edb..f66475167 100644 --- a/integrations/astra/tests/test_retriever.py +++ b/integrations/astra/tests/test_retriever.py @@ -1,64 +1,63 @@ # SPDX-FileCopyrightText: 2023-present Anant Corporation # # SPDX-License-Identifier: Apache-2.0 -import os - -import pytest +from unittest.mock import patch from haystack_integrations.components.retrievers.astra import AstraEmbeddingRetriever +from haystack_integrations.document_stores.astra import AstraDocumentStore -@pytest.mark.skipif( - os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN is not set" +@patch.dict( + "os.environ", {"ASTRA_TOKEN": "fake-token", "ASTRA_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"} ) -@pytest.mark.skipif(os.environ.get("ASTRA_DB_ID", "") == "", reason="ASTRA_DB_ID is not set") -@pytest.mark.integration -def test_retriever_to_json(document_store): - retriever = AstraEmbeddingRetriever(document_store, filters={"foo": "bar"}, top_k=99) +@patch("haystack_integrations.document_stores.astra.document_store.AstraClient") +def test_retriever_to_json(*_): + ds = AstraDocumentStore() + + retriever = AstraEmbeddingRetriever(ds, filters={"foo": "bar"}, top_k=99) assert retriever.to_dict() == { "type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever", "init_parameters": { "filters": {"foo": "bar"}, "top_k": 99, "document_store": { + "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", "init_parameters": { - "astra_collection": "haystack_integration", - "astra_id": "63195634-ba44-49be-8a3c-12e830eb1c01", - "astra_keyspace": "astra_haystack_test", - "astra_region": "us-east-2", - "duplicates_policy": "OVERWRITE", + "api_endpoint": {"type": "env_var", "env_vars": ["ASTRA_API_ENDPOINT"], "strict": True}, + "token": {"type": "env_var", "env_vars": ["ASTRA_TOKEN"], "strict": True}, + "duplicates_policy": "NONE", + "astra_keyspace": "default_keyspace", + "astra_collection": "documents", "embedding_dim": 768, "similarity": "cosine", }, - "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", }, }, } -@pytest.mark.skipif( - os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN is not set" +@patch.dict( + "os.environ", {"ASTRA_TOKEN": "fake-token", "ASTRA_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"} ) -@pytest.mark.skipif(os.environ.get("ASTRA_DB_ID", "") == "", reason="ASTRA_DB_ID is not set") -@pytest.mark.integration -def test_retriever_from_json(): +@patch("haystack_integrations.document_stores.astra.document_store.AstraClient") +def test_retriever_from_json(*_): + data = { "type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever", "init_parameters": { "filters": {"bar": "baz"}, "top_k": 42, "document_store": { + "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", "init_parameters": { - "astra_collection": "haystack_integration", - "astra_id": "63195634-ba44-49be-8a3c-12e830eb1c01", - "astra_application_token": os.getenv("ASTRA_DB_APPLICATION_TOKEN", ""), - "astra_keyspace": "astra_haystack_test", - "astra_region": "us-east-2", - "duplicates_policy": "overwrite", + "api_endpoint": {"type": "env_var", "env_vars": ["ASTRA_API_ENDPOINT"], "strict": True}, + "token": {"type": "env_var", "env_vars": ["ASTRA_TOKEN"], "strict": True}, + "duplicates_policy": "NONE", + "astra_keyspace": "default_keyspace", + "astra_collection": "documents", "embedding_dim": 768, "similarity": "cosine", }, - "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", }, }, }