Skip to content

Commit

Permalink
Update the Astra DB Integration to fit latest conventions (#428)
Browse files Browse the repository at this point in the history
* FIX: use standard astra naming conventions for env vars

* Further astra integration updates

* Fix linting

* Fix for black linting as well

* Fix issues, add dep

* Update pyproject.toml

* Update test_retriever.py

* Update astra.yml

* Address feedback in review
  • Loading branch information
erichare authored Feb 20, 2024
1 parent e43c12f commit b7e2f00
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 172 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/astra.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ jobs:

- name: Run tests
env:
ASTRA_API_ENDPOINT: ${{ secrets.ASTRA_API_ENDPOINT }}
ASTRA_TOKEN: ${{ secrets.ASTRA_TOKEN }}
ASTRA_DB_API_ENDPOINT: ${{ secrets.ASTRA_API_ENDPOINT }}
ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_TOKEN }}
run: hatch run cov
23 changes: 8 additions & 15 deletions integrations/astra/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pyenv local 3.9
Local install for the package
`pip install -e .`
To execute integration tests, add needed environment variables
`ASTRA_DB_ID=<id>`
`ASTRA_DB_API_ENDPOINT=<id>`
`ASTRA_DB_APPLICATION_TOKEN=<token>`
and execute
`python examples/example.py`
Expand All @@ -27,12 +27,10 @@ Install requirements

Export environment variables
```
export KEYSPACE_NAME=
export ASTRA_DB_API_ENDPOINT=
export ASTRA_DB_APPLICATION_TOKEN=
export COLLECTION_NAME=
export OPENAI_API_KEY=
export ASTRA_DB_ID=
export ASTRA_DB_REGION=
export ASTRA_DB_APPLICATION_TOKEN=
```

run the python examples
Expand All @@ -54,22 +52,17 @@ from haystack.preview.document_stores import DuplicatePolicy

Load in environment variables:
```
astra_id = os.getenv("ASTRA_DB_ID", "")
astra_region = os.getenv("ASTRA_DB_REGION", "us-east1")
astra_application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN", "")
api_endpoint = os.getenv("ASTRA_DB_API_ENDPOINT", "")
token = os.getenv("ASTRA_DB_APPLICATION_TOKEN", "")
collection_name = os.getenv("COLLECTION_NAME", "haystack_vector_search")
keyspace_name = os.getenv("KEYSPACE_NAME", "recommender_demo")
```

