Skip to content

Commit

Permalink
[Astra] Change authentication parameters (#423)
Browse files Browse the repository at this point in the history
* change auth params

* make tests unit

* linting

* fix examples

* adjust env vars from secrets

* remove conftest
  • Loading branch information
masci authored Feb 15, 2024
1 parent 5ca05dc commit eeecd3c
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 123 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/astra.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ jobs:

- name: Run tests
env:
ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }}
ASTRA_DB_ID: ${{ secrets.ASTRA_DB_ID }}
ASTRA_API_ENDPOINT: ${{ secrets.ASTRA_API_ENDPOINT }}
ASTRA_TOKEN: ${{ secrets.ASTRA_TOKEN }}
run: hatch run cov
9 changes: 2 additions & 7 deletions integrations/astra/examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,15 @@
file_paths = [HERE / "data" / Path(name) for name in os.listdir("integrations/astra/examples/data")]
logger.info(file_paths)

astra_id = os.getenv("ASTRA_DB_ID", "")
astra_region = os.getenv("ASTRA_DB_REGION", "us-east1")

astra_application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN", "")
collection_name = os.getenv("COLLECTION_NAME", "haystack_vector_search")
keyspace_name = os.getenv("KEYSPACE_NAME", "recommender_demo")

# Make sure ASTRA_API_ENDPOINT and ASTRA_TOKEN environment variables are set before proceeding

# We support many different databases. Here, we load a simple and lightweight in-memory database.
document_store = AstraDocumentStore(
astra_id=astra_id,
astra_region=astra_region,
astra_collection=collection_name,
astra_keyspace=keyspace_name,
astra_application_token=astra_application_token,
duplicates_policy=DuplicatePolicy.OVERWRITE,
embedding_dim=384,
)
Expand Down
25 changes: 10 additions & 15 deletions integrations/astra/examples/pipeline_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,27 @@

