From 2230e7853f7f3887bacbbf65e01b922e9b55cfd8 Mon Sep 17 00:00:00 2001 From: Mohammadreza Pourreza <71866535+MohammadrezaPourreza@users.noreply.github.com> Date: Thu, 7 Mar 2024 10:48:49 -0500 Subject: [PATCH] Dh-5541/AstraDB Support (#421) --- .env.example | 4 + dataherald/api/fastapi.py | 1 - dataherald/context_store/default.py | 1 - dataherald/vector_store/astra.py | 165 ++++++++++++++++++++++++++++ docs/envars.rst | 7 +- docs/vector_store.rst | 6 +- requirements.txt | 3 +- 7 files changed, 180 insertions(+), 7 deletions(-) create mode 100644 dataherald/vector_store/astra.py diff --git a/.env.example b/.env.example index ea45d0da..c81c5444 100644 --- a/.env.example +++ b/.env.example @@ -19,6 +19,10 @@ GOLDEN_SQL_COLLECTION = 'my-golden-records' #Pinecone info. These fields are required if the vector store used is Pinecone PINECONE_API_KEY = PINECONE_ENVIRONMENT = +#AstraDB info. These fields are required if the vector store used is AstraDB +ASTRA_DB_API_ENDPOINT = +ASTRA_DB_APPLICATION_TOKEN = + # Module implementations to be used names for each required component. You can use the default ones or create your own API_SERVER = "dataherald.api.fastapi.FastAPI" diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 7098f5b1..6eaf2c3b 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -422,7 +422,6 @@ def add_golden_sqls( {"items": [row.dict() for row in golden_sqls]}, "golden_sql_not_created", ) - return [GoldenSQLResponse(**golden_sql.dict()) for golden_sql in golden_sqls] @override diff --git a/dataherald/context_store/default.py b/dataherald/context_store/default.py index 8c69b782..df0f40bd 100644 --- a/dataherald/context_store/default.py +++ b/dataherald/context_store/default.py @@ -94,7 +94,6 @@ def add_golden_sqls(self, golden_sqls: List[GoldenSQLRequest]) -> List[GoldenSQL metadata=record.metadata, ) stored_golden_sqls.append(golden_sqls_repository.insert(golden_sql)) - self.vector_store.add_records(stored_golden_sqls, self.golden_sql_collection) return stored_golden_sqls diff --git a/dataherald/vector_store/astra.py b/dataherald/vector_store/astra.py new file mode 100644 index 00000000..c3307b59 --- /dev/null +++ b/dataherald/vector_store/astra.py @@ -0,0 +1,165 @@ +import os +from typing import Any, List + +from astrapy.api import APIRequestError +from astrapy.db import AstraDB +from langchain_openai import OpenAIEmbeddings +from overrides import override +from sql_metadata import Parser + +from dataherald.config import System +from dataherald.db import DB +from dataherald.repositories.database_connections import DatabaseConnectionRepository +from dataherald.types import GoldenSQL +from dataherald.vector_store import VectorStore + +EMBEDDING_MODEL = "text-embedding-3-small" + + +class Astra(VectorStore): + def __init__(self, system: System): + super().__init__(system) + astra_db_api_endpoint = os.environ.get("ASTRA_DB_API_ENDPOINT") + astra_db_application_token = os.environ.get("ASTRA_DB_APPLICATION_TOKEN") + if astra_db_api_endpoint is None: + raise ValueError("ASTRA_DB_API_ENDPOINT environment variable not set") + if astra_db_application_token is None: + raise ValueError("ASTRA_DB_APPLICATION_TOKEN environment variable not set") + self.db = AstraDB( + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace="default_keyspace", + ) + + def collection_name_formatter(self, collection: str) -> str: + return collection.replace("-", "_") + + @override + def query( + self, + query_texts: List[str], + db_connection_id: str, + collection: str, + num_results: int, + ) -> list: + collection = self.collection_name_formatter(collection) + try: + existing_collections = self.db.get_collections()["status"]["collections"] + except APIRequestError: + existing_collections = [] + if collection not in existing_collections: + raise ValueError(f"Collection {collection} does not exist") + astra_collection = self.db.collection(collection) + db_connection_repository = DatabaseConnectionRepository( + self.system.instance(DB) + ) + database_connection = db_connection_repository.find_by_id(db_connection_id) + embedding = OpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL + ) + xq = embedding.embed_query(query_texts[0]) + returened_results = astra_collection.vector_find( + vector=xq, + limit=num_results, + filter={"db_connection_id": {"$eq": db_connection_id}}, + include_similarity=True, + ) + return self.convert_to_pinecone_object_model(returened_results) + + @override + def add_records(self, golden_sqls: List[GoldenSQL], collection: str): + collection = self.collection_name_formatter(collection) + try: + existing_collections = self.db.get_collections()["status"]["collections"] + except APIRequestError: + existing_collections = [] + if collection not in existing_collections: + self.create_collection(collection) + astra_collection = self.db.collection(collection) + db_connection_repository = DatabaseConnectionRepository( + self.system.instance(DB) + ) + database_connection = db_connection_repository.find_by_id( + str(golden_sqls[0].db_connection_id) + ) + embedding = OpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL + ) + embeds = embedding.embed_documents( + [record.prompt_text for record in golden_sqls] + ) + records = [] + for key in range(len(golden_sqls)): + records.append( + { + "_id": str(golden_sqls[key].id), + "$vector": embeds[key], + "tables_used": ", ".join(Parser(golden_sqls[key].sql)) + if isinstance(Parser(golden_sqls[key].sql), list) + else "", + "db_connection_id": str(golden_sqls[key].db_connection_id), + } + ) + astra_collection.chunked_insert_many( + documents=records, chunk_size=10, concurrency=1 + ) + + @override + def add_record( + self, + documents: str, + db_connection_id: str, + collection: str, + metadata: Any, + ids: List, + ): + collection = self.collection_name_formatter(collection) + try: + existing_collections = self.db.get_collections()["status"]["collections"] + except APIRequestError: + existing_collections = [] + if collection not in existing_collections: + self.create_collection(collection) + astra_collection = self.db.collection(collection) + db_connection_repository = DatabaseConnectionRepository( + self.system.instance(DB) + ) + database_connection = db_connection_repository.find_by_id(db_connection_id) + embedding = OpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL + ) + embeds = embedding.embed_documents([documents]) + astra_collection.insert_one({"_id": ids[0], "$vector": embeds, **metadata[0]}) + + @override + def delete_record(self, collection: str, id: str): + collection = self.collection_name_formatter(collection) + try: + existing_collections = self.db.get_collections()["status"]["collections"] + except APIRequestError: + existing_collections = [] + if collection not in existing_collections: + raise ValueError(f"Collection {collection} does not exist") + astra_collection = self.db.collection(collection) + astra_collection.delete_one(id) + + @override + def delete_collection(self, collection: str): + collection = self.collection_name_formatter(collection) + return self.db.delete_collection(collection_name=collection) + + @override + def create_collection(self, collection: str): + collection = self.collection_name_formatter(collection) + return self.db.create_collection(collection, dimension=1536, metric="cosine") + + def convert_to_pinecone_object_model(self, astra_results: dict) -> List: + results = [] + for i in range(len(astra_results)): + results.append( + { + "id": astra_results[i]["_id"], + "score": astra_results[i]["$similarity"], + } + ) + return results diff --git a/docs/envars.rst b/docs/envars.rst index 46d717ca..efdd9331 100644 --- a/docs/envars.rst +++ b/docs/envars.rst @@ -15,6 +15,9 @@ provided in the .env.example file with the default values. PINECONE_API_KEY = PINECONE_ENVIRONMENT = + ASTRA_DB_API_ENDPOINT = + ASTRA_DB_APPLICATION_TOKEN = + API_SERVER = "dataherald.api.fastapi.FastAPI" SQL_GENERATOR = "dataherald.sql_generator.dataherald_sqlagent.DataheraldSQLAgent" @@ -24,7 +27,7 @@ provided in the .env.example file with the default values. CONTEXT_STORE = 'dataherald.context_store.default.DefaultContextStore' DB_SCANNER = 'dataherald.db_scanner.sqlalchemy.SqlAlchemyScanner' - + MONGODB_URI = "mongodb://admin:admin@mongodb:27017" MONGODB_DB_NAME = 'dataherald' MONGODB_DB_USERNAME = 'admin' @@ -51,6 +54,8 @@ provided in the .env.example file with the default values. "GOLDEN_RECORD_COLLECTION", "The name of the collection in Mongo where golden records will be stored", "``my-golden-records``", "No" "PINECONE_API_KEY", "The Pinecone API key used", "None", "Yes if using the Pinecone vector store" "PINECONE_ENVIRONMENT", "The Pinecone environment", "None", "Yes if using the Pinecone vector store" + "ASTRA_DB_API_ENDPOINT", "The Astra DB API endpoint", "None", "Yes if using the Astra DB" + "ASTRA_DB_APPLICATION_TOKEN", "The Astra DB application token", "None", "Yes if using the Astra DB "API_SERVER", "The implementation of the API Module used by the Dataherald Engine.", "``dataherald.api.fastapi.FastAPI``", "Yes" "SQL_GENERATOR", "The implementation of the SQLGenerator Module to be used.", "``dataherald.sql_generator. dataherald_sqlagent. DataheraldSQLAgent``", "Yes" "EVALUATOR", "The implementation of the Evaluator Module to be used.", "``dataherald.eval. simple_evaluator.SimpleEvaluator``", "Yes" diff --git a/docs/vector_store.rst b/docs/vector_store.rst index 3cc64421..01ea043e 100644 --- a/docs/vector_store.rst +++ b/docs/vector_store.rst @@ -1,18 +1,18 @@ Vector Store ==================== -The Dataherald Engine uses a Vector store for retrieving similar few shot examples from previous Natural Language to SQL pairs that have been marked as correct. Currently Pinecone and ChromaDB are the +The Dataherald Engine uses a Vector store for retrieving similar few shot examples from previous Natural Language to SQL pairs that have been marked as correct. Currently Pinecone, AstraDB, and ChromaDB are the supported vector stores, though developers can easily add support for other vector stores by implementing the abstract VectorStore class. Abstract Vector Store Class --------------------------- -Both ChromaDB and Pinecone are implemented as subclasses of the abstract :class:`VectorStore` class. This abstract class provides a unified interface for working with different vector store implementations. +AstraDB, ChromaDB and Pinecone are implemented as subclasses of the abstract :class:`VectorStore` class. This abstract class provides a unified interface for working with different vector store implementations. :class:`VectorStore` ^^^^^^^^^^^^^^^^^^^^^ -This abstract class defines the common methods that both ChromaDB and Pinecone vector stores should implement. +This abstract class defines the common methods that AstraDB, ChromaDB, and Pinecone vector stores should implement. .. method:: __init__(self, system: System) :noindex: diff --git a/requirements.txt b/requirements.txt index c8ace24e..d26176b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ dnspython==2.3.0 fastapi==0.98.0 -httpx==0.24.1 +httpx==0.27.0 langchain==0.1.11 langchain-community==0.0.25 langchain-openai==0.0.8 @@ -41,3 +41,4 @@ duckdb-engine==0.9.1 duckdb==0.9.1 PyMySQL==1.1.0 clickhouse-sqlalchemy==0.2.5 +astrapy==0.7.6