From 019c696ee8696604a23f0bfa37e52783ad165d1b Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Fri, 15 Dec 2023 18:20:41 -0500 Subject: [PATCH] DH-5080/refactor_engine_with_new_resources --- dataherald/context_store/__init__.py | 4 +- dataherald/context_store/default.py | 12 +- dataherald/sql_generator/__init__.py | 27 +---- .../sql_generator/create_sql_query_status.py | 105 +++--------------- .../dataherald_finetuning_agent.py | 43 ++++--- .../sql_generator/dataherald_sqlagent.py | 45 ++++---- .../sql_generator/generates_nl_answer.py | 96 ++++++++++------ .../sql_generator/langchain_sqlagent.py | 31 +++--- .../sql_generator/langchain_sqlchain.py | 31 +++--- dataherald/sql_generator/llamaindex.py | 35 +++--- 10 files changed, 172 insertions(+), 257 deletions(-) diff --git a/dataherald/context_store/__init__.py b/dataherald/context_store/__init__.py index 6189c8ca..ce644633 100644 --- a/dataherald/context_store/__init__.py +++ b/dataherald/context_store/__init__.py @@ -4,7 +4,7 @@ from dataherald.config import Component, System from dataherald.db import DB -from dataherald.types import GoldenRecord, GoldenRecordRequest, Question +from dataherald.types import GoldenRecord, GoldenRecordRequest, Prompt from dataherald.vector_store import VectorStore @@ -24,7 +24,7 @@ def __init__(self, system: System): @abstractmethod def retrieve_context_for_question( - self, nl_question: Question, number_of_samples: int = 3 + self, prompt: Prompt, number_of_samples: int = 3 ) -> Tuple[List[dict] | None, List[dict] | None]: pass diff --git a/dataherald/context_store/default.py b/dataherald/context_store/default.py index 0e9cbc3b..a057305b 100644 --- a/dataherald/context_store/default.py +++ b/dataherald/context_store/default.py @@ -8,7 +8,7 @@ from dataherald.context_store import ContextStore from dataherald.repositories.golden_records import GoldenRecordRepository from dataherald.repositories.instructions import InstructionRepository -from dataherald.types import GoldenRecord, GoldenRecordRequest, Question +from dataherald.types import GoldenRecord, GoldenRecordRequest, Prompt logger = logging.getLogger(__name__) @@ -19,12 +19,12 @@ def __init__(self, system: System): @override def retrieve_context_for_question( - self, nl_question: Question, number_of_samples: int = 3 + self, prompt: Prompt, number_of_samples: int = 3 ) -> Tuple[List[dict] | None, List[dict] | None]: - logger.info(f"Getting context for {nl_question.question}") + logger.info(f"Getting context for {prompt.text}") closest_questions = self.vector_store.query( - query_texts=[nl_question.question], - db_connection_id=nl_question.db_connection_id, + query_texts=[prompt.text], + db_connection_id=prompt.db_connection_id, collection=self.golden_record_collection, num_results=number_of_samples, ) @@ -47,7 +47,7 @@ def retrieve_context_for_question( instruction_repository = InstructionRepository(self.db) all_instructions = instruction_repository.find_all() for instruction in all_instructions: - if instruction.db_connection_id == nl_question.db_connection_id: + if instruction.db_connection_id == prompt.db_connection_id: instructions.append( { "instruction": instruction.instruction, diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index 35ce895d..7899a5b9 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -1,7 +1,6 @@ """Base class that all sql generation classes inherit from.""" import re from abc import ABC, abstractmethod -from datetime import date, datetime from typing import Any, List, Tuple import sqlparse @@ -12,7 +11,7 @@ from dataherald.sql_database.base import SQLDatabase from dataherald.sql_database.models.types import DatabaseConnection from dataherald.sql_generator.create_sql_query_status import create_sql_query_status -from dataherald.types import Question, Response, SQLQueryResult +from dataherald.types import Prompt, SQLGeneration from dataherald.utils.strings import contains_line_breaks @@ -44,22 +43,9 @@ def remove_markdown(self, query: str) -> str: return query def create_sql_query_status( - self, - db: SQLDatabase, - query: str, - response: Response, - top_k: int = None, - generate_csv: bool = False, - database_connection: DatabaseConnection | None = None, - ) -> Response: - return create_sql_query_status( - db, - query, - response, - top_k, - generate_csv, - database_connection=database_connection, - ) + self, db: SQLDatabase, query: str, sql_generation: SQLGeneration + ) -> SQLGeneration: + return create_sql_query_status(db, query, sql_generation) def format_intermediate_representations( self, intermediate_representation: List[Tuple[AgentAction, str]] @@ -91,10 +77,9 @@ def format_sql_query(self, sql_query: str) -> str: @abstractmethod def generate_response( self, - user_question: Question, + user_prompt: Prompt, database_connection: DatabaseConnection, context: List[dict] = None, - generate_csv: bool = False, - ) -> Response: + ) -> SQLGeneration: """Generates a response to a user question.""" pass diff --git a/dataherald/sql_generator/create_sql_query_status.py b/dataherald/sql_generator/create_sql_query_status.py index 18e3da4a..0095f065 100644 --- a/dataherald/sql_generator/create_sql_query_status.py +++ b/dataherald/sql_generator/create_sql_query_status.py @@ -1,116 +1,43 @@ -import csv -import uuid -from datetime import date, datetime -from decimal import Decimal - from sqlalchemy import text -from dataherald.config import Settings from dataherald.sql_database.base import SQLDatabase, SQLInjectionError -from dataherald.sql_database.models.types import DatabaseConnection -from dataherald.types import Response, SQLQueryResult -from dataherald.utils.s3 import S3 +from dataherald.types import SQLGeneration -def format_error_message(response: Response, error_message: str) -> Response: +def format_error_message( + sql_generation: SQLGeneration, error_message: str +) -> SQLGeneration: # Remove the complete query if error_message.find("[") > 0 and error_message.find("]") > 0: error_message = ( error_message[0 : error_message.find("[")] + error_message[error_message.rfind("]") + 1 :] ) - response.sql_generation_status = "INVALID" - response.response = "" - response.sql_query_result = None - response.error_message = error_message - return response - - -def create_csv_file( - generate_csv: bool, - columns: list, - rows: list, - response: Response, - database_connection: DatabaseConnection | None = None, -): - if generate_csv: - file_location = f"tmp/{str(uuid.uuid4())}.csv" - with open(file_location, "w", newline="") as file: - writer = csv.writer(file) - - writer.writerow(rows[0].keys()) - for row in rows: - writer.writerow(row.values()) - if Settings().only_store_csv_files_locally: - response.csv_file_path = file_location - else: - s3 = S3() - response.csv_file_path = s3.upload( - file_location, database_connection.file_storage - ) - response.sql_query_result = SQLQueryResult(columns=columns, rows=rows) + sql_generation.status = "INVALID" + sql_generation.error = error_message + return sql_generation def create_sql_query_status( db: SQLDatabase, query: str, - response: Response, - top_k: int = None, - generate_csv: bool = False, - database_connection: DatabaseConnection | None = None, -) -> Response: + sql_generation: SQLGeneration, +) -> SQLGeneration: """Find the sql query status and populate the fields sql_query_result, sql_generation_status, and error_message""" if query == "": - response.sql_generation_status = "INVALID" - response.response = "" - response.sql_query_result = None - response.error_message = None + sql_generation.status = "INVALID" + sql_generation.error = "Empty query" else: try: query = db.parser_to_filter_commands(query) with db._engine.connect() as connection: - execution = connection.execute(text(query)) - columns = execution.keys() - if top_k: - result = execution.fetchmany(top_k) - else: - result = execution.fetchall() - if len(result) == 0: - response.sql_query_result = None - else: - columns = [item for item in columns] # noqa: C416 - rows = [] - for row in result: - modified_row = {} - for key, value in zip(row.keys(), row, strict=True): - if type(value) in [ - date, - datetime, - ]: # Check if the value is an instance of datetime.date - modified_row[key] = str(value) - elif ( - type(value) is Decimal - ): # Check if the value is an instance of decimal.Decimal - modified_row[key] = float(value) - else: - modified_row[key] = value - rows.append(modified_row) - - create_csv_file( - generate_csv, - columns, - rows, - response, - database_connection, - ) - - response.sql_generation_status = "VALID" - response.error_message = None + connection.execute(text(query)) + sql_generation.status = "VALID" + sql_generation.error = None except SQLInjectionError as e: raise SQLInjectionError( "Sensitive SQL keyword detected in the query." ) from e except Exception as e: - response = format_error_message(response, str(e)) - - return response + sql_generation = format_error_message(sql_generation, str(e)) + return sql_generation diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index 6bd1924e..48a3cbc4 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -32,7 +32,7 @@ DatabaseConnection, ) from dataherald.sql_generator import EngineTimeOutORItemLimitError, SQLGenerator -from dataherald.types import Question, Response +from dataherald.types import Prompt, SQLGeneration from dataherald.utils.agent_prompts import ( FINETUNING_AGENT_PREFIX, FINETUNING_AGENT_SUFFIX, @@ -436,11 +436,10 @@ def create_sql_agent( @override def generate_response( self, - user_question: Question, + user_prompt: Prompt, database_connection: DatabaseConnection, context: List[dict] = None, # noqa: ARG002 - generate_csv: bool = False, - ) -> Response: + ) -> SQLGeneration: """ generate_response generates a response to a user question using a Finetuning model. @@ -470,7 +469,7 @@ def generate_response( if not db_scan: raise ValueError("No scanned tables found for database") _, instructions = context_store.retrieve_context_for_question( - user_question, number_of_samples=1 + user_prompt, number_of_samples=1 ) self.database = SQLDatabase.get_sql_engine(database_connection) @@ -489,21 +488,21 @@ def generate_response( agent_executor.handle_parsing_errors = True with get_openai_callback() as cb: try: - result = agent_executor({"input": user_question.question}) + result = agent_executor({"input": user_prompt.text}) result = self.check_for_time_out_or_tool_limit(result) except SQLInjectionError as e: raise SQLInjectionError(e) from e except EngineTimeOutORItemLimitError as e: raise EngineTimeOutORItemLimitError(e) from e except Exception as e: - return Response( - question_id=user_question.id, - total_tokens=cb.total_tokens, - total_cost=cb.total_cost, - sql_query="", - sql_generation_status="INVALID", - sql_query_result=None, - error_message=str(e), + return SQLGeneration( + prompt_id=user_prompt.id, + tokens_used=cb.total_tokens, + model="FineTuning_Agent", + completed_at=datetime.datetime.now(), + sql="", + status="INVALID", + error=str(e), ) if "```sql" in result["output"]: sql_query = self.remove_markdown(result["output"]) @@ -514,17 +513,15 @@ def generate_response( sql_query = self.remove_markdown(sql_query) sql_query = self.format_sql_query(action.tool_input) logger.info(f"cost: {str(cb.total_cost)} tokens: {str(cb.total_tokens)}") - response = Response( - question_id=user_question.id, - total_tokens=cb.total_tokens, - total_cost=cb.total_cost, - sql_query=sql_query, + response = SQLGeneration( + prompt_id=user_prompt.id, + tokens_used=cb.total_tokens, + model="RAG_AGENT", + completed_at=datetime.datetime.now(), + sql=sql_query, ) return self.create_sql_query_status( self.database, - response.sql_query, + response.sql, response, - top_k=TOP_K, - generate_csv=generate_csv, - database_connection=database_connection, ) diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 64cdb1b9..17f9f014 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -39,7 +39,7 @@ DatabaseConnection, ) from dataherald.sql_generator import EngineTimeOutORItemLimitError, SQLGenerator -from dataherald.types import Question, Response +from dataherald.types import Prompt, SQLGeneration from dataherald.utils.agent_prompts import ( AGENT_PREFIX, FORMAT_INSTRUCTIONS, @@ -598,11 +598,10 @@ def create_sql_agent( @override def generate_response( self, - user_question: Question, + user_prompt: Prompt, database_connection: DatabaseConnection, context: List[dict] = None, - generate_csv: bool = False, - ) -> Response: + ) -> SQLGeneration: context_store = self.system.instance(ContextStore) storage = self.system.instance(DB) self.llm = self.model.get_model( @@ -620,7 +619,7 @@ def generate_response( if not db_scan: raise ValueError("No scanned tables found for database") few_shot_examples, instructions = context_store.retrieve_context_for_question( - user_question, number_of_samples=self.max_number_of_examples + user_prompt, number_of_samples=self.max_number_of_examples ) if few_shot_examples is not None: new_fewshot_examples = self.remove_duplicate_examples(few_shot_examples) @@ -628,7 +627,7 @@ def generate_response( else: new_fewshot_examples = None number_of_samples = 0 - logger.info(f"Generating SQL response to question: {str(user_question.dict())}") + logger.info(f"Generating SQL response to question: {str(user_prompt.dict())}") self.database = SQLDatabase.get_sql_engine(database_connection) toolkit = SQLDatabaseToolkit( db=self.database, @@ -652,21 +651,21 @@ def generate_response( agent_executor.handle_parsing_errors = True with get_openai_callback() as cb: try: - result = agent_executor({"input": user_question.question}) + result = agent_executor({"input": user_prompt.text}) result = self.check_for_time_out_or_tool_limit(result) except SQLInjectionError as e: raise SQLInjectionError(e) from e except EngineTimeOutORItemLimitError as e: raise EngineTimeOutORItemLimitError(e) from e except Exception as e: - return Response( - question_id=user_question.id, - total_tokens=cb.total_tokens, - total_cost=cb.total_cost, - sql_query="", - sql_generation_status="INVALID", - sql_query_result=None, - error_message=str(e), + return SQLGeneration( + prompt_id=user_prompt.id, + tokens_used=cb.total_tokens, + model="RAG_AGENT", + completed_at=datetime.datetime.now(), + sql="", + status="INVALID", + error=str(e), ) if "```sql" in result["output"]: sql_query = self.remove_markdown(result["output"]) @@ -677,17 +676,15 @@ def generate_response( sql_query = self.remove_markdown(sql_query) sql_query = self.format_sql_query(action.tool_input) logger.info(f"cost: {str(cb.total_cost)} tokens: {str(cb.total_tokens)}") - response = Response( - question_id=user_question.id, - total_tokens=cb.total_tokens, - total_cost=cb.total_cost, - sql_query=sql_query, + response = SQLGeneration( + prompt_id=user_prompt.id, + tokens_used=cb.total_tokens, + model="RAG_AGENT", + completed_at=datetime.datetime.now(), + sql=sql_query, ) return self.create_sql_query_status( self.database, - response.sql_query, + response.sql, response, - top_k=TOP_K, - generate_csv=generate_csv, - database_connection=database_connection, ) diff --git a/dataherald/sql_generator/generates_nl_answer.py b/dataherald/sql_generator/generates_nl_answer.py index 89500f0d..2b550455 100644 --- a/dataherald/sql_generator/generates_nl_answer.py +++ b/dataherald/sql_generator/generates_nl_answer.py @@ -1,4 +1,6 @@ import os +from datetime import date, datetime +from decimal import Decimal from langchain.chains import LLMChain from langchain.prompts.chat import ( @@ -6,13 +8,13 @@ HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) +from sqlalchemy import text from dataherald.model.chat_model import ChatModel from dataherald.repositories.database_connections import DatabaseConnectionRepository from dataherald.repositories.question import QuestionRepository -from dataherald.sql_database.base import SQLDatabase -from dataherald.sql_generator.create_sql_query_status import create_sql_query_status -from dataherald.types import Response +from dataherald.sql_database.base import SQLDatabase, SQLInjectionError +from dataherald.types import NLGeneration, SQLGeneration SYSTEM_TEMPLATE = """ Given a Question, a Sql query and the sql query result try to answer the question If the sql query result doesn't answer the question just say 'I don't know' @@ -33,12 +35,11 @@ def __init__(self, system, storage): def execute( self, - query_response: Response, - sql_response_only: bool = False, - generate_csv: bool = False, - ) -> Response: + sql_generation: SQLGeneration, + top_k: int = 100, + ) -> NLGeneration: question_repository = QuestionRepository(self.storage) - question = question_repository.find_by_id(query_response.question_id) + question = question_repository.find_by_id(sql_generation.prompt_id) db_connection_repository = DatabaseConnectionRepository(self.storage) database_connection = db_connection_repository.find_by_id( @@ -51,35 +52,56 @@ def execute( ) database = SQLDatabase.get_sql_engine(database_connection) - if not query_response.sql_query_result: - query_response = create_sql_query_status( - database, - query_response.sql_query, - query_response, - top_k=int(os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", "50")), - generate_csv=generate_csv, - database_connection=database_connection, + if sql_generation.status == "INVALID": + return NLGeneration( + sql_generation_id=sql_generation.id, + nl_answer="I don't know", + created_at=datetime.now(), ) - if query_response.csv_file_path: - query_response.response = None - return query_response + try: + query = database.parser_to_filter_commands(sql_generation.sql) + with database._engine.connect() as connection: + execution = connection.execute(text(query)) + result = execution.fetchmany(top_k) + rows = [] + for row in result: + modified_row = {} + for key, value in zip(row.keys(), row, strict=True): + if type(value) in [ + date, + datetime, + ]: # Check if the value is an instance of datetime.date + modified_row[key] = str(value) + elif ( + type(value) is Decimal + ): # Check if the value is an instance of decimal.Decimal + modified_row[key] = float(value) + else: + modified_row[key] = value + rows.append(modified_row) - if not sql_response_only: - system_message_prompt = SystemMessagePromptTemplate.from_template( - SYSTEM_TEMPLATE - ) - human_message_prompt = HumanMessagePromptTemplate.from_template( - HUMAN_TEMPLATE - ) - chat_prompt = ChatPromptTemplate.from_messages( - [system_message_prompt, human_message_prompt] - ) - chain = LLMChain(llm=self.llm, prompt=chat_prompt) - nl_resp = chain.run( - question=question.question, - sql_query=query_response.sql_query, - sql_query_result=str(query_response.sql_query_result), - ) - query_response.response = nl_resp - return query_response + except SQLInjectionError as e: + raise SQLInjectionError( + "Sensitive SQL keyword detected in the query." + ) from e + + system_message_prompt = SystemMessagePromptTemplate.from_template( + SYSTEM_TEMPLATE + ) + human_message_prompt = HumanMessagePromptTemplate.from_template(HUMAN_TEMPLATE) + chat_prompt = ChatPromptTemplate.from_messages( + [system_message_prompt, human_message_prompt] + ) + chain = LLMChain(llm=self.llm, prompt=chat_prompt) + nl_resp = chain.run( + question=question.question, + sql_query=sql_generation.sql, + sql_query_result="\n".join([str(row) for row in rows]), + ) + + return NLGeneration( + sql_generation_id=sql_generation.id, + nl_answer=nl_resp, + created_at=datetime.now(), + ) diff --git a/dataherald/sql_generator/langchain_sqlagent.py b/dataherald/sql_generator/langchain_sqlagent.py index 2235f2d1..5c26548a 100644 --- a/dataherald/sql_generator/langchain_sqlagent.py +++ b/dataherald/sql_generator/langchain_sqlagent.py @@ -1,5 +1,6 @@ """A wrapper for the SQL generation functions in langchain""" +import datetime import logging import os import time @@ -15,7 +16,7 @@ from dataherald.sql_database.base import SQLDatabase from dataherald.sql_database.models.types import DatabaseConnection from dataherald.sql_generator import SQLGenerator -from dataherald.types import Question, Response +from dataherald.types import Prompt, SQLGeneration logger = logging.getLogger(__name__) @@ -26,12 +27,11 @@ class LangChainSQLAgentSQLGenerator(SQLGenerator): @override def generate_response( self, - user_question: Question, + user_prompt: Prompt, database_connection: DatabaseConnection, context: List[dict] = None, - generate_csv: bool = False, - ) -> Response: # type: ignore - logger.info(f"Generating SQL response to question: {str(user_question.dict())}") + ) -> SQLGeneration: # type: ignore + logger.info(f"Generating SQL response to question: {str(user_prompt.dict())}") self.llm = self.model.get_model( database_connection=database_connection, temperature=0, @@ -58,10 +58,10 @@ def generate_response( ) question_with_context = ( - f"{user_question.question} An example of a similar question and the query that was generated \ + f"{user_prompt.text} An example of a similar question and the query that was generated \ to answer it is the following {samples_prompt_string}" if context is not None - else user_question.question + else user_prompt.text ) with get_openai_callback() as cb: result = agent_executor(question_with_context) @@ -74,18 +74,15 @@ def generate_response( logger.info( f"cost: {str(cb.total_cost)} tokens: {str(cb.total_tokens)} time: {str(exec_time)}" ) - response = Response( - question_id=user_question.id, - response=result["output"], - exec_time=exec_time, - total_tokens=cb.total_tokens, - total_cost=cb.total_cost, - sql_query=sql_query_list[-1] if len(sql_query_list) > 0 else "", + response = SQLGeneration( + prompt_id=user_prompt.id, + tokens_used=cb.total_tokens, + model="LANGCHAIN_SQLAGENT", + completed_at=datetime.datetime.now(), + sql=sql_query_list[-1] if len(sql_query_list) > 0 else "", ) return self.create_sql_query_status( self.database, - response.sql_query, + response.sql, response, - generate_csv=generate_csv, - database_connection=database_connection, ) diff --git a/dataherald/sql_generator/langchain_sqlchain.py b/dataherald/sql_generator/langchain_sqlchain.py index c001a147..b41a0d70 100644 --- a/dataherald/sql_generator/langchain_sqlchain.py +++ b/dataherald/sql_generator/langchain_sqlchain.py @@ -1,5 +1,6 @@ """A wrapper for the SQL generation functions in langchain""" +import datetime import logging import os import time @@ -12,7 +13,7 @@ from dataherald.sql_database.base import SQLDatabase from dataherald.sql_database.models.types import DatabaseConnection from dataherald.sql_generator import SQLGenerator -from dataherald.types import Question, Response +from dataherald.types import Prompt, SQLGeneration logger = logging.getLogger(__name__) @@ -44,11 +45,10 @@ class LangChainSQLChainSQLGenerator(SQLGenerator): @override def generate_response( self, - user_question: Question, + user_prompt: Prompt, database_connection: DatabaseConnection, context: List[dict] = None, - generate_csv: bool = False, - ) -> Response: + ) -> SQLGeneration: start_time = time.time() self.llm = self.model.get_model( database_connection=database_connection, @@ -57,7 +57,7 @@ def generate_response( ) self.database = SQLDatabase.get_sql_engine(database_connection) logger.info( - f"Generating SQL response to question: {str(user_question.dict())} with passed context {context}" + f"Generating SQL response to question: {str(user_prompt.dict())} with passed context {context}" ) if context is not None: samples_prompt_string = "The following are some similar previous questions and their correct SQL queries from these databases: \ @@ -68,10 +68,10 @@ def generate_response( ) prompt = PROMPT_WITH_CONTEXT.format( - user_question=user_question.question, context=samples_prompt_string + user_question=user_prompt.text, context=samples_prompt_string ) else: - prompt = PROMPT_WITHOUT_CONTEXT.format(user_question=user_question.question) + prompt = PROMPT_WITHOUT_CONTEXT.format(user_question=user_prompt.text) # should top_k be an argument? db_chain = SQLDatabaseChain.from_llm( self.llm, self.database, top_k=3, return_intermediate_steps=True @@ -83,18 +83,15 @@ def generate_response( logger.info( f"cost: {str(cb.total_cost)} tokens: {str(cb.total_tokens)} time: {str(exec_time)}" ) - response = Response( - question_id=user_question.id, - response=result["result"], - exec_time=exec_time, - total_cost=cb.total_cost, - total_tokens=cb.total_tokens, - sql_query=self.format_sql_query(result["intermediate_steps"][1]), + response = SQLGeneration( + prompt_id=user_prompt.id, + tokens_used=cb.total_tokens, + model="LANGCHAIN_SQLCHAIN", + completed_at=datetime.datetime.now(), + sql=self.format_sql_query(result["intermediate_steps"][1]), ) return self.create_sql_query_status( self.database, - response.sql_query, + response.sql, response, - generate_csv=generate_csv, - database_connection=database_connection, ) diff --git a/dataherald/sql_generator/llamaindex.py b/dataherald/sql_generator/llamaindex.py index 6818f885..e4f08eaa 100644 --- a/dataherald/sql_generator/llamaindex.py +++ b/dataherald/sql_generator/llamaindex.py @@ -1,8 +1,7 @@ """A wrapper for the SQL generation functions in langchain""" - +import datetime import logging import os -import time from typing import Any, List import tiktoken @@ -21,7 +20,7 @@ from dataherald.sql_database.base import SQLDatabase from dataherald.sql_database.models.types import DatabaseConnection from dataherald.sql_generator import SQLGenerator -from dataherald.types import Question, Response +from dataherald.types import Prompt, SQLGeneration logger = logging.getLogger(__name__) @@ -32,13 +31,11 @@ class LlamaIndexSQLGenerator(SQLGenerator): @override def generate_response( self, - user_question: Question, + user_prompt: Prompt, database_connection: DatabaseConnection, context: List[dict] = None, - generate_csv: bool = False, - ) -> Response: - start_time = time.time() - logger.info(f"Generating SQL response to question: {str(user_question.dict())}") + ) -> SQLGeneration: + logger.info(f"Generating SQL response to question: {str(user_prompt.dict())}") self.llm = self.model.get_model( database_connection=database_connection, temperature=0, @@ -64,10 +61,10 @@ def generate_response( f"Question: {sample['nl_question']} \nSQL: {sample['sql_query']} \n" ) question_with_context = ( - f"{user_question.question} An example of a similar question and the query that was generated to answer it \ + f"{user_prompt.text} An example of a similar question and the query that was generated to answer it \ is the following {samples_prompt_string}" if context is not None - else user_question.question + else user_prompt.text ) for table_name in metadata_obj.tables.keys(): table_schema_objs.append(SQLTableSchema(table_name=table_name)) @@ -100,19 +97,15 @@ def generate_response( logger.info( f"total cost: {str(total_cost)} {str(token_counter.total_llm_token_count)}" ) - exec_time = time.time() - start_time - response = Response( - question_id=user_question.id, - response=result.response, - exec_time=exec_time, - total_tokens=token_counter.total_llm_token_count, - total_cost=total_cost, - sql_query=self.format_sql_query(result.metadata["sql_query"]), + response = SQLGeneration( + prompt_id=user_prompt.id, + tokens_used=token_counter.total_llm_token_count, + model="LLAMA_INDEX", + completed_at=datetime.datetime.now(), + sql=self.format_sql_query(result.metadata["sql_query"]), ) return self.create_sql_query_status( self.database, - response.sql_query, + response.sql, response, - generate_csv=generate_csv, - database_connection=database_connection, )