Skip to content

Commit

Permalink
Dh-5541/AstraDB Support (#421)
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza authored and DishenWang2023 committed May 7, 2024
1 parent 11122e3 commit 2230e78
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 7 deletions.
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ GOLDEN_SQL_COLLECTION = 'my-golden-records'
#Pinecone info. These fields are required if the vector store used is Pinecone
PINECONE_API_KEY =
PINECONE_ENVIRONMENT =
#AstraDB info. These fields are required if the vector store used is AstraDB
ASTRA_DB_API_ENDPOINT =
ASTRA_DB_APPLICATION_TOKEN =


# Module implementations to be used names for each required component. You can use the default ones or create your own
API_SERVER = "dataherald.api.fastapi.FastAPI"
Expand Down
1 change: 0 additions & 1 deletion dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,6 @@ def add_golden_sqls(
{"items": [row.dict() for row in golden_sqls]},
"golden_sql_not_created",
)

return [GoldenSQLResponse(**golden_sql.dict()) for golden_sql in golden_sqls]

@override
Expand Down
1 change: 0 additions & 1 deletion dataherald/context_store/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def add_golden_sqls(self, golden_sqls: List[GoldenSQLRequest]) -> List[GoldenSQL
metadata=record.metadata,
)
stored_golden_sqls.append(golden_sqls_repository.insert(golden_sql))

self.vector_store.add_records(stored_golden_sqls, self.golden_sql_collection)
return stored_golden_sqls

Expand Down
165 changes: 165 additions & 0 deletions dataherald/vector_store/astra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import os
from typing import Any, List

from astrapy.api import APIRequestError
from astrapy.db import AstraDB
from langchain_openai import OpenAIEmbeddings
from overrides import override
from sql_metadata import Parser

from dataherald.config import System
from dataherald.db import DB
from dataherald.repositories.database_connections import DatabaseConnectionRepository
from dataherald.types import GoldenSQL
from dataherald.vector_store import VectorStore

EMBEDDING_MODEL = "text-embedding-3-small"


class Astra(VectorStore):
def __init__(self, system: System):
super().__init__(system)
astra_db_api_endpoint = os.environ.get("ASTRA_DB_API_ENDPOINT")
astra_db_application_token = os.environ.get("ASTRA_DB_APPLICATION_TOKEN")
if astra_db_api_endpoint is None:
raise ValueError("ASTRA_DB_API_ENDPOINT environment variable not set")
if astra_db_application_token is None:
raise ValueError("ASTRA_DB_APPLICATION_TOKEN environment variable not set")
self.db = AstraDB(
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace="default_keyspace",
)

def collection_name_formatter(self, collection: str) -> str:
return collection.replace("-", "_")

@override
def query(
self,
query_texts: List[str],
db_connection_id: str,
collection: str,
num_results: int,
) -> list:
collection = self.collection_name_formatter(collection)
try:
existing_collections = self.db.get_collections()["status"]["collections"]
except APIRequestError:
existing_collections = []
if collection not in existing_collections:
raise ValueError(f"Collection {collection} does not exist")
astra_collection = self.db.collection(collection)
db_connection_repository = DatabaseConnectionRepository(
self.system.instance(DB)
)
database_connection = db_connection_repository.find_by_id(db_connection_id)
embedding = OpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL
)
xq = embedding.embed_query(query_texts[0])
returened_results = astra_collection.vector_find(
vector=xq,
limit=num_results,
filter={"db_connection_id": {"$eq": db_connection_id}},
include_similarity=True,
)
return self.convert_to_pinecone_object_model(returened_results)

@override
def add_records(self, golden_sqls: List[GoldenSQL], collection: str):
collection = self.collection_name_formatter(collection)
try:
existing_collections = self.db.get_collections()["status"]["collections"]
except APIRequestError:
existing_collections = []
if collection not in existing_collections:
self.create_collection(collection)
astra_collection = self.db.collection(collection)
db_connection_repository = DatabaseConnectionRepository(
self.system.instance(DB)
)
database_connection = db_connection_repository.find_by_id(
str(golden_sqls[0].db_connection_id)
)
embedding = OpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL
)
embeds = embedding.embed_documents(
[record.prompt_text for record in golden_sqls]
)
records = []
for key in range(len(golden_sqls)):
records.append(
{
"_id": str(golden_sqls[key].id),
"$vector": embeds[key],
"tables_used": ", ".join(Parser(golden_sqls[key].sql))
if isinstance(Parser(golden_sqls[key].sql), list)
else "",
"db_connection_id": str(golden_sqls[key].db_connection_id),
}
)
astra_collection.chunked_insert_many(
documents=records, chunk_size=10, concurrency=1
)

@override
def add_record(
self,
documents: str,
db_connection_id: str,
collection: str,
metadata: Any,
ids: List,
):
collection = self.collection_name_formatter(collection)
try:
existing_collections = self.db.get_collections()["status"]["collections"]
except APIRequestError:
existing_collections = []
if collection not in existing_collections:
self.create_collection(collection)
astra_collection = self.db.collection(collection)
db_connection_repository = DatabaseConnectionRepository(
self.system.instance(DB)
)
database_connection = db_connection_repository.find_by_id(db_connection_id)
embedding = OpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL
)
embeds = embedding.embed_documents([documents])
astra_collection.insert_one({"_id": ids[0], "$vector": embeds, **metadata[0]})

