diff --git a/dataherald/sql_generator/dataherald_finetuned_agent.py b/dataherald/sql_generator/dataherald_finetuned_agent.py new file mode 100644 index 00000000..4c9c9bb5 --- /dev/null +++ b/dataherald/sql_generator/dataherald_finetuned_agent.py @@ -0,0 +1,35 @@ +import os +import time +from typing import List, Tuple + +from overrides import override + +from dataherald.sql_database.models.types import ( + DatabaseConnection, +) +from dataherald.sql_generator import SQLGenerator +from dataherald.types import Question, Response + + +class DataheraldFineTunedAgent(SQLGenerator): + + @override + def generate_response( + self, + user_question: Question, + database_connection: DatabaseConnection, + context: Tuple[List[dict] | None, List[dict] | None], + ) -> Response: + start_time = time.time() + instructions = context[1] + + self.short_context_llm = self.model.get_model( + database_connection=database_connection, + temperature=0, + model_name=os.getenv("LLM_MODEL", "gpt-4"), + ) + self.long_context_llm = self.model.get_model( + database_connection=database_connection, + temperature=0, + model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k") + )