# Create a RAG query pipeline
prompt_template = """
Given these documents, answer the question.
Given these documents, answer the question.
Documents:
{% for doc in documents %}
{{ doc.content }}
{% endfor %}
Documents:
{% for doc in documents %}
{{ doc.content }}
{% endfor %}
Question: {{question}}
Question: {{question}}
Answer:
"""
Answer:
"""

astra_id = os.getenv("ASTRA_DB_ID", "")
astra_region = os.getenv("ASTRA_DB_REGION", "us-east1")

astra_application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN", "")
collection_name = os.getenv("COLLECTION_NAME", "haystack_vector_search")
keyspace_name = os.getenv("KEYSPACE_NAME", "recommender_demo")

# Make sure ASTRA_API_ENDPOINT and ASTRA_TOKEN environment variables are set before proceeding

# We support many different databases. Here, we load a simple and lightweight in-memory database.
document_store = AstraDocumentStore(
astra_id=astra_id,
astra_region=astra_region,
astra_collection=collection_name,
astra_keyspace=keyspace_name,
astra_application_token=astra_application_token,
duplicates_policy=DuplicatePolicy.SKIP,
embedding_dim=384,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,26 @@ class AstraClient:

def __init__(
self,
astra_id: str,
astra_region: str,
astra_application_token: str,
api_endpoint: str,
token: str,
keyspace_name: str,
collection_name: str,
embedding_dim: int,
similarity_function: str,
):
self.astra_id = astra_id
self.astra_application_token = astra_application_token
self.astra_region = astra_region
self.api_endpoint = api_endpoint
self.token = token
self.keyspace_name = keyspace_name
self.collection_name = collection_name
self.embedding_dim = embedding_dim
self.similarity_function = similarity_function

self.request_url = f"https://{self.astra_id}-{self.astra_region}.apps.astra.datastax.com/api/json/v1/{self.keyspace_name}/{self.collection_name}"
self.request_url = f"{self.api_endpoint}/api/json/v1/{self.keyspace_name}/{self.collection_name}"
self.request_header = {
"x-cassandra-token": self.astra_application_token,
"x-cassandra-token": self.token,
"Content-Type": "application/json",
}
self.create_url = (
f"https://{self.astra_id}-{self.astra_region}.apps.astra.datastax.com/api/json/v1/{self.keyspace_name}"
)
self.create_url = f"{self.api_endpoint}/api/json/v1/{self.keyspace_name}"

index_exists = self.find_index()
if not index_exists:
Expand Down Expand Up @@ -198,7 +194,9 @@ def batch_generator(chunks, batch_size):
yield batch

for id_batch in batch_generator(ids, batch_size):
document_batch.extend(self.find_documents({"filter": {"_id": {"$in": id_batch}}}))
docs = self.find_documents({"filter": {"_id": {"$in": id_batch}}})
if docs:
document_batch.extend(docs)
formatted_docs = self._format_query_response(document_batch, include_metadata=True, include_values=True)
return formatted_docs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from haystack.dataclasses import Document
from haystack.document_stores.errors import DuplicateDocumentError, MissingDocumentError
from haystack.document_stores.types import DuplicatePolicy
from haystack.utils import Secret, deserialize_secrets_inplace

from .astra_client import AstraClient
from .errors import AstraDocumentStoreFilterError
Expand All @@ -35,11 +36,10 @@ class AstraDocumentStore:

def __init__(
self,
astra_id: str,
astra_region: str,
astra_application_token: str,
astra_keyspace: str,
astra_collection: str,
api_endpoint: Secret = Secret.from_env_var("ASTRA_API_ENDPOINT"), # noqa: B008
token: Secret = Secret.from_env_var("ASTRA_TOKEN"), # noqa: B008
astra_keyspace: str = "default_keyspace",
astra_collection: str = "documents",
embedding_dim: int = 768,
duplicates_policy: DuplicatePolicy = DuplicatePolicy.NONE,
similarity: str = "cosine",
Expand All @@ -65,36 +65,50 @@ def __init__(
- `DuplicatePolicy.OVERWRITE`: If a Document with the same id already exists, it is overwritten.
- `DuplicatePolicy.FAIL`: If a Document with the same id already exists, an error is raised.
"""
resolved_api_endpoint = api_endpoint.resolve_value()
if resolved_api_endpoint is None:
msg = (
"AstraDocumentStore expects the API endpoint. "
"Set the ASTRA_API_ENDPOINT environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

self.duplicates_policy = duplicates_policy
self.astra_id = astra_id
self.astra_region = astra_region
self.astra_application_token = astra_application_token
resolved_token = token.resolve_value()
if resolved_token is None:
msg = (
"AstraDocumentStore expects an authentication token. "
"Set the ASTRA_TOKEN environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

self.api_endpoint = api_endpoint
self.token = token
self.astra_keyspace = astra_keyspace
self.astra_collection = astra_collection
self.embedding_dim = embedding_dim
self.duplicates_policy = duplicates_policy
self.similarity = similarity

self.index = AstraClient(
astra_id=self.astra_id,
astra_region=self.astra_region,
astra_application_token=self.astra_application_token,
keyspace_name=self.astra_keyspace,
collection_name=self.astra_collection,
embedding_dim=self.embedding_dim,
similarity_function=self.similarity,
resolved_api_endpoint,
resolved_token,
self.astra_keyspace,
self.astra_collection,
self.embedding_dim,
self.similarity,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AstraDocumentStore":
deserialize_secrets_inplace(data["init_parameters"], keys=["api_endpoint", "token"])
return default_from_dict(cls, data)

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
api_endpoint=self.api_endpoint.to_dict(),
token=self.token.to_dict(),
duplicates_policy=self.duplicates_policy.name,
astra_id=self.astra_id,
astra_region=self.astra_region,
astra_keyspace=self.astra_keyspace,
astra_collection=self.astra_collection,
embedding_dim=self.embedding_dim,
Expand Down
35 changes: 0 additions & 35 deletions integrations/astra/tests/conftest.py

This file was deleted.

17 changes: 10 additions & 7 deletions integrations/astra/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,23 @@
from haystack_integrations.document_stores.astra import AstraDocumentStore


@pytest.mark.skipif(
os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN is not set"
)
@pytest.mark.skipif(os.environ.get("ASTRA_DB_ID", "") == "", reason="ASTRA_DB_ID is not set")
@pytest.mark.integration
@pytest.mark.skipif(os.environ.get("ASTRA_TOKEN", "") == "", reason="ASTRA_TOKEN env var not set")
@pytest.mark.skipif(os.environ.get("ASTRA_API_ENDPOINT", "") == "", reason="ASTRA_API_ENDPOINT env var not set")
class TestDocumentStore(DocumentStoreBaseTests):
"""
Common test cases will be provided by `DocumentStoreBaseTests` but
you can add more to this class.
"""

@pytest.fixture
@pytest.mark.usefixtures
def document_store(self, document_store) -> AstraDocumentStore:
return document_store
def document_store(self) -> AstraDocumentStore:
return AstraDocumentStore(
astra_keyspace="astra_haystack_test",
astra_collection="haystack_integration",
duplicates_policy=DuplicatePolicy.OVERWRITE,
embedding_dim=768,
)

@pytest.fixture(autouse=True)
def run_before_and_after_tests(self, document_store: AstraDocumentStore):
Expand Down
53 changes: 26 additions & 27 deletions integrations/astra/tests/test_retriever.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,63 @@
# SPDX-FileCopyrightText: 2023-present Anant Corporation <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import os

import pytest
from unittest.mock import patch

from haystack_integrations.components.retrievers.astra import AstraEmbeddingRetriever
from haystack_integrations.document_stores.astra import AstraDocumentStore


@pytest.mark.skipif(
os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN is not set"
@patch.dict(
"os.environ", {"ASTRA_TOKEN": "fake-token", "ASTRA_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}
)
@pytest.mark.skipif(os.environ.get("ASTRA_DB_ID", "") == "", reason="ASTRA_DB_ID is not set")
@pytest.mark.integration
def test_retriever_to_json(document_store):
retriever = AstraEmbeddingRetriever(document_store, filters={"foo": "bar"}, top_k=99)
@patch("haystack_integrations.document_stores.astra.document_store.AstraClient")
def test_retriever_to_json(*_):
ds = AstraDocumentStore()

retriever = AstraEmbeddingRetriever(ds, filters={"foo": "bar"}, top_k=99)
assert retriever.to_dict() == {
"type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever",
"init_parameters": {
"filters": {"foo": "bar"},
"top_k": 99,
"document_store": {
"type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore",
"init_parameters": {
"astra_collection": "haystack_integration",
"astra_id": "63195634-ba44-49be-8a3c-12e830eb1c01",
"astra_keyspace": "astra_haystack_test",
"astra_region": "us-east-2",
"duplicates_policy": "OVERWRITE",
"api_endpoint": {"type": "env_var", "env_vars": ["ASTRA_API_ENDPOINT"], "strict": True},
"token": {"type": "env_var", "env_vars": ["ASTRA_TOKEN"], "strict": True},
"duplicates_policy": "NONE",
"astra_keyspace": "default_keyspace",
"astra_collection": "documents",
"embedding_dim": 768,
"similarity": "cosine",
},
"type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore",
},
},
}


@pytest.mark.skipif(
os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN is not set"
@patch.dict(
"os.environ", {"ASTRA_TOKEN": "fake-token", "ASTRA_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}
)
@pytest.mark.skipif(os.environ.get("ASTRA_DB_ID", "") == "", reason="ASTRA_DB_ID is not set")
@pytest.mark.integration
def test_retriever_from_json():
@patch("haystack_integrations.document_stores.astra.document_store.AstraClient")
def test_retriever_from_json(*_):

data = {
"type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever",
"init_parameters": {
"filters": {"bar": "baz"},
"top_k": 42,
"document_store": {
"type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore",
"init_parameters": {
"astra_collection": "haystack_integration",
"astra_id": "63195634-ba44-49be-8a3c-12e830eb1c01",
"astra_application_token": os.getenv("ASTRA_DB_APPLICATION_TOKEN", ""),
"astra_keyspace": "astra_haystack_test",
"astra_region": "us-east-2",
"duplicates_policy": "overwrite",
"api_endpoint": {"type": "env_var", "env_vars": ["ASTRA_API_ENDPOINT"], "strict": True},
"token": {"type": "env_var", "env_vars": ["ASTRA_TOKEN"], "strict": True},
"duplicates_policy": "NONE",
"astra_keyspace": "default_keyspace",
"astra_collection": "documents",
"embedding_dim": 768,
"similarity": "cosine",
},
"type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore",
},
},
}
Expand Down

0 comments on commit eeecd3c

Please sign in to comment.