Create the Document Store object:
```
document_store = AstraDocumentStore(
astra_id=astra_id,
astra_region=astra_region,
astra_collection=collection_name,
astra_keyspace=keyspace_name,
astra_application_token=astra_application_token,
api_endpoint=api_endpoint,
token=token,
collection_name=collection_name,
duplicates_policy=DuplicatePolicy.SKIP,
embedding_dim=384,
)
Expand Down
2 changes: 2 additions & 0 deletions integrations/astra/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"haystack-ai",
"pydantic",
"typing_extensions",
"astrapy",
]

[project.urls]
Expand Down Expand Up @@ -185,6 +186,7 @@ markers = [
[[tool.mypy.overrides]]
module = [
"astra_client.*",
"astrapy.*",
"pydantic.*",
"haystack.*",
"haystack_integrations.*",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import json
import logging
from typing import Dict, List, Optional, Union
from warnings import warn

import requests
from astrapy.api import APIRequestError
from astrapy.db import AstraDB
from pydantic.dataclasses import dataclass

logger = logging.getLogger(__name__)

NON_INDEXED_FIELDS = ["metadata._node_content", "content"]


@dataclass
class Response:
Expand Down Expand Up @@ -34,73 +38,77 @@ def __init__(
self,
api_endpoint: str,
token: str,
keyspace_name: str,
collection_name: str,
embedding_dim: int,
embedding_dimension: int,
similarity_function: str,
namespace: Optional[str] = None,
):
self.api_endpoint = api_endpoint
self.token = token
self.keyspace_name = keyspace_name
self.collection_name = collection_name
self.embedding_dim = embedding_dim
self.embedding_dimension = embedding_dimension
self.similarity_function = similarity_function
self.namespace = namespace

self.request_url = f"{self.api_endpoint}/api/json/v1/{self.keyspace_name}/{self.collection_name}"
self.request_header = {
"x-cassandra-token": self.token,
"Content-Type": "application/json",
}
self.create_url = f"{self.api_endpoint}/api/json/v1/{self.keyspace_name}"

index_exists = self.find_index()
if not index_exists:
self.create_index()

def find_index(self):
find_query = {"findCollections": {"options": {"explain": True}}}
response = requests.request("POST", self.create_url, headers=self.request_header, data=json.dumps(find_query))
response.raise_for_status()
response_dict = json.loads(response.text)

if "status" in response_dict:
collection_name_matches = list(
filter(lambda d: d["name"] == self.collection_name, response_dict["status"]["collections"])
)

if len(collection_name_matches) == 0:
logger.warning(
f"Astra collection {self.collection_name} not found under {self.keyspace_name}. Will be created."
)
return False

collection_embedding_dim = collection_name_matches[0]["options"]["vector"]["dimension"]
if collection_embedding_dim != self.embedding_dim:
msg = (
f"Collection vector dimension is not valid, expected {self.embedding_dim}, "
f"found {collection_embedding_dim}"
)
raise Exception(msg)
# Build the Astra DB object
self._astra_db = AstraDB(api_endpoint=api_endpoint, token=token, namespace=namespace)

else:
msg = f"status not in response: {response.text}"
raise Exception(msg)

return True

def create_index(self):
create_query = {
"createCollection": {
"name": self.collection_name,
"options": {"vector": {"dimension": self.embedding_dim, "metric": self.similarity_function}},
}
}
response = requests.request("POST", self.create_url, headers=self.request_header, data=json.dumps(create_query))
response.raise_for_status()
response_dict = json.loads(response.text)
if "errors" in response_dict:
raise Exception(response_dict["errors"])
logger.info(f"Collection {self.collection_name} created: {response.text}")
try:
# Create and connect to the newly created collection
self._astra_db_collection = self._astra_db.create_collection(
collection_name=collection_name,
dimension=embedding_dimension,
options={"indexing": {"deny": NON_INDEXED_FIELDS}},
)
except APIRequestError:
# possibly the collection is preexisting and has legacy
# indexing settings: verify
get_coll_response = self._astra_db.get_collections(options={"explain": True})

collections = (get_coll_response["status"] or {}).get("collections") or []

preexisting = [collection for collection in collections if collection["name"] == collection_name]

if preexisting:
pre_collection = preexisting[0]
# if it has no "indexing", it is a legacy collection;
# otherwise it's unexpected warn and proceed at user's risk
pre_col_options = pre_collection.get("options") or {}
if "indexing" not in pre_col_options:
warn(
(
f"Collection '{collection_name}' is detected as legacy"
" and has indexing turned on for all fields. This"
" implies stricter limitations on the amount of text"
" each entry can store. Consider reindexing anew on a"
" fresh collection to be able to store longer texts."
),
UserWarning,
stacklevel=2,
)
self._astra_db_collection = self._astra_db.collection(
collection_name=collection_name,
)
else:
options_json = json.dumps(pre_col_options["indexing"])
warn(
(
f"Collection '{collection_name}' has unexpected 'indexing'"
f" settings (options.indexing = {options_json})."
" This can result in odd behaviour when running "
" metadata filtering and/or unwarranted limitations"
" on storing long texts. Consider reindexing anew on a"
" fresh collection."
),
UserWarning,
stacklevel=2,
)
self._astra_db_collection = self._astra_db.collection(
collection_name=collection_name,
)
else:
# other exception
raise

def query(
self,
Expand Down Expand Up @@ -143,13 +151,16 @@ def query(

def _query_without_vector(self, top_k, filters=None):
query = {"filter": filters, "options": {"limit": top_k}}

return self.find_documents(query)

@staticmethod
def _format_query_response(responses, include_metadata, include_values):
final_res = []

if responses is None:
return QueryResponse(matches=[])

for response in responses:
_id = response.pop("_id")
score = response.pop("$similarity", None)
Expand All @@ -158,27 +169,26 @@ def _format_query_response(responses, include_metadata, include_values):
metadata = response if include_metadata else {} # Add all remaining fields to the metadata
rsp = Response(_id, text, values, metadata, score)
final_res.append(rsp)

return QueryResponse(final_res)

def _query(self, vector, top_k, filters=None):
query = {"sort": {"$vector": vector}, "options": {"limit": top_k, "includeSimilarity": True}}

if filters is not None:
query["filter"] = filters

result = self.find_documents(query)

return result

def find_documents(self, find_query):
query = json.dumps({"find": find_query})
response = requests.request(
"POST",
self.request_url,
headers=self.request_header,
data=query,
response_dict = self._astra_db_collection.find(
filter=find_query["filter"],
projection=find_query["sort"],
options=find_query["options"],
)
response.raise_for_status()
response_dict = json.loads(response.text)
if "errors" in response_dict:
raise Exception(response_dict["errors"])

if "data" in response_dict and "documents" in response_dict["data"]:
return response_dict["data"]["documents"]
else:
Expand All @@ -197,19 +207,13 @@ def batch_generator(chunks, batch_size):
docs = self.find_documents({"filter": {"_id": {"$in": id_batch}}})
if docs:
document_batch.extend(docs)

formatted_docs = self._format_query_response(document_batch, include_metadata=True, include_values=True)

return formatted_docs

def insert(self, documents: List[Dict]):
query = json.dumps({"insertMany": {"options": {"ordered": False}, "documents": documents}})
response = requests.request(
"POST",
self.request_url,
headers=self.request_header,
data=query,
)
response.raise_for_status()
response_dict = json.loads(response.text)
response_dict = self._astra_db_collection.insert_many(documents=documents)

inserted_ids = (
response_dict["status"]["insertedIds"]
Expand All @@ -218,34 +222,27 @@ def insert(self, documents: List[Dict]):
)
if "errors" in response_dict:
logger.error(response_dict["errors"])

return inserted_ids

def update_document(self, document: Dict, id_key: str):
document_id = document.pop(id_key)
query = json.dumps(
{
"findOneAndUpdate": {
"filter": {id_key: document_id},
"update": {"$set": document},
"options": {"returnDocument": "after"},
}
}
)
response = requests.request(
"POST",
self.request_url,
headers=self.request_header,
data=query,

response_dict = self._astra_db_collection.find_one_and_update(
filter={id_key: document_id},
update={"$set": document},
options={"returnDocument": "after"},
)
response.raise_for_status()
response_dict = json.loads(response.text)

document[id_key] = document_id

if "status" in response_dict and "errors" not in response_dict:
if "matchedCount" in response_dict["status"] and "modifiedCount" in response_dict["status"]:
if response_dict["status"]["matchedCount"] == 1 and response_dict["status"]["modifiedCount"] == 1:
return True
logger.warning(f"Documents {document_id} not updated in Astra {response.text}")

logger.warning(f"Documents {document_id} not updated in Astra DB.")

return False

def delete(
Expand All @@ -261,21 +258,18 @@ def delete(
if filters is not None:
query = {"deleteMany": {"filter": filters}}

filter_dict = {}
if "filter" in query["deleteMany"]:
filter_dict = query["deleteMany"]["filter"]

deletion_counter = 0
moredata = True
while moredata:
response = requests.request(
"POST",
self.request_url,
headers=self.request_header,
data=json.dumps(query),
)
response.raise_for_status()
response_dict = response.json()
if "errors" in response_dict:
raise Exception(response_dict["errors"])
response_dict = self._astra_db_collection.delete_many(filter=filter_dict)

if "moreData" not in response_dict.get("status", {}):
moredata = False

deletion_counter += int(response_dict["status"].get("deletedCount", 0))

return deletion_counter
Expand All @@ -284,13 +278,6 @@ def count_documents(self) -> int:
"""
Returns how many documents are present in the document store.
"""
response = requests.request(
"POST",
self.request_url,
headers=self.request_header,
data=json.dumps({"countDocuments": {}}),
)
response.raise_for_status()
if "errors" in response.json():
raise Exception(response.json()["errors"])
return response.json()["status"]["count"]
documents_count = self._astra_db_collection.count_documents()

return documents_count["status"]["count"]
Loading

0 comments on commit b7e2f00

Please sign in to comment.