diff --git a/.github/workflows/CI_readme_sync.yml b/.github/workflows/CI_readme_sync.yml index e921998ff..a09abdc65 100644 --- a/.github/workflows/CI_readme_sync.yml +++ b/.github/workflows/CI_readme_sync.yml @@ -49,6 +49,7 @@ jobs: for d in $ALL_CHANGED_DIRS; do cd $d hatch run docs + hatch env prune # clean up the environment after docs generation cd - done mkdir tmp diff --git a/.github/workflows/astra.yml b/.github/workflows/astra.yml index 6e0976ebe..d859626ff 100644 --- a/.github/workflows/astra.yml +++ b/.github/workflows/astra.yml @@ -59,6 +59,6 @@ jobs: - name: Run tests env: - ASTRA_API_ENDPOINT: ${{ secrets.ASTRA_API_ENDPOINT }} - ASTRA_TOKEN: ${{ secrets.ASTRA_TOKEN }} + ASTRA_DB_API_ENDPOINT: ${{ secrets.ASTRA_API_ENDPOINT }} + ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_TOKEN }} run: hatch run cov \ No newline at end of file diff --git a/integrations/astra/README.md b/integrations/astra/README.md index 75fdfeb6d..d14544df4 100644 --- a/integrations/astra/README.md +++ b/integrations/astra/README.md @@ -17,7 +17,7 @@ pyenv local 3.9 Local install for the package `pip install -e .` To execute integration tests, add needed environment variables -`ASTRA_DB_ID=` +`ASTRA_DB_API_ENDPOINT=` `ASTRA_DB_APPLICATION_TOKEN=` and execute `python examples/example.py` @@ -27,12 +27,10 @@ Install requirements Export environment variables ``` -export KEYSPACE_NAME= +export ASTRA_DB_API_ENDPOINT= +export ASTRA_DB_APPLICATION_TOKEN= export COLLECTION_NAME= export OPENAI_API_KEY= -export ASTRA_DB_ID= -export ASTRA_DB_REGION= -export ASTRA_DB_APPLICATION_TOKEN= ``` run the python examples @@ -54,22 +52,17 @@ from haystack.preview.document_stores import DuplicatePolicy Load in environment variables: ``` -astra_id = os.getenv("ASTRA_DB_ID", "") -astra_region = os.getenv("ASTRA_DB_REGION", "us-east1") - -astra_application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN", "") +api_endpoint = os.getenv("ASTRA_DB_API_ENDPOINT", "") +token = os.getenv("ASTRA_DB_APPLICATION_TOKEN", "") collection_name = os.getenv("COLLECTION_NAME", "haystack_vector_search") -keyspace_name = os.getenv("KEYSPACE_NAME", "recommender_demo") ``` Create the Document Store object: ``` document_store = AstraDocumentStore( - astra_id=astra_id, - astra_region=astra_region, - astra_collection=collection_name, - astra_keyspace=keyspace_name, - astra_application_token=astra_application_token, + api_endpoint=api_endpoint, + token=token, + collection_name=collection_name, duplicates_policy=DuplicatePolicy.SKIP, embedding_dim=384, ) diff --git a/integrations/astra/pyproject.toml b/integrations/astra/pyproject.toml index 7599797a8..0e52d1e62 100644 --- a/integrations/astra/pyproject.toml +++ b/integrations/astra/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "haystack-ai", "pydantic", "typing_extensions", + "astrapy", ] [project.urls] @@ -185,6 +186,7 @@ markers = [ [[tool.mypy.overrides]] module = [ "astra_client.*", + "astrapy.*", "pydantic.*", "haystack.*", "haystack_integrations.*", diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py index 99094a6a7..bb0687a07 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py @@ -1,12 +1,16 @@ import json import logging from typing import Dict, List, Optional, Union +from warnings import warn -import requests +from astrapy.api import APIRequestError +from astrapy.db import AstraDB from pydantic.dataclasses import dataclass logger = logging.getLogger(__name__) +NON_INDEXED_FIELDS = ["metadata._node_content", "content"] + @dataclass class Response: @@ -34,73 +38,77 @@ def __init__( self, api_endpoint: str, token: str, - keyspace_name: str, collection_name: str, - embedding_dim: int, + embedding_dimension: int, similarity_function: str, + namespace: Optional[str] = None, ): self.api_endpoint = api_endpoint self.token = token - self.keyspace_name = keyspace_name self.collection_name = collection_name - self.embedding_dim = embedding_dim + self.embedding_dimension = embedding_dimension self.similarity_function = similarity_function + self.namespace = namespace - self.request_url = f"{self.api_endpoint}/api/json/v1/{self.keyspace_name}/{self.collection_name}" - self.request_header = { - "x-cassandra-token": self.token, - "Content-Type": "application/json", - } - self.create_url = f"{self.api_endpoint}/api/json/v1/{self.keyspace_name}" - - index_exists = self.find_index() - if not index_exists: - self.create_index() - - def find_index(self): - find_query = {"findCollections": {"options": {"explain": True}}} - response = requests.request("POST", self.create_url, headers=self.request_header, data=json.dumps(find_query)) - response.raise_for_status() - response_dict = json.loads(response.text) - - if "status" in response_dict: - collection_name_matches = list( - filter(lambda d: d["name"] == self.collection_name, response_dict["status"]["collections"]) - ) - - if len(collection_name_matches) == 0: - logger.warning( - f"Astra collection {self.collection_name} not found under {self.keyspace_name}. Will be created." - ) - return False - - collection_embedding_dim = collection_name_matches[0]["options"]["vector"]["dimension"] - if collection_embedding_dim != self.embedding_dim: - msg = ( - f"Collection vector dimension is not valid, expected {self.embedding_dim}, " - f"found {collection_embedding_dim}" - ) - raise Exception(msg) + # Build the Astra DB object + self._astra_db = AstraDB(api_endpoint=api_endpoint, token=token, namespace=namespace) - else: - msg = f"status not in response: {response.text}" - raise Exception(msg) - - return True - - def create_index(self): - create_query = { - "createCollection": { - "name": self.collection_name, - "options": {"vector": {"dimension": self.embedding_dim, "metric": self.similarity_function}}, - } - } - response = requests.request("POST", self.create_url, headers=self.request_header, data=json.dumps(create_query)) - response.raise_for_status() - response_dict = json.loads(response.text) - if "errors" in response_dict: - raise Exception(response_dict["errors"]) - logger.info(f"Collection {self.collection_name} created: {response.text}") + try: + # Create and connect to the newly created collection + self._astra_db_collection = self._astra_db.create_collection( + collection_name=collection_name, + dimension=embedding_dimension, + options={"indexing": {"deny": NON_INDEXED_FIELDS}}, + ) + except APIRequestError: + # possibly the collection is preexisting and has legacy + # indexing settings: verify + get_coll_response = self._astra_db.get_collections(options={"explain": True}) + + collections = (get_coll_response["status"] or {}).get("collections") or [] + + preexisting = [collection for collection in collections if collection["name"] == collection_name] + + if preexisting: + pre_collection = preexisting[0] + # if it has no "indexing", it is a legacy collection; + # otherwise it's unexpected warn and proceed at user's risk + pre_col_options = pre_collection.get("options") or {} + if "indexing" not in pre_col_options: + warn( + ( + f"Collection '{collection_name}' is detected as legacy" + " and has indexing turned on for all fields. This" + " implies stricter limitations on the amount of text" + " each entry can store. Consider reindexing anew on a" + " fresh collection to be able to store longer texts." + ), + UserWarning, + stacklevel=2, + ) + self._astra_db_collection = self._astra_db.collection( + collection_name=collection_name, + ) + else: + options_json = json.dumps(pre_col_options["indexing"]) + warn( + ( + f"Collection '{collection_name}' has unexpected 'indexing'" + f" settings (options.indexing = {options_json})." + " This can result in odd behaviour when running " + " metadata filtering and/or unwarranted limitations" + " on storing long texts. Consider reindexing anew on a" + " fresh collection." + ), + UserWarning, + stacklevel=2, + ) + self._astra_db_collection = self._astra_db.collection( + collection_name=collection_name, + ) + else: + # other exception + raise def query( self, @@ -143,13 +151,16 @@ def query( def _query_without_vector(self, top_k, filters=None): query = {"filter": filters, "options": {"limit": top_k}} + return self.find_documents(query) @staticmethod def _format_query_response(responses, include_metadata, include_values): final_res = [] + if responses is None: return QueryResponse(matches=[]) + for response in responses: _id = response.pop("_id") score = response.pop("$similarity", None) @@ -158,27 +169,26 @@ def _format_query_response(responses, include_metadata, include_values): metadata = response if include_metadata else {} # Add all remaining fields to the metadata rsp = Response(_id, text, values, metadata, score) final_res.append(rsp) + return QueryResponse(final_res) def _query(self, vector, top_k, filters=None): query = {"sort": {"$vector": vector}, "options": {"limit": top_k, "includeSimilarity": True}} + if filters is not None: query["filter"] = filters + result = self.find_documents(query) + return result def find_documents(self, find_query): - query = json.dumps({"find": find_query}) - response = requests.request( - "POST", - self.request_url, - headers=self.request_header, - data=query, + response_dict = self._astra_db_collection.find( + filter=find_query["filter"], + projection=find_query["sort"], + options=find_query["options"], ) - response.raise_for_status() - response_dict = json.loads(response.text) - if "errors" in response_dict: - raise Exception(response_dict["errors"]) + if "data" in response_dict and "documents" in response_dict["data"]: return response_dict["data"]["documents"] else: @@ -197,19 +207,13 @@ def batch_generator(chunks, batch_size): docs = self.find_documents({"filter": {"_id": {"$in": id_batch}}}) if docs: document_batch.extend(docs) + formatted_docs = self._format_query_response(document_batch, include_metadata=True, include_values=True) + return formatted_docs def insert(self, documents: List[Dict]): - query = json.dumps({"insertMany": {"options": {"ordered": False}, "documents": documents}}) - response = requests.request( - "POST", - self.request_url, - headers=self.request_header, - data=query, - ) - response.raise_for_status() - response_dict = json.loads(response.text) + response_dict = self._astra_db_collection.insert_many(documents=documents) inserted_ids = ( response_dict["status"]["insertedIds"] @@ -218,34 +222,27 @@ def insert(self, documents: List[Dict]): ) if "errors" in response_dict: logger.error(response_dict["errors"]) + return inserted_ids def update_document(self, document: Dict, id_key: str): document_id = document.pop(id_key) - query = json.dumps( - { - "findOneAndUpdate": { - "filter": {id_key: document_id}, - "update": {"$set": document}, - "options": {"returnDocument": "after"}, - } - } - ) - response = requests.request( - "POST", - self.request_url, - headers=self.request_header, - data=query, + + response_dict = self._astra_db_collection.find_one_and_update( + filter={id_key: document_id}, + update={"$set": document}, + options={"returnDocument": "after"}, ) - response.raise_for_status() - response_dict = json.loads(response.text) + document[id_key] = document_id if "status" in response_dict and "errors" not in response_dict: if "matchedCount" in response_dict["status"] and "modifiedCount" in response_dict["status"]: if response_dict["status"]["matchedCount"] == 1 and response_dict["status"]["modifiedCount"] == 1: return True - logger.warning(f"Documents {document_id} not updated in Astra {response.text}") + + logger.warning(f"Documents {document_id} not updated in Astra DB.") + return False def delete( @@ -261,21 +258,18 @@ def delete( if filters is not None: query = {"deleteMany": {"filter": filters}} + filter_dict = {} + if "filter" in query["deleteMany"]: + filter_dict = query["deleteMany"]["filter"] + deletion_counter = 0 moredata = True while moredata: - response = requests.request( - "POST", - self.request_url, - headers=self.request_header, - data=json.dumps(query), - ) - response.raise_for_status() - response_dict = response.json() - if "errors" in response_dict: - raise Exception(response_dict["errors"]) + response_dict = self._astra_db_collection.delete_many(filter=filter_dict) + if "moreData" not in response_dict.get("status", {}): moredata = False + deletion_counter += int(response_dict["status"].get("deletedCount", 0)) return deletion_counter @@ -284,13 +278,6 @@ def count_documents(self) -> int: """ Returns how many documents are present in the document store. """ - response = requests.request( - "POST", - self.request_url, - headers=self.request_header, - data=json.dumps({"countDocuments": {}}), - ) - response.raise_for_status() - if "errors" in response.json(): - raise Exception(response.json()["errors"]) - return response.json()["status"]["count"] + documents_count = self._astra_db_collection.count_documents() + + return documents_count["status"]["count"] diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py index 8708ee785..1bbf3a6ec 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py @@ -36,27 +36,23 @@ class AstraDocumentStore: def __init__( self, - api_endpoint: Secret = Secret.from_env_var("ASTRA_API_ENDPOINT"), # noqa: B008 - token: Secret = Secret.from_env_var("ASTRA_TOKEN"), # noqa: B008 - astra_keyspace: str = "default_keyspace", - astra_collection: str = "documents", - embedding_dim: int = 768, + api_endpoint: Secret = Secret.from_env_var("ASTRA_DB_API_ENDPOINT"), # noqa: B008 + token: Secret = Secret.from_env_var("ASTRA_DB_APPLICATION_TOKEN"), # noqa: B008 + collection_name: str = "documents", + embedding_dimension: int = 768, duplicates_policy: DuplicatePolicy = DuplicatePolicy.NONE, similarity: str = "cosine", ): """ The connection to Astra DB is established and managed through the JSON API. - The required credentials (database ID, region, and application token) can be generated + The required credentials (api endpoint andapplication token) can be generated through the UI by clicking and the connect tab, and then selecting JSON API and Generate Configuration. - :param astra_id: id of the Astra DB instance. - :param astra_region: Region of cloud servers (can be found when generating the token). - :param astra_application_token: the connection token for Astra. - :param astra_keyspace: The keyspace for the current Astra DB. - :param astra_collection: The current collection in the keyspace in the current Astra DB. - :param embedding_dim: Dimension of embedding vector. - :param similarity: The similarity function used to compare document vectors. + :param api_endpoint: The Astra DB API endpoint. + :param token: The Astra DB application token. + :param collection_name: The current collection in the keyspace in the current Astra DB. + :param embedding_dimension: Dimension of embedding vector. :param duplicates_policy: Handle duplicate documents based on DuplicatePolicy parameter options. Parameter options : (SKIP, OVERWRITE, FAIL, NONE) - `DuplicatePolicy.NONE`: Default policy, If a Document with the same id already exists, @@ -64,12 +60,13 @@ def __init__( - `DuplicatePolicy.SKIP`: If a Document with the same id already exists, it is skipped and not written. - `DuplicatePolicy.OVERWRITE`: If a Document with the same id already exists, it is overwritten. - `DuplicatePolicy.FAIL`: If a Document with the same id already exists, an error is raised. + :param similarity: The similarity function used to compare document vectors. """ resolved_api_endpoint = api_endpoint.resolve_value() if resolved_api_endpoint is None: msg = ( "AstraDocumentStore expects the API endpoint. " - "Set the ASTRA_API_ENDPOINT environment variable (recommended) or pass it explicitly." + "Set the ASTRA_DB_API_ENDPOINT environment variable (recommended) or pass it explicitly." ) raise ValueError(msg) @@ -77,24 +74,22 @@ def __init__( if resolved_token is None: msg = ( "AstraDocumentStore expects an authentication token. " - "Set the ASTRA_TOKEN environment variable (recommended) or pass it explicitly." + "Set the ASTRA_DB_APPLICATION_TOKEN environment variable (recommended) or pass it explicitly." ) raise ValueError(msg) self.api_endpoint = api_endpoint self.token = token - self.astra_keyspace = astra_keyspace - self.astra_collection = astra_collection - self.embedding_dim = embedding_dim + self.collection_name = collection_name + self.embedding_dimension = embedding_dimension self.duplicates_policy = duplicates_policy self.similarity = similarity self.index = AstraClient( resolved_api_endpoint, resolved_token, - self.astra_keyspace, - self.astra_collection, - self.embedding_dim, + self.collection_name, + self.embedding_dimension, self.similarity, ) @@ -108,10 +103,9 @@ def to_dict(self) -> Dict[str, Any]: self, api_endpoint=self.api_endpoint.to_dict(), token=self.token.to_dict(), + collection_name=self.collection_name, + embedding_dimension=self.embedding_dimension, duplicates_policy=self.duplicates_policy.name, - astra_keyspace=self.astra_keyspace, - astra_collection=self.astra_collection, - embedding_dim=self.embedding_dim, similarity=self.similarity, ) diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index e1a4a5dfd..df70b2d13 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -14,8 +14,10 @@ @pytest.mark.integration -@pytest.mark.skipif(os.environ.get("ASTRA_TOKEN", "") == "", reason="ASTRA_TOKEN env var not set") -@pytest.mark.skipif(os.environ.get("ASTRA_API_ENDPOINT", "") == "", reason="ASTRA_API_ENDPOINT env var not set") +@pytest.mark.skipif( + os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN env var not set" +) +@pytest.mark.skipif(os.environ.get("ASTRA_DB_API_ENDPOINT", "") == "", reason="ASTRA_DB_API_ENDPOINT env var not set") class TestDocumentStore(DocumentStoreBaseTests): """ Common test cases will be provided by `DocumentStoreBaseTests` but @@ -25,8 +27,7 @@ class TestDocumentStore(DocumentStoreBaseTests): @pytest.fixture def document_store(self) -> AstraDocumentStore: return AstraDocumentStore( - astra_keyspace="astra_haystack_test", - astra_collection="haystack_integration", + collection_name="haystack_integration", duplicates_policy=DuplicatePolicy.OVERWRITE, embedding_dim=768, ) diff --git a/integrations/astra/tests/test_retriever.py b/integrations/astra/tests/test_retriever.py index f66475167..95ba7a263 100644 --- a/integrations/astra/tests/test_retriever.py +++ b/integrations/astra/tests/test_retriever.py @@ -8,7 +8,8 @@ @patch.dict( - "os.environ", {"ASTRA_TOKEN": "fake-token", "ASTRA_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"} + "os.environ", + {"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}, ) @patch("haystack_integrations.document_stores.astra.document_store.AstraClient") def test_retriever_to_json(*_): @@ -23,12 +24,11 @@ def test_retriever_to_json(*_): "document_store": { "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", "init_parameters": { - "api_endpoint": {"type": "env_var", "env_vars": ["ASTRA_API_ENDPOINT"], "strict": True}, - "token": {"type": "env_var", "env_vars": ["ASTRA_TOKEN"], "strict": True}, + "api_endpoint": {"type": "env_var", "env_vars": ["ASTRA_DB_API_ENDPOINT"], "strict": True}, + "token": {"type": "env_var", "env_vars": ["ASTRA_DB_APPLICATION_TOKEN"], "strict": True}, + "collection_name": "documents", + "embedding_dimension": 768, "duplicates_policy": "NONE", - "astra_keyspace": "default_keyspace", - "astra_collection": "documents", - "embedding_dim": 768, "similarity": "cosine", }, }, @@ -37,7 +37,8 @@ def test_retriever_to_json(*_): @patch.dict( - "os.environ", {"ASTRA_TOKEN": "fake-token", "ASTRA_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"} + "os.environ", + {"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}, ) @patch("haystack_integrations.document_stores.astra.document_store.AstraClient") def test_retriever_from_json(*_): @@ -50,12 +51,11 @@ def test_retriever_from_json(*_): "document_store": { "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", "init_parameters": { - "api_endpoint": {"type": "env_var", "env_vars": ["ASTRA_API_ENDPOINT"], "strict": True}, - "token": {"type": "env_var", "env_vars": ["ASTRA_TOKEN"], "strict": True}, + "api_endpoint": {"type": "env_var", "env_vars": ["ASTRA_DB_API_ENDPOINT"], "strict": True}, + "token": {"type": "env_var", "env_vars": ["ASTRA_DB_APPLICATION_TOKEN"], "strict": True}, + "collection_name": "documents", + "embedding_dimension": 768, "duplicates_policy": "NONE", - "astra_keyspace": "default_keyspace", - "astra_collection": "documents", - "embedding_dim": 768, "similarity": "cosine", }, },