@override
def delete_record(self, collection: str, id: str):
collection = self.collection_name_formatter(collection)
try:
existing_collections = self.db.get_collections()["status"]["collections"]
except APIRequestError:
existing_collections = []
if collection not in existing_collections:
raise ValueError(f"Collection {collection} does not exist")
astra_collection = self.db.collection(collection)
astra_collection.delete_one(id)

@override
def delete_collection(self, collection: str):
collection = self.collection_name_formatter(collection)
return self.db.delete_collection(collection_name=collection)

@override
def create_collection(self, collection: str):
collection = self.collection_name_formatter(collection)
return self.db.create_collection(collection, dimension=1536, metric="cosine")

def convert_to_pinecone_object_model(self, astra_results: dict) -> List:
results = []
for i in range(len(astra_results)):
results.append(
{
"id": astra_results[i]["_id"],
"score": astra_results[i]["$similarity"],
}
)
return results
7 changes: 6 additions & 1 deletion docs/envars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ provided in the .env.example file with the default values.
PINECONE_API_KEY =
PINECONE_ENVIRONMENT =
ASTRA_DB_API_ENDPOINT =
ASTRA_DB_APPLICATION_TOKEN =
API_SERVER = "dataherald.api.fastapi.FastAPI"
SQL_GENERATOR = "dataherald.sql_generator.dataherald_sqlagent.DataheraldSQLAgent"
Expand All @@ -24,7 +27,7 @@ provided in the .env.example file with the default values.
CONTEXT_STORE = 'dataherald.context_store.default.DefaultContextStore'
DB_SCANNER = 'dataherald.db_scanner.sqlalchemy.SqlAlchemyScanner'
MONGODB_URI = "mongodb://admin:admin@mongodb:27017"
MONGODB_DB_NAME = 'dataherald'
MONGODB_DB_USERNAME = 'admin'
Expand All @@ -51,6 +54,8 @@ provided in the .env.example file with the default values.
"GOLDEN_RECORD_COLLECTION", "The name of the collection in Mongo where golden records will be stored", "``my-golden-records``", "No"
"PINECONE_API_KEY", "The Pinecone API key used", "None", "Yes if using the Pinecone vector store"
"PINECONE_ENVIRONMENT", "The Pinecone environment", "None", "Yes if using the Pinecone vector store"
"ASTRA_DB_API_ENDPOINT", "The Astra DB API endpoint", "None", "Yes if using the Astra DB"
"ASTRA_DB_APPLICATION_TOKEN", "The Astra DB application token", "None", "Yes if using the Astra DB
"API_SERVER", "The implementation of the API Module used by the Dataherald Engine.", "``dataherald.api.fastapi.FastAPI``", "Yes"
"SQL_GENERATOR", "The implementation of the SQLGenerator Module to be used.", "``dataherald.sql_generator. dataherald_sqlagent. DataheraldSQLAgent``", "Yes"
"EVALUATOR", "The implementation of the Evaluator Module to be used.", "``dataherald.eval. simple_evaluator.SimpleEvaluator``", "Yes"
Expand Down
6 changes: 3 additions & 3 deletions docs/vector_store.rst
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
Vector Store
====================

The Dataherald Engine uses a Vector store for retrieving similar few shot examples from previous Natural Language to SQL pairs that have been marked as correct. Currently Pinecone and ChromaDB are the
The Dataherald Engine uses a Vector store for retrieving similar few shot examples from previous Natural Language to SQL pairs that have been marked as correct. Currently Pinecone, AstraDB, and ChromaDB are the
supported vector stores, though developers can easily add support for other vector stores by implementing the abstract VectorStore class.

Abstract Vector Store Class
---------------------------

Both ChromaDB and Pinecone are implemented as subclasses of the abstract :class:`VectorStore` class. This abstract class provides a unified interface for working with different vector store implementations.
AstraDB, ChromaDB and Pinecone are implemented as subclasses of the abstract :class:`VectorStore` class. This abstract class provides a unified interface for working with different vector store implementations.

:class:`VectorStore`
^^^^^^^^^^^^^^^^^^^^^

This abstract class defines the common methods that both ChromaDB and Pinecone vector stores should implement.
This abstract class defines the common methods that AstraDB, ChromaDB, and Pinecone vector stores should implement.

.. method:: __init__(self, system: System)
:noindex:
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
dnspython==2.3.0
fastapi==0.98.0
httpx==0.24.1
httpx==0.27.0
langchain==0.1.11
langchain-community==0.0.25
langchain-openai==0.0.8
Expand Down Expand Up @@ -41,3 +41,4 @@ duckdb-engine==0.9.1
duckdb==0.9.1
PyMySQL==1.1.0
clickhouse-sqlalchemy==0.2.5
astrapy==0.7.6

0 comments on commit 2230e78

Please sign in to comment.