diff --git a/README.md b/README.md index fbb4aae0..32c6478b 100644 --- a/README.md +++ b/README.md @@ -370,11 +370,11 @@ curl -X 'PUT' \ ### Querying the Database in Natural Language -Once you have connected the engine to your data warehouse (and preferably added some context to the store), you can query your data warehouse using the `POST /api/v1/question` endpoint. +Once you have connected the engine to your data warehouse (and preferably added some context to the store), you can query your data warehouse using the `POST /api/v1/questions` endpoint. ``` curl -X 'POST' \ - '/api/v1/question' \ + '/api/v1/questions' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ diff --git a/dataherald/api/__init__.py b/dataherald/api/__init__.py index fbdd0dd1..a2bd90fe 100644 --- a/dataherald/api/__init__.py +++ b/dataherald/api/__init__.py @@ -5,21 +5,21 @@ from dataherald.api.types import Query from dataherald.config import Component -from dataherald.db_scanner.models.types import TableSchemaDetail +from dataherald.db_scanner.models.types import TableDescription from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings from dataherald.types import ( + CreateResponseRequest, DatabaseConnectionRequest, - ExecuteTempQueryRequest, GoldenRecord, GoldenRecordRequest, Instruction, InstructionRequest, - NLQueryResponse, + Question, QuestionRequest, + Response, ScannerRequest, TableDescriptionRequest, UpdateInstruction, - UpdateQueryRequest, ) @@ -36,7 +36,11 @@ def scan_db( pass @abstractmethod - def answer_question(self, question_request: QuestionRequest) -> NLQueryResponse: + def answer_question(self, question_request: QuestionRequest) -> Response: + pass + + @abstractmethod + def get_questions(self, db_connection_id: str | None = None) -> list[Question]: pass @abstractmethod @@ -62,13 +66,17 @@ def update_table_description( self, table_description_id: str, table_description_request: TableDescriptionRequest, - ) -> TableSchemaDetail: + ) -> TableDescription: pass @abstractmethod def list_table_descriptions( self, db_connection_id: str | None = None, table_name: str | None = None - ) -> list[TableSchemaDetail]: + ) -> list[TableDescription]: + pass + + @abstractmethod + def get_responses(self, question_id: str | None = None) -> list[Response]: pass @abstractmethod @@ -82,15 +90,7 @@ def execute_sql_query(self, query: Query) -> tuple[str, dict]: pass @abstractmethod - def update_nl_query_response( - self, query_id: str, query: UpdateQueryRequest - ) -> NLQueryResponse: - pass - - @abstractmethod - def get_nl_query_response( - self, query_request: ExecuteTempQueryRequest - ) -> NLQueryResponse: + def create_response(self, query_request: CreateResponseRequest) -> Response: pass @abstractmethod diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index c6fde9b4..f6076851 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -4,6 +4,7 @@ from typing import List from bson import json_util +from bson.objectid import ObjectId from fastapi import BackgroundTasks, HTTPException from overrides import override @@ -13,14 +14,14 @@ from dataherald.context_store import ContextStore from dataherald.db import DB from dataherald.db_scanner import Scanner -from dataherald.db_scanner.models.types import TableDescriptionStatus, TableSchemaDetail -from dataherald.db_scanner.repository.base import DBScannerRepository +from dataherald.db_scanner.models.types import TableDescription, TableDescriptionStatus +from dataherald.db_scanner.repository.base import TableDescriptionRepository from dataherald.eval import Evaluator -from dataherald.repositories.base import NLQueryResponseRepository +from dataherald.repositories.base import ResponseRepository from dataherald.repositories.database_connections import DatabaseConnectionRepository from dataherald.repositories.golden_records import GoldenRecordRepository from dataherald.repositories.instructions import InstructionRepository -from dataherald.repositories.nl_question import NLQuestionRepository +from dataherald.repositories.question import QuestionRepository from dataherald.sql_database.base import ( InvalidDBConnectionError, SQLDatabase, @@ -30,19 +31,18 @@ from dataherald.sql_generator import SQLGenerator from dataherald.sql_generator.generates_nl_answer import GeneratesNlAnswer from dataherald.types import ( + CreateResponseRequest, DatabaseConnectionRequest, - ExecuteTempQueryRequest, GoldenRecord, GoldenRecordRequest, Instruction, InstructionRequest, - NLQuery, - NLQueryResponse, + Question, QuestionRequest, + Response, ScannerRequest, TableDescriptionRequest, UpdateInstruction, - UpdateQueryRequest, ) logger = logging.getLogger(__name__) @@ -53,7 +53,7 @@ def async_scanning(scanner, database, scanner_request, storage): database, scanner_request.db_connection_id, scanner_request.table_names, - DBScannerRepository(storage), + TableDescriptionRepository(storage), ) @@ -103,7 +103,7 @@ def scan_db( scanner.synchronizing( scanner_request.table_names, scanner_request.db_connection_id, - DBScannerRepository(self.storage), + TableDescriptionRepository(self.storage), ) background_tasks.add_task( @@ -112,20 +112,20 @@ def scan_db( return True @override - def answer_question(self, question_request: QuestionRequest) -> NLQueryResponse: + def answer_question(self, question_request: QuestionRequest) -> Response: """Takes in an English question and answers it based on content from the registered databases""" logger.info(f"Answer question: {question_request.question}") sql_generation = self.system.instance(SQLGenerator) evaluator = self.system.instance(Evaluator) context_store = self.system.instance(ContextStore) - user_question = NLQuery( + user_question = Question( question=question_request.question, db_connection_id=question_request.db_connection_id, ) - nl_question_repository = NLQuestionRepository(self.storage) - user_question = nl_question_repository.insert(user_question) + question_repository = QuestionRepository(self.storage) + user_question = question_repository.insert(user_question) db_connection_repository = DatabaseConnectionRepository(self.storage) database_connection = db_connection_repository.find_by_id( @@ -149,9 +149,8 @@ def answer_question(self, question_request: QuestionRequest) -> NLQueryResponse: raise HTTPException(status_code=404, detail=str(e)) from e generated_answer.confidence_score = confidence_score generated_answer.exec_time = time.time() - start_generated_answer - nl_query_response_repository = NLQueryResponseRepository(self.storage) - nl_query_response = nl_query_response_repository.insert(generated_answer) - return json.loads(json_util.dumps(nl_query_response)) + response_repository = ResponseRepository(self.storage) + return response_repository.insert(generated_answer) @override def create_database_connection( @@ -217,8 +216,8 @@ def update_table_description( self, table_description_id: str, table_description_request: TableDescriptionRequest, - ) -> TableSchemaDetail: - scanner_repository = DBScannerRepository(self.storage) + ) -> TableDescription: + scanner_repository = TableDescriptionRepository(self.storage) table = scanner_repository.find_by_id(table_description_id) if not table: @@ -239,8 +238,8 @@ def update_table_description( @override def list_table_descriptions( self, db_connection_id: str | None = None, table_name: str | None = None - ) -> list[TableSchemaDetail]: - scanner_repository = DBScannerRepository(self.storage) + ) -> list[TableDescription]: + scanner_repository = TableDescriptionRepository(self.storage) table_descriptions = scanner_repository.find_by( {"db_connection_id": db_connection_id, "table_name": table_name} ) @@ -260,7 +259,7 @@ def list_table_descriptions( all_tables.remove(table_description.table_name) for table in all_tables: table_descriptions.append( - TableSchemaDetail( + TableDescription( table_name=table, status=TableDescriptionStatus.NOT_SYNCHRONIZED.value, db_connection_id=db_connection_id, @@ -270,6 +269,22 @@ def list_table_descriptions( return table_descriptions + @override + def get_responses(self, question_id: str | None = None) -> list[Response]: + response_repository = ResponseRepository(self.storage) + query = {} + if question_id: + query = {"question_id": ObjectId(question_id)} + return response_repository.find_by(query) + + @override + def get_questions(self, db_connection_id: str | None = None) -> list[Question]: + question_repository = QuestionRepository(self.storage) + query = {} + if db_connection_id: + query = {"db_connection_id": db_connection_id} + return question_repository.find_by(query) + @override def add_golden_records( self, golden_records: List[GoldenRecordRequest] @@ -295,53 +310,22 @@ def execute_sql_query(self, query: Query) -> tuple[str, dict]: return result @override - def update_nl_query_response( - self, query_id: str, query: UpdateQueryRequest # noqa: ARG002 - ) -> NLQueryResponse: - nl_query_response_repository = NLQueryResponseRepository(self.storage) - nl_question_repository = NLQuestionRepository(self.storage) - nl_query_response = nl_query_response_repository.find_by_id(query_id) - nl_question = nl_question_repository.find_by_id( - nl_query_response.nl_question_id + def create_response( + self, query_request: CreateResponseRequest # noqa: ARG002 + ) -> Response: + response = Response( + question_id=query_request.question_id, sql_query=query_request.sql_query ) - if nl_query_response.sql_query.strip() != query.sql_query.strip(): - nl_query_response.sql_query = query.sql_query - evaluator = self.system.instance(Evaluator) - db_connection_repository = DatabaseConnectionRepository(self.storage) - database_connection = db_connection_repository.find_by_id( - nl_question.db_connection_id - ) - if not database_connection: - raise HTTPException( - status_code=404, detail="Database connection not found" - ) - try: - confidence_score = evaluator.get_confidence_score( - nl_question, nl_query_response, database_connection - ) - nl_query_response.confidence_score = confidence_score - generates_nl_answer = GeneratesNlAnswer(self.system, self.storage) - nl_query_response = generates_nl_answer.execute(nl_query_response) - except SQLInjectionError as e: - raise HTTPException(status_code=404, detail=str(e)) from e - nl_query_response_repository.update(nl_query_response) - return json.loads(json_util.dumps(nl_query_response)) + response_repository = ResponseRepository(self.storage) + response_repository.insert(response) - @override - def get_nl_query_response( - self, query_request: ExecuteTempQueryRequest # noqa: ARG002 - ) -> NLQueryResponse: - nl_query_response_repository = NLQueryResponseRepository(self.storage) - nl_query_response = nl_query_response_repository.find_by_id( - query_request.query_id - ) - nl_query_response.sql_query = query_request.sql_query try: generates_nl_answer = GeneratesNlAnswer(self.system, self.storage) - nl_query_response = generates_nl_answer.execute(nl_query_response) + response = generates_nl_answer.execute(response) + response_repository.update(response) except SQLInjectionError as e: raise HTTPException(status_code=404, detail=str(e)) from e - return json.loads(json_util.dumps(nl_query_response)) + return response @override def delete_golden_record(self, golden_record_id: str) -> dict: diff --git a/dataherald/context_store/__init__.py b/dataherald/context_store/__init__.py index 693b5d6e..6189c8ca 100644 --- a/dataherald/context_store/__init__.py +++ b/dataherald/context_store/__init__.py @@ -4,7 +4,7 @@ from dataherald.config import Component, System from dataherald.db import DB -from dataherald.types import GoldenRecord, GoldenRecordRequest, NLQuery +from dataherald.types import GoldenRecord, GoldenRecordRequest, Question from dataherald.vector_store import VectorStore @@ -24,7 +24,7 @@ def __init__(self, system: System): @abstractmethod def retrieve_context_for_question( - self, nl_question: NLQuery, number_of_samples: int = 3 + self, nl_question: Question, number_of_samples: int = 3 ) -> Tuple[List[dict] | None, List[dict] | None]: pass diff --git a/dataherald/context_store/default.py b/dataherald/context_store/default.py index 453d34b8..bb910f1c 100644 --- a/dataherald/context_store/default.py +++ b/dataherald/context_store/default.py @@ -8,7 +8,7 @@ from dataherald.context_store import ContextStore from dataherald.repositories.golden_records import GoldenRecordRepository from dataherald.repositories.instructions import InstructionRepository -from dataherald.types import GoldenRecord, GoldenRecordRequest, NLQuery +from dataherald.types import GoldenRecord, GoldenRecordRequest, Question logger = logging.getLogger(__name__) @@ -19,7 +19,7 @@ def __init__(self, system: System): @override def retrieve_context_for_question( - self, nl_question: NLQuery, number_of_samples: int = 3 + self, nl_question: Question, number_of_samples: int = 3 ) -> Tuple[List[dict] | None, List[dict] | None]: logger.info(f"Getting context for {nl_question.question}") closest_questions = self.vector_store.query( diff --git a/dataherald/db/__init__.py b/dataherald/db/__init__.py index 27a40421..5646a90e 100644 --- a/dataherald/db/__init__.py +++ b/dataherald/db/__init__.py @@ -16,6 +16,12 @@ def insert_one(self, collection: str, obj: dict) -> int: def rename(self, old_collection_name: str, new_collection_name) -> None: pass + @abstractmethod + def rename_field( + self, collection_name: str, old_field_name: str, new_field_name: str + ) -> None: + pass + @abstractmethod def update_or_create(self, collection: str, query: dict, obj: dict) -> int: pass diff --git a/dataherald/db/mongo.py b/dataherald/db/mongo.py index 822632a1..81e510c2 100644 --- a/dataherald/db/mongo.py +++ b/dataherald/db/mongo.py @@ -27,6 +27,14 @@ def insert_one(self, collection: str, obj: dict) -> int: def rename(self, old_collection_name: str, new_collection_name) -> None: self._data_store[old_collection_name].rename(new_collection_name) + @override + def rename_field( + self, collection_name: str, old_field_name: str, new_field_name: str + ) -> None: + self._data_store[collection_name].update_many( + {}, {"$rename": {old_field_name: new_field_name}} + ) + @override def update_or_create(self, collection: str, query: dict, obj: dict) -> int: row = self.find_one(collection, query) diff --git a/dataherald/db_scanner/__init__.py b/dataherald/db_scanner/__init__.py index 4626f1e4..da5e2c68 100644 --- a/dataherald/db_scanner/__init__.py +++ b/dataherald/db_scanner/__init__.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from dataherald.config import Component -from dataherald.db_scanner.repository.base import DBScannerRepository +from dataherald.db_scanner.repository.base import TableDescriptionRepository from dataherald.sql_database.base import SQLDatabase @@ -13,13 +13,16 @@ def scan( db_engine: SQLDatabase, db_connection_id: str, table_names: list[str] | None, - repository: DBScannerRepository, + repository: TableDescriptionRepository, ) -> None: """ "Scan a db""" @abstractmethod def synchronizing( - self, tables: list[str], db_connection_id: str, repository: DBScannerRepository + self, + tables: list[str], + db_connection_id: str, + repository: TableDescriptionRepository, ) -> None: """ "Update table_description status""" diff --git a/dataherald/db_scanner/models/types.py b/dataherald/db_scanner/models/types.py index 9922b12b..cf29700d 100644 --- a/dataherald/db_scanner/models/types.py +++ b/dataherald/db_scanner/models/types.py @@ -28,8 +28,8 @@ class TableDescriptionStatus(Enum): FAILED = "FAILED" -class TableSchemaDetail(BaseModel): - id: Any +class TableDescription(BaseModel): + id: str | None db_connection_id: str table_name: str description: str | None diff --git a/dataherald/db_scanner/repository/base.py b/dataherald/db_scanner/repository/base.py index f69a5972..ae4a5078 100644 --- a/dataherald/db_scanner/repository/base.py +++ b/dataherald/db_scanner/repository/base.py @@ -3,77 +3,86 @@ from bson.objectid import ObjectId from pymongo import ASCENDING -from dataherald.db_scanner.models.types import TableSchemaDetail +from dataherald.db_scanner.models.types import TableDescription DB_COLLECTION = "table_descriptions" -class DBScannerRepository: +class TableDescriptionRepository: def __init__(self, storage): self.storage = storage - def find_by_id(self, id: str) -> TableSchemaDetail | None: + def find_by_id(self, id: str) -> TableDescription | None: row = self.storage.find_one(DB_COLLECTION, {"_id": ObjectId(id)}) if not row: return None - obj = TableSchemaDetail(**row) - obj.id = str(row["_id"]) - return obj + row["id"] = str(row["_id"]) + row["db_connection_id"] = str(row["db_connection_id"]) + return TableDescription(**row) def get_table_info( self, db_connection_id: str, table_name: str - ) -> TableSchemaDetail | None: + ) -> TableDescription | None: row = self.storage.find_one( DB_COLLECTION, {"db_connection_id": db_connection_id, "table_name": table_name}, ) if row: - row["id"] = row["_id"] - return TableSchemaDetail(**row) + row["id"] = str(row["_id"]) + row["db_connection_id"] = str(row["db_connection_id"]) + return TableDescription(**row) return None - def get_all_tables_by_db(self, query: dict) -> List[TableSchemaDetail]: + def get_all_tables_by_db(self, query: dict) -> List[TableDescription]: rows = self.storage.find(DB_COLLECTION, query) tables = [] for row in rows: - row["id"] = row["_id"] - tables.append(TableSchemaDetail(**row)) + row["id"] = str(row["_id"]) + row["db_connection_id"] = str(row["db_connection_id"]) + tables.append(TableDescription(**row)) return tables - def save_table_info(self, table_info: TableSchemaDetail) -> None: + def save_table_info(self, table_info: TableDescription) -> None: + table_info_dict = table_info.dict(exclude={"id"}) + table_info_dict["db_connection_id"] = ObjectId(table_info.db_connection_id) self.storage.update_or_create( DB_COLLECTION, { - "db_connection_id": table_info.db_connection_id, + "db_connection_id": table_info_dict["db_connection_id"], "table_name": table_info.table_name, }, table_info.dict(), ) - def update(self, table_info: TableSchemaDetail) -> TableSchemaDetail: + def update(self, table_info: TableDescription) -> TableDescription: + table_info_dict = table_info.dict(exclude={"id"}) + table_info_dict["db_connection_id"] = ObjectId(table_info.db_connection_id) + self.storage.update_or_create( DB_COLLECTION, {"_id": ObjectId(table_info.id)}, - table_info.dict(exclude={"id"}), + table_info_dict, ) return table_info - def find_all(self) -> list[TableSchemaDetail]: + def find_all(self) -> list[TableDescription]: rows = self.storage.find_all(DB_COLLECTION) result = [] for row in rows: - obj = TableSchemaDetail(**row) - obj.id = str(row["_id"]) + row["id"] = str(row["_id"]) + row["db_connection_id"] = str(row["db_connection_id"]) + obj = TableDescription(**row) result.append(obj) return result - def find_by(self, query: dict) -> list[TableSchemaDetail]: + def find_by(self, query: dict) -> list[TableDescription]: query = {k: v for k, v in query.items() if v} rows = self.storage.find(DB_COLLECTION, query, sort=[("table_name", ASCENDING)]) result = [] for row in rows: - obj = TableSchemaDetail(**row) + row["id"] = str(row["_id"]) + row["db_connection_id"] = str(row["db_connection_id"]) + obj = TableDescription(**row) obj.columns = sorted(obj.columns, key=lambda x: x.name) - obj.id = str(row["_id"]) result.append(obj) return result diff --git a/dataherald/db_scanner/sqlalchemy.py b/dataherald/db_scanner/sqlalchemy.py index 4f8e0d54..49d71aca 100644 --- a/dataherald/db_scanner/sqlalchemy.py +++ b/dataherald/db_scanner/sqlalchemy.py @@ -10,10 +10,10 @@ from dataherald.db_scanner import Scanner from dataherald.db_scanner.models.types import ( ColumnDetail, + TableDescription, TableDescriptionStatus, - TableSchemaDetail, ) -from dataherald.db_scanner.repository.base import DBScannerRepository +from dataherald.db_scanner.repository.base import TableDescriptionRepository from dataherald.sql_database.base import SQLDatabase MIN_CATEGORY_VALUE = 1 @@ -24,12 +24,15 @@ class SqlAlchemyScanner(Scanner): @override def synchronizing( - self, tables: list[str], db_connection_id: str, repository: DBScannerRepository + self, + tables: list[str], + db_connection_id: str, + repository: TableDescriptionRepository, ) -> None: # persist tables to be scanned for table in tables: repository.save_table_info( - TableSchemaDetail( + TableDescription( db_connection_id=db_connection_id, table_name=table, status=TableDescriptionStatus.SYNCHRONIZING.value, @@ -144,8 +147,8 @@ def scan_single_table( table: str, db_engine: SQLDatabase, db_connection_id: str, - repository: DBScannerRepository, - ) -> TableSchemaDetail: + repository: TableDescriptionRepository, + ) -> TableDescription: print(f"Scanning table: {table}") inspector = inspect(db_engine.engine) table_columns = [] @@ -160,7 +163,7 @@ def scan_single_table( ) ) - object = TableSchemaDetail( + object = TableDescription( db_connection_id=db_connection_id, table_name=table, columns=table_columns, @@ -183,7 +186,7 @@ def scan( db_engine: SQLDatabase, db_connection_id: str, table_names: list[str] | None, - repository: DBScannerRepository, + repository: TableDescriptionRepository, ) -> None: inspector = inspect(db_engine.engine) meta = MetaData(bind=db_engine.engine) @@ -208,7 +211,7 @@ def scan( ) except Exception as e: repository.save_table_info( - TableSchemaDetail( + TableDescription( db_connection_id=db_connection_id, table_name=table, status=TableDescriptionStatus.FAILED.value, diff --git a/dataherald/eval/__init__.py b/dataherald/eval/__init__.py index 050f0c65..a3d55d45 100644 --- a/dataherald/eval/__init__.py +++ b/dataherald/eval/__init__.py @@ -6,7 +6,7 @@ from dataherald.model.chat_model import ChatModel from dataherald.sql_database.base import SQLDatabase from dataherald.sql_database.models.types import DatabaseConnection -from dataherald.types import NLQuery, NLQueryResponse +from dataherald.types import Question, Response class Evaluation(BaseModel): @@ -27,8 +27,8 @@ def __init__(self, system: System): def get_confidence_score( self, - question: NLQuery, - generated_answer: NLQueryResponse, + question: Question, + generated_answer: Response, database_connection: DatabaseConnection, ) -> confloat: """Determines if a generated response from the engine is acceptable considering the ACCEPTANCE_THRESHOLD""" @@ -42,8 +42,8 @@ def get_confidence_score( @abstractmethod def evaluate( self, - question: NLQuery, - generated_answer: NLQueryResponse, + question: Question, + generated_answer: Response, database_connection: DatabaseConnection, ) -> Evaluation: """Evaluates a question with an SQL pair.""" diff --git a/dataherald/eval/eval_agent.py b/dataherald/eval/eval_agent.py index ea46d082..a48c2ea1 100644 --- a/dataherald/eval/eval_agent.py +++ b/dataherald/eval/eval_agent.py @@ -27,7 +27,7 @@ from dataherald.eval import Evaluation, Evaluator from dataherald.sql_database.base import SQLDatabase from dataherald.sql_database.models.types import DatabaseConnection -from dataherald.types import NLQuery, NLQueryResponse +from dataherald.types import Question, Response logger = logging.getLogger(__name__) @@ -239,8 +239,8 @@ def create_evaluation_agent( @override def evaluate( self, - question: NLQuery, - generated_answer: NLQueryResponse, + question: Question, + generated_answer: Response, database_connection: DatabaseConnection, ) -> Evaluation: start_time = time.time() diff --git a/dataherald/eval/simple_evaluator.py b/dataherald/eval/simple_evaluator.py index 35724e24..5ebe4b72 100644 --- a/dataherald/eval/simple_evaluator.py +++ b/dataherald/eval/simple_evaluator.py @@ -17,7 +17,7 @@ from dataherald.eval import Evaluation, Evaluator from dataherald.sql_database.base import SQLDatabase from dataherald.sql_database.models.types import DatabaseConnection -from dataherald.types import NLQuery, NLQueryResponse +from dataherald.types import Question, Response logger = logging.getLogger(__name__) @@ -80,8 +80,8 @@ def answer_parser(self, answer: str) -> int: @override def evaluate( self, - question: NLQuery, - generated_answer: NLQueryResponse, + question: Question, + generated_answer: Response, database_connection: DatabaseConnection, ) -> Evaluation: database = SQLDatabase.get_sql_engine(database_connection) diff --git a/dataherald/repositories/base.py b/dataherald/repositories/base.py index d4c7caff..5d589d19 100644 --- a/dataherald/repositories/base.py +++ b/dataherald/repositories/base.py @@ -1,38 +1,59 @@ from bson.objectid import ObjectId +from pymongo import DESCENDING -from dataherald.types import NLQueryResponse +from dataherald.types import Response -DB_COLLECTION = "nl_query_responses" +DB_COLLECTION = "responses" -class NLQueryResponseRepository: +class ResponseRepository: def __init__(self, storage): self.storage = storage - def insert(self, nl_query_response: NLQueryResponse) -> NLQueryResponse: - nl_query_response.id = self.storage.insert_one( - DB_COLLECTION, nl_query_response.dict(exclude={"id"}) - ) - return nl_query_response + def insert(self, response: Response) -> Response: + response_dict = response.dict(exclude={"id"}) + response_dict["question_id"] = ObjectId(response.question_id) + response.id = str(self.storage.insert_one(DB_COLLECTION, response_dict)) + return response - def find_one(self, query: dict) -> NLQueryResponse | None: + def find_one(self, query: dict) -> Response | None: row = self.storage.find_one(DB_COLLECTION, query) if not row: return None - row["id"] = row["_id"] - return NLQueryResponse(**row) + row["id"] = str(row["_id"]) + row["question_id"] = str(row["question_id"]) + return Response(**row) + + def update(self, response: Response) -> Response: + response_dict = response.dict(exclude={"id"}) + response_dict["question_id"] = ObjectId(response.question_id) - def update(self, nl_query_response: NLQueryResponse) -> NLQueryResponse: self.storage.update_or_create( DB_COLLECTION, - {"_id": ObjectId(nl_query_response.id)}, - nl_query_response.dict(exclude={"id"}), + {"_id": ObjectId(response.id)}, + response_dict, ) - return nl_query_response + return response - def find_by_id(self, id: str) -> NLQueryResponse | None: + def find_by_id(self, id: str) -> Response | None: row = self.storage.find_one(DB_COLLECTION, {"_id": ObjectId(id)}) if not row: return None - row["id"] = row["_id"] - return NLQueryResponse(**row) + row["id"] = str(row["_id"]) + row["question_id"] = str(row["question_id"]) + return Response(**row) + + def find_by(self, query: dict, page: int = 1, limit: int = 10) -> list[Response]: + rows = self.storage.find( + DB_COLLECTION, + query, + page=page, + limit=limit, + sort=[("created_at", DESCENDING)], + ) + result = [] + for row in rows: + row["id"] = str(row["_id"]) + row["question_id"] = str(row["question_id"]) + result.append(Response(**row)) + return result diff --git a/dataherald/repositories/database_connections.py b/dataherald/repositories/database_connections.py index af2019bb..31cf6522 100644 --- a/dataherald/repositories/database_connections.py +++ b/dataherald/repositories/database_connections.py @@ -21,7 +21,9 @@ def find_one(self, query: dict) -> DatabaseConnection | None: row = self.storage.find_one(DB_COLLECTION, query) if not row: return None - return DatabaseConnection(**row) + obj = DatabaseConnection(**row) + obj.id = str(row["_id"]) + return obj def update(self, database_connection: DatabaseConnection) -> DatabaseConnection: self.storage.update_or_create( diff --git a/dataherald/repositories/golden_records.py b/dataherald/repositories/golden_records.py index 7775d6cd..460145b2 100644 --- a/dataherald/repositories/golden_records.py +++ b/dataherald/repositories/golden_records.py @@ -10,8 +10,12 @@ def __init__(self, storage): self.storage = storage def insert(self, golden_record: GoldenRecord) -> GoldenRecord: + golden_record_dict = golden_record.dict(exclude={"id"}) + golden_record_dict["db_connection_id"] = ObjectId( + golden_record.db_connection_id + ) golden_record.id = str( - self.storage.insert_one(DB_COLLECTION, golden_record.dict(exclude={"id"})) + self.storage.insert_one(DB_COLLECTION, golden_record_dict) ) return golden_record @@ -19,13 +23,20 @@ def find_one(self, query: dict) -> GoldenRecord | None: row = self.storage.find_one(DB_COLLECTION, query) if not row: return None + row["id"] = str(row["_id"]) + row["db_connection_id"] = str(row["db_connection_id"]) return GoldenRecord(**row) def update(self, golden_record: GoldenRecord) -> GoldenRecord: + golden_record_dict = golden_record.dict(exclude={"id"}) + golden_record_dict["db_connection_id"] = ObjectId( + golden_record.db_connection_id + ) + self.storage.update_or_create( DB_COLLECTION, {"_id": ObjectId(golden_record.id)}, - golden_record.dict(exclude={"id"}), + golden_record_dict, ) return golden_record @@ -33,17 +44,29 @@ def find_by_id(self, id: str) -> GoldenRecord | None: row = self.storage.find_one(DB_COLLECTION, {"_id": ObjectId(id)}) if not row: return None + row["id"] = str(row["_id"]) + row["db_connection_id"] = str(row["db_connection_id"]) return GoldenRecord(**row) def find_by( self, query: dict, page: int = 1, limit: int = 10 ) -> list[GoldenRecord]: rows = self.storage.find(DB_COLLECTION, query, page=page, limit=limit) - return [GoldenRecord(id=str(row["_id"]), **row) for row in rows] + golden_records = [] + for row in rows: + row["id"] = str(row["_id"]) + row["db_connection_id"] = str(row["db_connection_id"]) + golden_records.append(GoldenRecord(**row)) + return golden_records def find_all(self, page: int = 0, limit: int = 0) -> list[GoldenRecord]: rows = self.storage.find_all(DB_COLLECTION, page=page, limit=limit) - return [GoldenRecord(id=str(row["_id"]), **row) for row in rows] + golden_records = [] + for row in rows: + row["id"] = str(row["_id"]) + row["db_connection_id"] = str(row["db_connection_id"]) + golden_records.append(GoldenRecord(**row)) + return golden_records def delete_by_id(self, id: str) -> int: return self.storage.delete_by_id(DB_COLLECTION, id) diff --git a/dataherald/repositories/instructions.py b/dataherald/repositories/instructions.py index 8087729e..ad88fac5 100644 --- a/dataherald/repositories/instructions.py +++ b/dataherald/repositories/instructions.py @@ -10,22 +10,28 @@ def __init__(self, storage): self.storage = storage def insert(self, instruction: Instruction) -> Instruction: - instruction.id = str( - self.storage.insert_one(DB_COLLECTION, instruction.dict(exclude={"id"})) - ) + instruction_dict = instruction.dict(exclude={"id"}) + instruction_dict["db_connection_id"] = ObjectId(instruction.db_connection_id) + instruction.id = str(self.storage.insert_one(DB_COLLECTION, instruction_dict)) + return instruction def find_one(self, query: dict) -> Instruction | None: row = self.storage.find_one(DB_COLLECTION, query) if not row: return None + row["id"] = str(row["_id"]) + row["db_connection_id"] = str(row["db_connection_id"]) return Instruction(**row) def update(self, instruction: Instruction) -> Instruction: + instruction_dict = instruction.dict(exclude={"id"}) + instruction_dict["db_connection_id"] = ObjectId(instruction.db_connection_id) + self.storage.update_or_create( DB_COLLECTION, {"_id": ObjectId(instruction.id)}, - instruction.dict(exclude={"id"}), + instruction_dict, ) return instruction @@ -33,15 +39,27 @@ def find_by_id(self, id: str) -> Instruction | None: row = self.storage.find_one(DB_COLLECTION, {"_id": ObjectId(id)}) if not row: return None + row["id"] = str(row["_id"]) + row["db_connection_id"] = str(row["db_connection_id"]) return Instruction(**row) def find_by(self, query: dict, page: int = 1, limit: int = 10) -> list[Instruction]: rows = self.storage.find(DB_COLLECTION, query, page=page, limit=limit) - return [Instruction(id=str(row["_id"]), **row) for row in rows] + result = [] + for row in rows: + row["id"] = str(row["_id"]) + row["db_connection_id"] = str(row["db_connection_id"]) + result.append(Instruction(**row)) + return result def find_all(self, page: int = 0, limit: int = 0) -> list[Instruction]: rows = self.storage.find_all(DB_COLLECTION, page=page, limit=limit) - return [Instruction(id=str(row["_id"]), **row) for row in rows] + result = [] + for row in rows: + row["id"] = str(row["_id"]) + row["db_connection_id"] = str(row["db_connection_id"]) + result.append(Instruction(**row)) + return result def delete_by_id(self, id: str) -> int: return self.storage.delete_by_id(DB_COLLECTION, id) diff --git a/dataherald/repositories/nl_question.py b/dataherald/repositories/nl_question.py deleted file mode 100644 index 65edcee6..00000000 --- a/dataherald/repositories/nl_question.py +++ /dev/null @@ -1,28 +0,0 @@ -from bson.objectid import ObjectId - -from dataherald.types import NLQuery - -DB_COLLECTION = "nl_questions" - - -class NLQuestionRepository: - def __init__(self, storage): - self.storage = storage - - def insert(self, nl_query: NLQuery) -> NLQuery: - nl_query.id = self.storage.insert_one( - DB_COLLECTION, nl_query.dict(exclude={"id"}) - ) - return nl_query - - def find_one(self, query: dict) -> NLQuery | None: - row = self.storage.find_one(DB_COLLECTION, query) - if not row: - return None - return NLQuery(**row) - - def find_by_id(self, id: str) -> NLQuery | None: - row = self.storage.find_one(DB_COLLECTION, {"_id": ObjectId(id)}) - if not row: - return None - return NLQuery(**row) diff --git a/dataherald/repositories/question.py b/dataherald/repositories/question.py new file mode 100644 index 00000000..c3112e74 --- /dev/null +++ b/dataherald/repositories/question.py @@ -0,0 +1,41 @@ +from bson.objectid import ObjectId + +from dataherald.types import Question + +DB_COLLECTION = "questions" + + +class QuestionRepository: + def __init__(self, storage): + self.storage = storage + + def insert(self, question: Question) -> Question: + question_dict = question.dict(exclude={"id"}) + question_dict["db_connection_id"] = ObjectId(question.db_connection_id) + question.id = str(self.storage.insert_one(DB_COLLECTION, question_dict)) + return question + + def find_one(self, query: dict) -> Question | None: + row = self.storage.find_one(DB_COLLECTION, query) + if not row: + return None + row["id"] = str(row["_id"]) + row["db_connection_id"] = str(row["db_connection_id"]) + return Question(**row) + + def find_by_id(self, id: str) -> Question | None: + row = self.storage.find_one(DB_COLLECTION, {"_id": ObjectId(id)}) + if not row: + return None + row["id"] = str(row["_id"]) + row["db_connection_id"] = str(row["db_connection_id"]) + return Question(**row) + + def find_by(self, query: dict, page: int = 1, limit: int = 10) -> list[Question]: + rows = self.storage.find(DB_COLLECTION, query, page=page, limit=limit) + result = [] + for row in rows: + row["id"] = str(row["_id"]) + row["db_connection_id"] = str(row["db_connection_id"]) + result.append(Question(**row)) + return result diff --git a/dataherald/scripts/migrate_v003_to_v004.py b/dataherald/scripts/migrate_v003_to_v004.py new file mode 100644 index 00000000..e1953343 --- /dev/null +++ b/dataherald/scripts/migrate_v003_to_v004.py @@ -0,0 +1,44 @@ +from datetime import datetime + +from bson.objectid import ObjectId + +import dataherald.config +from dataherald.config import System +from dataherald.db import DB + + +def update_object_id_fields(field_name: str, collection_name: str): + for obj in storage.find_all(collection_name): + if obj[field_name] and obj[field_name] != "": + obj[field_name] = ObjectId(obj[field_name]) + storage.update_or_create(collection_name, {"_id": obj["_id"]}, obj) + + +if __name__ == "__main__": + settings = dataherald.config.Settings() + system = System(settings) + system.start() + storage = system.instance(DB) + + # Rename collections + try: + storage.rename("nl_questions", "questions") + storage.rename("nl_query_responses", "responses") + except Exception: # noqa: S110 + pass + + # Rename fields + storage.rename_field("responses", "nl_question_id", "question_id") + storage.rename_field("responses", "nl_response", "response") + + # Add field + for response in storage.find_all("responses"): + if "created_at" not in response: + response["created_at"] = datetime.now() + storage.update_or_create("responses", {"_id": response["_id"]}, response) + + # Change datatype + update_object_id_fields("db_connection_id", "table_descriptions") + update_object_id_fields("db_connection_id", "golden_records") + update_object_id_fields("db_connection_id", "questions") + update_object_id_fields("question_id", "responses") diff --git a/dataherald/server/fastapi/__init__.py b/dataherald/server/fastapi/__init__.py index e18cb4f8..215c6742 100644 --- a/dataherald/server/fastapi/__init__.py +++ b/dataherald/server/fastapi/__init__.py @@ -9,21 +9,21 @@ import dataherald from dataherald.api.types import Query from dataherald.config import Settings -from dataherald.db_scanner.models.types import TableSchemaDetail +from dataherald.db_scanner.models.types import TableDescription from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings from dataherald.types import ( + CreateResponseRequest, DatabaseConnectionRequest, - ExecuteTempQueryRequest, GoldenRecord, GoldenRecordRequest, Instruction, InstructionRequest, - NLQueryResponse, + Question, QuestionRequest, + Response, ScannerRequest, TableDescriptionRequest, UpdateInstruction, - UpdateQueryRequest, ) @@ -110,24 +110,31 @@ def __init__(self, settings: Settings): ) self.router.add_api_route( - "/api/v1/question", + "/api/v1/questions", self.answer_question, methods=["POST"], - tags=["Question"], + tags=["Questions"], ) self.router.add_api_route( - "/api/v1/nl-query-responses", - self.get_nl_query_response, + "/api/v1/questions", + self.get_questions, + methods=["GET"], + tags=["Questions"], + ) + + self.router.add_api_route( + "/api/v1/responses", + self.create_response, methods=["POST"], - tags=["NL query responses"], + tags=["Responses"], ) self.router.add_api_route( - "/api/v1/nl-query-responses/{query_id}", - self.update_nl_query_response, - methods=["PATCH"], - tags=["NL query responses"], + "/api/v1/responses", + self.get_responses, + methods=["GET"], + tags=["Responses"], ) self.router.add_api_route( @@ -180,9 +187,12 @@ def scan_db( ) -> bool: return self._api.scan_db(scanner_request, background_tasks) - def answer_question(self, question_request: QuestionRequest) -> NLQueryResponse: + def answer_question(self, question_request: QuestionRequest) -> Response: return self._api.answer_question(question_request) + def get_questions(self, db_connection_id: str | None = None) -> list[Question]: + return self._api.get_questions(db_connection_id) + def root(self) -> dict[str, int]: return {"nanosecond heartbeat": self._api.heartbeat()} @@ -213,7 +223,7 @@ def update_table_description( self, table_description_id: str, table_description_request: TableDescriptionRequest, - ) -> TableSchemaDetail: + ) -> TableDescription: """Add descriptions for tables and columns""" return self._api.update_table_description( table_description_id, table_description_request @@ -221,25 +231,21 @@ def update_table_description( def list_table_descriptions( self, db_connection_id: str | None = None, table_name: str | None = None - ) -> list[TableSchemaDetail]: + ) -> list[TableDescription]: """List table descriptions""" return self._api.list_table_descriptions(db_connection_id, table_name) + def get_responses(self, question_id: str | None = None) -> list[Response]: + """List responses""" + return self._api.get_responses(question_id) + def execute_sql_query(self, query: Query) -> tuple[str, dict]: """Executes a query on the given db_connection_id""" return self._api.execute_sql_query(query) - def update_nl_query_response( - self, query_id: str, query: UpdateQueryRequest - ) -> NLQueryResponse: - """Executes a query on the given db_connection_id""" - return self._api.update_nl_query_response(query_id, query) - - def get_nl_query_response( - self, query_request: ExecuteTempQueryRequest - ) -> NLQueryResponse: + def create_response(self, query_request: CreateResponseRequest) -> Response: """Executes a query on the given db_connection_id""" - return self._api.get_nl_query_response(query_request) + return self._api.create_response(query_request) def delete_golden_record(self, golden_record_id: str) -> dict: """Deletes a golden record""" diff --git a/dataherald/smart_cache/__init__.py b/dataherald/smart_cache/__init__.py index f81289cc..060152d5 100644 --- a/dataherald/smart_cache/__init__.py +++ b/dataherald/smart_cache/__init__.py @@ -3,12 +3,12 @@ from typing import Any, Union from dataherald.config import Component -from dataherald.types import NLQueryResponse +from dataherald.types import Response class SmartCache(Component, ABC): @abstractmethod - def add(self, key: str, value: NLQueryResponse) -> dict[str, Any]: + def add(self, key: str, value: Response) -> dict[str, Any]: """Adds a key-value pair to the cache.""" @abstractmethod diff --git a/dataherald/smart_cache/in_memory.py b/dataherald/smart_cache/in_memory.py index de68cfa8..e4b20ab4 100644 --- a/dataherald/smart_cache/in_memory.py +++ b/dataherald/smart_cache/in_memory.py @@ -5,7 +5,7 @@ from dataherald.config import Settings from dataherald.smart_cache import SmartCache -from dataherald.types import NLQueryResponse +from dataherald.types import Response logger = logging.getLogger(__name__) @@ -19,7 +19,7 @@ def __init__(self, settings: Settings): self.cache = {} @override - def add(self, key: str, value: NLQueryResponse) -> dict[str, Any]: + def add(self, key: str, value: Response) -> dict[str, Any]: logger.info(f"Adding to cache: {key}") self.cache[key] = value return {key: value} diff --git a/dataherald/sql_database/models/types.py b/dataherald/sql_database/models/types.py index cf06d415..3fd25377 100644 --- a/dataherald/sql_database/models/types.py +++ b/dataherald/sql_database/models/types.py @@ -52,7 +52,7 @@ def __getitem__(self, key: str) -> Any: class DatabaseConnection(BaseModel): - id: Any + id: str | None alias: str use_ssh: bool = False uri: str | None diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index 8f8c1fe1..fae2dd8d 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -12,7 +12,7 @@ from dataherald.sql_database.base import SQLDatabase from dataherald.sql_database.models.types import DatabaseConnection from dataherald.sql_generator.create_sql_query_status import create_sql_query_status -from dataherald.types import NLQuery, NLQueryResponse, SQLQueryResult +from dataherald.types import Question, Response, SQLQueryResult from dataherald.utils.strings import contains_line_breaks @@ -25,8 +25,8 @@ def __init__(self, system: System): # noqa: ARG002 self.model = ChatModel(self.system) def create_sql_query_status( - self, db: SQLDatabase, query: str, response: NLQueryResponse - ) -> NLQueryResponse: + self, db: SQLDatabase, query: str, response: Response + ) -> Response: return create_sql_query_status(db, query, response) def format_intermediate_representations( @@ -59,9 +59,9 @@ def format_sql_query(self, sql_query: str) -> str: @abstractmethod def generate_response( self, - user_question: NLQuery, + user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, - ) -> NLQueryResponse: + ) -> Response: """Generates a response to a user question.""" pass diff --git a/dataherald/sql_generator/create_sql_query_status.py b/dataherald/sql_generator/create_sql_query_status.py index 88518034..16ce3b48 100644 --- a/dataherald/sql_generator/create_sql_query_status.py +++ b/dataherald/sql_generator/create_sql_query_status.py @@ -4,12 +4,12 @@ from sqlalchemy import text from dataherald.sql_database.base import SQLDatabase, SQLInjectionError -from dataherald.types import NLQueryResponse, SQLQueryResult +from dataherald.types import Response, SQLQueryResult def create_sql_query_status( - db: SQLDatabase, query: str, response: NLQueryResponse -) -> NLQueryResponse: + db: SQLDatabase, query: str, response: Response +) -> Response: """Find the sql query status and populate the fields sql_query_result, sql_generation_status, and error_message""" if query == "": response.sql_generation_status = "NONE" diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 16ceaeb5..45cc9eef 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -30,14 +30,14 @@ from dataherald.context_store import ContextStore from dataherald.db import DB -from dataherald.db_scanner.models.types import TableDescriptionStatus, TableSchemaDetail -from dataherald.db_scanner.repository.base import DBScannerRepository +from dataherald.db_scanner.models.types import TableDescription, TableDescriptionStatus +from dataherald.db_scanner.repository.base import TableDescriptionRepository from dataherald.sql_database.base import SQLDatabase, SQLInjectionError from dataherald.sql_database.models.types import ( DatabaseConnection, ) from dataherald.sql_generator import SQLGenerator -from dataherald.types import NLQuery, NLQueryResponse +from dataherald.types import Question, Response logger = logging.getLogger(__name__) @@ -230,7 +230,7 @@ class TablesSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): Output: Comma-separated list of tables with their relevance scores, indicating their relevance to the question. Use this tool to identify the relevant tables for the given question. """ - db_scan: List[TableSchemaDetail] + db_scan: List[TableDescription] def get_embedding( self, text: str, model: str = "text-embedding-ada-002" @@ -364,7 +364,7 @@ class SchemaSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): Example Input: table1, table2, table3 """ - db_scan: List[TableSchemaDetail] + db_scan: List[TableDescription] @catch_exceptions() def _run( @@ -403,7 +403,7 @@ class InfoRelevantColumns(BaseSQLDatabaseTool, BaseTool): Example Input: table1 -> column1, table1 -> column2, table2 -> column1 """ - db_scan: List[TableSchemaDetail] + db_scan: List[TableDescription] @catch_exceptions() def _run( @@ -493,7 +493,7 @@ class SQLDatabaseToolkit(BaseToolkit): context: List[dict] | None = Field(exclude=True, default=None) few_shot_examples: List[dict] | None = Field(exclude=True, default=None) instructions: List[dict] | None = Field(exclude=True, default=None) - db_scan: List[TableSchemaDetail] = Field(exclude=True) + db_scan: List[TableDescription] = Field(exclude=True) @property def dialect(self) -> str: @@ -607,10 +607,10 @@ def create_sql_agent( @override def generate_response( self, - user_question: NLQuery, + user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, - ) -> NLQueryResponse: + ) -> Response: start_time = time.time() context_store = self.system.instance(ContextStore) storage = self.system.instance(DB) @@ -618,7 +618,7 @@ def generate_response( database_connection=database_connection, temperature=0, ) - repository = DBScannerRepository(storage) + repository = TableDescriptionRepository(storage) db_scan = repository.get_all_tables_by_db( { "db_connection_id": str(database_connection.id), @@ -658,8 +658,8 @@ def generate_response( except SQLInjectionError as e: raise SQLAlchemyError(e) from e except Exception as e: - return NLQueryResponse( - nl_question_id=user_question.id, + return Response( + question_id=user_question.id, total_tokens=cb.total_tokens, total_cost=cb.total_cost, sql_query="", @@ -679,9 +679,9 @@ def generate_response( logger.info( f"cost: {str(cb.total_cost)} tokens: {str(cb.total_tokens)} time: {str(exec_time)}" ) - response = NLQueryResponse( - nl_question_id=user_question.id, - nl_response=result["output"], + response = Response( + question_id=user_question.id, + response=result["output"], intermediate_steps=intermediate_steps, exec_time=exec_time, total_tokens=cb.total_tokens, diff --git a/dataherald/sql_generator/generates_nl_answer.py b/dataherald/sql_generator/generates_nl_answer.py index 89dd12fb..7bcc8845 100644 --- a/dataherald/sql_generator/generates_nl_answer.py +++ b/dataherald/sql_generator/generates_nl_answer.py @@ -7,10 +7,10 @@ from dataherald.model.chat_model import ChatModel from dataherald.repositories.database_connections import DatabaseConnectionRepository -from dataherald.repositories.nl_question import NLQuestionRepository +from dataherald.repositories.question import QuestionRepository from dataherald.sql_database.base import SQLDatabase from dataherald.sql_generator.create_sql_query_status import create_sql_query_status -from dataherald.types import NLQueryResponse +from dataherald.types import Response SYSTEM_TEMPLATE = """ Given a Question, a Sql query and the sql query result try to answer the question If the sql query result doesn't answer the question just say 'I don't know' @@ -29,23 +29,21 @@ def __init__(self, system, storage): self.storage = storage self.model = ChatModel(self.system) - def execute(self, nl_query_response: NLQueryResponse) -> NLQueryResponse: - nl_question_repository = NLQuestionRepository(self.storage) - nl_question = nl_question_repository.find_by_id( - nl_query_response.nl_question_id - ) + def execute(self, query_response: Response) -> Response: + question_repository = QuestionRepository(self.storage) + question = question_repository.find_by_id(query_response.question_id) db_connection_repository = DatabaseConnectionRepository(self.storage) database_connection = db_connection_repository.find_by_id( - nl_question.db_connection_id + question.db_connection_id ) self.llm = self.model.get_model( database_connection=database_connection, temperature=0, ) database = SQLDatabase.get_sql_engine(database_connection) - nl_query_response = create_sql_query_status( - database, nl_query_response.sql_query, nl_query_response + query_response = create_sql_query_status( + database, query_response.sql_query, query_response ) system_message_prompt = SystemMessagePromptTemplate.from_template( SYSTEM_TEMPLATE @@ -56,9 +54,9 @@ def execute(self, nl_query_response: NLQueryResponse) -> NLQueryResponse: ) chain = LLMChain(llm=self.llm, prompt=chat_prompt) nl_resp = chain.run( - question=nl_question.question, - sql_query=nl_query_response.sql_query, - sql_query_result=str(nl_query_response.sql_query_result), + question=question.question, + sql_query=query_response.sql_query, + sql_query_result=str(query_response.sql_query_result), ) - nl_query_response.nl_response = nl_resp - return nl_query_response + query_response.response = nl_resp + return query_response diff --git a/dataherald/sql_generator/langchain_sqlagent.py b/dataherald/sql_generator/langchain_sqlagent.py index 936e94ac..bad3d822 100644 --- a/dataherald/sql_generator/langchain_sqlagent.py +++ b/dataherald/sql_generator/langchain_sqlagent.py @@ -14,7 +14,7 @@ from dataherald.sql_database.base import SQLDatabase from dataherald.sql_database.models.types import DatabaseConnection from dataherald.sql_generator import SQLGenerator -from dataherald.types import NLQuery, NLQueryResponse +from dataherald.types import Question, Response logger = logging.getLogger(__name__) @@ -25,10 +25,10 @@ class LangChainSQLAgentSQLGenerator(SQLGenerator): @override def generate_response( self, - user_question: NLQuery, + user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, - ) -> NLQueryResponse: # type: ignore + ) -> Response: # type: ignore logger.info(f"Generating SQL response to question: {str(user_question.dict())}") self.llm = self.model.get_model( database_connection=database_connection, temperature=0 @@ -73,9 +73,9 @@ def generate_response( logger.info( f"cost: {str(cb.total_cost)} tokens: {str(cb.total_tokens)} time: {str(exec_time)}" ) - response = NLQueryResponse( - nl_question_id=user_question.id, - nl_response=result["output"], + response = Response( + question_id=user_question.id, + response=result["output"], intermediate_steps=intermediate_steps, exec_time=exec_time, total_tokens=cb.total_tokens, diff --git a/dataherald/sql_generator/langchain_sqlchain.py b/dataherald/sql_generator/langchain_sqlchain.py index 20bda59e..f4e3bac0 100644 --- a/dataherald/sql_generator/langchain_sqlchain.py +++ b/dataherald/sql_generator/langchain_sqlchain.py @@ -11,7 +11,7 @@ from dataherald.sql_database.base import SQLDatabase from dataherald.sql_database.models.types import DatabaseConnection from dataherald.sql_generator import SQLGenerator -from dataherald.types import NLQuery, NLQueryResponse +from dataherald.types import Question, Response logger = logging.getLogger(__name__) @@ -43,10 +43,10 @@ class LangChainSQLChainSQLGenerator(SQLGenerator): @override def generate_response( self, - user_question: NLQuery, + user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, - ) -> NLQueryResponse: + ) -> Response: start_time = time.time() self.llm = self.model.get_model( database_connection=database_connection, temperature=0 @@ -82,9 +82,9 @@ def generate_response( logger.info( f"cost: {str(cb.total_cost)} tokens: {str(cb.total_tokens)} time: {str(exec_time)}" ) - response = NLQueryResponse( - nl_question_id=user_question.id, - nl_response=result["result"], + response = Response( + question_id=user_question.id, + response=result["result"], intermediate_steps=intermediate_steps, exec_time=exec_time, total_cost=cb.total_cost, diff --git a/dataherald/sql_generator/llamaindex.py b/dataherald/sql_generator/llamaindex.py index 3aa996cb..7ef7d4a9 100644 --- a/dataherald/sql_generator/llamaindex.py +++ b/dataherald/sql_generator/llamaindex.py @@ -20,7 +20,7 @@ from dataherald.sql_database.base import SQLDatabase from dataherald.sql_database.models.types import DatabaseConnection from dataherald.sql_generator import SQLGenerator -from dataherald.types import NLQuery, NLQueryResponse +from dataherald.types import Question, Response logger = logging.getLogger(__name__) @@ -31,10 +31,10 @@ class LlamaIndexSQLGenerator(SQLGenerator): @override def generate_response( self, - user_question: NLQuery, + user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, - ) -> NLQueryResponse: + ) -> Response: start_time = time.time() logger.info(f"Generating SQL response to question: {str(user_question.dict())}") self.llm = self.model.get_model( @@ -97,9 +97,9 @@ def generate_response( f"total cost: {str(total_cost)} {str(token_counter.total_llm_token_count)}" ) exec_time = time.time() - start_time - response = NLQueryResponse( - nl_question_id=user_question.id, - nl_response=result.response, + response = Response( + question_id=user_question.id, + response=result.response, exec_time=exec_time, total_tokens=token_counter.total_llm_token_count, total_cost=total_cost, diff --git a/dataherald/tests/db/test_db.py b/dataherald/tests/db/test_db.py index 7bcc615c..713d15be 100644 --- a/dataherald/tests/db/test_db.py +++ b/dataherald/tests/db/test_db.py @@ -89,3 +89,9 @@ def delete_by_id(self, collection: str, id: str) -> int: @override def rename(self, old_collection_name: str, new_collection_name) -> None: pass + + @override + def rename_field( + self, collection_name: str, old_field_name: str, new_field_name: str + ) -> None: + pass diff --git a/dataherald/tests/evaluator/test_eval.py b/dataherald/tests/evaluator/test_eval.py index c51f9e81..505f7d68 100644 --- a/dataherald/tests/evaluator/test_eval.py +++ b/dataherald/tests/evaluator/test_eval.py @@ -4,7 +4,7 @@ from dataherald.config import System from dataherald.eval import Evaluation, Evaluator from dataherald.sql_database.models.types import DatabaseConnection -from dataherald.types import NLQuery, NLQueryResponse +from dataherald.types import Question, Response class TestEvaluator(Evaluator): @@ -14,8 +14,8 @@ def __init__(self, system: System): @override def get_confidence_score( self, - question: NLQuery, - generated_answer: NLQueryResponse, + question: Question, + generated_answer: Response, database_connection: DatabaseConnection, ) -> confloat: score: confloat(ge=0, le=1) = 1.0 @@ -24,8 +24,8 @@ def get_confidence_score( @override def evaluate( self, - question: NLQuery, - generated_answer: NLQueryResponse, + question: Question, + generated_answer: Response, database_connection: DatabaseConnection, ) -> Evaluation: return Evaluation(question_id="0", answer_id="0", score=0.8) diff --git a/dataherald/tests/sql_generator/test_generator.py b/dataherald/tests/sql_generator/test_generator.py index f8e29eec..9862583c 100644 --- a/dataherald/tests/sql_generator/test_generator.py +++ b/dataherald/tests/sql_generator/test_generator.py @@ -5,7 +5,7 @@ from dataherald.config import System from dataherald.sql_database.models.types import DatabaseConnection from dataherald.sql_generator import SQLGenerator -from dataherald.types import NLQuery, NLQueryResponse +from dataherald.types import Question, Response class TestGenerator(SQLGenerator): @@ -15,13 +15,13 @@ def __init__(self, system: System): @override def generate_response( self, - user_question: NLQuery, + user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, # noqa: ARG002 - ) -> NLQueryResponse: - return NLQueryResponse( - nl_question_id=None, - nl_response="Foo response", + ) -> Response: + return Response( + question_id=None, + response="Foo response", intermediate_steps=["foo"], sql_query="bar", ) diff --git a/dataherald/tests/test_api.py b/dataherald/tests/test_api.py index 445a4ddd..b6fe8384 100644 --- a/dataherald/tests/test_api.py +++ b/dataherald/tests/test_api.py @@ -36,7 +36,7 @@ def test_scan_one_table(): def test_answer_question(): response = client.post( - "/api/v1/question", + "/api/v1/questions", json={"question": "Who am I?", "db_connection_id": "64dfa0e103f5134086f7090c"}, ) assert response.status_code == HTTP_200_CODE diff --git a/dataherald/types.py b/dataherald/types.py index 1e46a0ad..344943ba 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -1,5 +1,5 @@ +from datetime import datetime, timezone from enum import Enum -from typing import Any from bson.errors import InvalidId from bson.objectid import ObjectId @@ -20,12 +20,8 @@ def object_id_validation(cls, v: str): return v -class UpdateQueryRequest(BaseModel): - sql_query: str - - -class ExecuteTempQueryRequest(BaseModel): - query_id: str +class CreateResponseRequest(BaseModel): + question_id: str sql_query: str @@ -34,8 +30,8 @@ class SQLQueryResult(BaseModel): rows: list[dict] -class NLQuery(BaseModel): - id: Any +class Question(BaseModel): + id: str | None = None question: str db_connection_id: str @@ -49,7 +45,7 @@ class InstructionRequest(DBConnectionValidation): class Instruction(BaseModel): - id: Any + id: str | None = None instruction: str db_connection_id: str @@ -60,7 +56,7 @@ class GoldenRecordRequest(DBConnectionValidation): class GoldenRecord(BaseModel): - id: Any + id: str | None = None question: str sql_query: str db_connection_id: str @@ -72,10 +68,10 @@ class SQLGenerationStatus(Enum): INVALID = "INVALID" -class NLQueryResponse(BaseModel): - id: Any - nl_question_id: Any - nl_response: str | None = None +class Response(BaseModel): + id: str | None = None + question_id: str | None = None + response: str | None = None intermediate_steps: list[str] | None = None sql_query: str sql_query_result: SQLQueryResult | None @@ -85,7 +81,17 @@ class NLQueryResponse(BaseModel): total_tokens: int | None = None total_cost: float | None = None confidence_score: float | None = None - # date_entered: datetime = datetime.now() add this later + created_at: datetime = datetime.now() + + @validator("created_at", pre=True) + def parse_datetime_with_timezone(cls, value): + if not value: + return None + return value.replace(tzinfo=timezone.utc) # Set the timezone to UTC + + @validator("question_id", pre=True) + def parse_question_id(cls, value): + return str(value) class SupportedDatabase(Enum): diff --git a/docs/api.process_nl_query_response.rst b/docs/api.process_nl_query_response.rst index 6f14c5ab..91cf353a 100644 --- a/docs/api.process_nl_query_response.rst +++ b/docs/api.process_nl_query_response.rst @@ -5,7 +5,7 @@ Once you made a question you can try sending a new sql query to improve the resp Request this ``POST`` endpoint:: - /api/v1/nl-query-responses + /api/v1/responses **Request body** @@ -24,8 +24,8 @@ HTTP 200 code response { "id": "string", - "nl_question_id": "string", - "nl_response": "string", + "question_id": "string", + "response": "string", "intermediate_steps": [ "string" ], @@ -52,7 +52,7 @@ HTTP 200 code response .. code-block:: rst curl -X 'POST' \ - '/api/v1/nl-query-responses' \ + '/api/v1/responses' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ @@ -73,10 +73,8 @@ HTTP 200 code response "id": { "$oid": "64c424fa3f4036441e882352" }, - "nl_question_id": { - "$oid": "64dbd8cf944f867b3c450467" - }, - "nl_response": "The most expensive zip to rent in Los Angeles city is 90210", + "question_id": "64dbd8cf944f867b3c450467", + "response": "The most expensive zip to rent in Los Angeles city is 90210", "intermediate_steps": [ "", ], diff --git a/docs/api.question.rst b/docs/api.question.rst index 10537307..dcf251d5 100644 --- a/docs/api.question.rst +++ b/docs/api.question.rst @@ -6,7 +6,7 @@ you should be able to ask natural language questions to retrieve an accurate res Request this ``POST`` endpoint:: - /api/v1/question + /api/v1/questions **Request body** @@ -24,8 +24,8 @@ HTTP 200 code response { "id": "string", - "nl_question_id": "string", - "nl_response": "string", + "question_id": "string", + "response": "string", "intermediate_steps": [ "string" ], @@ -52,7 +52,7 @@ HTTP 200 code response .. code-block:: rst curl -X 'POST' \ - '/api/v1/question' \ + '/api/v1/questions' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ @@ -68,10 +68,8 @@ HTTP 200 code response "id": { "$oid": "64dbd8f4944f867b3c450468" }, - "nl_question_id": { - "$oid": "64dbd8cf944f867b3c450467" - }, - "nl_response": "The median rent price for single homes in Los Angeles city is approximately $2827.65.", + "question_id": "64dbd8cf944f867b3c450467", + "response": "The median rent price for single homes in Los Angeles city is approximately $2827.65.", "intermediate_steps": [ "", ], diff --git a/docs/api.rst b/docs/api.rst index 0239a804..20ac5249 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -49,9 +49,7 @@ The ``query-response`` object is created from the answering natural language que The related endpoints are: -* :doc:`process_nl_query_response ` -- ``POST api/v1/nl-query-responses`` -* :doc:`update_nl_query_response ` -- ``PATCH api/v1/nl-query-responses/{query_id}`` - +* :doc:`process_nl_query_response ` -- ``POST api/v1/responses`` .. code-block:: json @@ -60,8 +58,8 @@ The related endpoints are: "error_message": "string", "exec_time": "float", "intermediate_steps":["string"], - "nl_question_id": "string", - "nl_response": "string", + "question_id": "string", + "response": "string", "sql_generation_status": "string", "sql_query": "string", "sql_query_result": {}, @@ -132,5 +130,4 @@ Related endpoints are: api.question - api.update_nl_query_response.rst api.process_nl_query_response diff --git a/docs/api.update_nl_query_response.rst b/docs/api.update_nl_query_response.rst deleted file mode 100644 index 9447afa0..00000000 --- a/docs/api.update_nl_query_response.rst +++ /dev/null @@ -1,113 +0,0 @@ -Update a NL query response -============================ - -Once you ask a question, you can give feedback to improve the queries - -Request this ``PATCH`` endpoint:: - - /api/v1/nl-query-responses/{query_id} - -**Parameters** - -.. csv-table:: - :header: "Name", "Type", "Description" - :widths: 15, 10, 30 - - "query_id", "string", "Generated query id, ``Required``" - -**Request body** - -.. code-block:: rst - - { - "sql_query": "string", # required - } - -**Responses** - -HTTP 200 code response - -.. code-block:: rst - - { - "id": "string", - "nl_question_id": "string", - "nl_response": "string", - "intermediate_steps": [ - "string" - ], - "sql_query": "string", - "sql_query_result": { - "columns": [ - "string" - ], - "rows": [ - {} - ] - }, - "sql_generation_status": "NONE", - "error_message": "string", - "exec_time": 0, - "total_tokens": 0, - "total_cost": 0, - "confidence_score": 0 - } - -**Request example** - - -.. code-block:: rst - - curl -X 'POST' \ - '/api/v1/nl-query-responses/64c424fa3f4036441e882352' \ - -H 'accept: application/json' \ - -H 'Content-Type: application/json' \ - -d '{ - "sql_query": "SELECT "dh_zip_code", MAX("metric_value") as max_rent - FROM db_table - WHERE "dh_county_name" = 'Los Angeles' AND "period_start" = '2022-05-01' AND "period_end" = '2022-05-31' - GROUP BY "zip_code" - ORDER BY max_rent DESC - LIMIT 1;" - }' - -**Response example** - -.. code-block:: rst - - { - "id": { - "$oid": "64c424fa3f4036441e882352" - }, - "nl_question_id": { - "$oid": "64dbd8cf944f867b3c450467" - }, - "nl_response": "The most expensive zip to rent in Los Angeles city is 90210", - "intermediate_steps": [ - "", - ], - "sql_query": "SELECT "zip_code", MAX("metric_value") as max_rent - FROM db_table - WHERE "dh_county_name" = 'Los Angeles' AND "period_start" = '2022-05-01' AND "period_end" = '2022-05-31' - GROUP BY "zip_code" - ORDER BY max_rent DESC - LIMIT 1;", - "sql_query_result": { - "columns": [ - "zip_code", - "max_rent" - ], - "rows": [ - { - "zip_code": "90210", - "max_rent": 58279.6479072398192 - } - ] - }, - "sql_generation_status": "VALID", - "error_message": null, - "exec_time": 37.183526277542114, - "total_tokens": 17816, - "total_cost": 1.1087399999999998 - "confidence_score": 0.95 - } diff --git a/docs/api_server.rst b/docs/api_server.rst index 168b2b62..bd645759 100644 --- a/docs/api_server.rst +++ b/docs/api_server.rst @@ -31,15 +31,15 @@ All implementations of the API module must inherit and implement the abstract :c :return: True if the scanning was initiated successfully; otherwise, False. :rtype: bool -.. method:: answer_question(self, question_request: QuestionRequest) -> NLQueryResponse +.. method:: answer_question(self, question_request: QuestionRequest) -> Response :noindex: Provides a response to a user's question based on the provided question request. :param question_request: The question request. :type question_request: QuestionRequest - :return: The NLQueryResponse containing the response to the user's question. - :rtype: NLQueryResponse + :return: The Response containing the response to the user's question. + :rtype: Response .. method:: create_database_connection(self, database_connection_request: DatabaseConnectionRequest) -> bool :noindex: @@ -85,7 +85,7 @@ All implementations of the API module must inherit and implement the abstract :c :return: A tuple containing the query status and result. :rtype: tuple[str, dict] -.. method:: update_query(self, query_id: str, query: UpdateQueryRequest) -> NLQueryResponse +.. method:: update_query(self, query_id: str, query: UpdateQueryRequest) -> Response :noindex: Updates a query using the provided query ID and UpdateQueryRequest. @@ -94,20 +94,20 @@ All implementations of the API module must inherit and implement the abstract :c :type query_id: str :param query: The update query request. :type query: UpdateQueryRequest - :return: The NLQueryResponse containing the result of the query update. - :rtype: NLQueryResponse + :return: The Response containing the result of the query update. + :rtype: Response -.. method:: execute_temp_query(self, query_id: str, query: ExecuteTempQueryRequest) -> NLQueryResponse +.. method:: execute_temp_query(self, query_id: str, query: CreateResponseRequest) -> Response :noindex: - Executes a temporary query using the provided query ID and ExecuteTempQueryRequest. + Executes a temporary query using the provided query ID and CreateResponseRequest. - :param query_id: The ID of the temporary query to execute. - :type query_id: str - :param query: The temporary query request. - :type query: ExecuteTempQueryRequest - :return: The NLQueryResponse containing the result of the temporary query execution. - :rtype: NLQueryResponse + :param question_id: The ID of the question to execute. + :type question_id: str + :param query: The query request. + :type query: CreateResponseRequest + :return: The Response containing the result. + :rtype: Response .. method:: get_scanned_databases(self, db_connection_id: str) -> ScannedDBResponse :noindex: diff --git a/docs/evaluator.rst b/docs/evaluator.rst index ab16d1c3..21c88357 100644 --- a/docs/evaluator.rst +++ b/docs/evaluator.rst @@ -70,9 +70,9 @@ All implementations of the Evaluation component must inherit from the ``Evaluato Determines if a generated response from the engine is acceptable based on the ACCEPTANCE_THRESHOLD. :param question: The natural language question. - :type question: NLQuery + :type question: Question :param generated_answer: The generated SQL query response. - :type generated_answer: NLQueryResponse + :type generated_answer: Response :param database_connection: The database connection. :type database_connection: DatabaseConnection :return: The confidence score. @@ -83,9 +83,9 @@ All implementations of the Evaluation component must inherit from the ``Evaluato Abstract method to evaluate a question with an SQL pair. Subclasses must implement this method. :param question: The natural language question. - :type question: NLQuery + :type question: Question :param generated_answer: The generated SQL query response. - :type generated_answer: NLQueryResponse + :type generated_answer: Response :param database_connection: The database connection. :type database_connection: DatabaseConnection :return: An Evaluation instance. diff --git a/docs/quickstart.rst b/docs/quickstart.rst index e0114e08..b40dd0c3 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -139,12 +139,12 @@ The details of how to use these endpoints are outside the scope of this quicksta Querying the Database in Natural Language ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Once you have connected the engine to your data warehouse (and preferably added some context to the store), you can query your data warehouse using the ``POST /api/v1/question`` endpoint. +Once you have connected the engine to your data warehouse (and preferably added some context to the store), you can query your data warehouse using the ``POST /api/v1/questions`` endpoint. .. code-block:: rst curl -X 'POST' \ - '/api/v1/question' \ + '/api/v1/questions' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ diff --git a/docs/text_to_sql_engine.rst b/docs/text_to_sql_engine.rst index f09601f9..1627b2c5 100644 --- a/docs/text_to_sql_engine.rst +++ b/docs/text_to_sql_engine.rst @@ -51,10 +51,10 @@ This base class defines the common structure for SQL generation classes. :type db: SQLDatabase :param query: The SQL query. :type query: str - :param response: The NLQueryResponse instance. - :type response: NLQueryResponse - :return: The updated NLQueryResponse instance with the SQL query status. - :rtype: NLQueryResponse + :param response: The Response instance. + :type response: Response + :return: The updated Response instance with the SQL query status. + :rtype: Response .. method:: generate_response(user_question, database_connection, context=None) :noindex: @@ -62,13 +62,13 @@ This base class defines the common structure for SQL generation classes. Generates a response to a user question based on the given user question, database connection, and optional context. :param user_question: The user's natural language question. - :type user_question: NLQuery + :type user_question: Question :param database_connection: The database connection information. :type database_connection: DatabaseConnection :param context: (Optional) Additional context information. :type context: List[dict], optional - :return: The NLQueryResponse containing the generated response. - :rtype: NLQueryResponse + :return: The Response containing the generated response. + :rtype: Response