From 09b56423086be0cb578486f80f9bb16da31b67da Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Mon, 16 Oct 2023 15:06:25 -0400 Subject: [PATCH] DH-4851/using gpt-3.5--turbo for nl answer generation --- dataherald/eval/eval_agent.py | 5 ++++- dataherald/eval/simple_evaluator.py | 5 ++++- dataherald/model/__init__.py | 3 ++- dataherald/model/base_models.py | 12 ++++++------ dataherald/model/chat_model.py | 4 ++-- dataherald/sql_generator/dataherald_sqlagent.py | 1 + dataherald/sql_generator/generates_nl_answer.py | 3 +++ dataherald/sql_generator/langchain_sqlagent.py | 5 ++++- dataherald/sql_generator/langchain_sqlchain.py | 5 ++++- dataherald/sql_generator/llamaindex.py | 5 ++++- 10 files changed, 34 insertions(+), 14 deletions(-) diff --git a/dataherald/eval/eval_agent.py b/dataherald/eval/eval_agent.py index a48c2ea1..a3588464 100644 --- a/dataherald/eval/eval_agent.py +++ b/dataherald/eval/eval_agent.py @@ -1,4 +1,5 @@ import logging +import os import re import time from difflib import SequenceMatcher @@ -248,7 +249,9 @@ def evaluate( f"Generating score for the question/sql pair: {str(question.question)}/ {str(generated_answer.sql_query)}" ) self.llm = self.model.get_model( - database_connection=database_connection, temperature=0 + database_connection=database_connection, + model_name = os.environ.get("LLM_MODEL", "gpt-4-32k"), + temperature=0 ) database = SQLDatabase.get_sql_engine(database_connection) user_question = question.question diff --git a/dataherald/eval/simple_evaluator.py b/dataherald/eval/simple_evaluator.py index 7b54d418..b9e6f9d0 100644 --- a/dataherald/eval/simple_evaluator.py +++ b/dataherald/eval/simple_evaluator.py @@ -1,4 +1,5 @@ import logging +import os import re import time from typing import Any @@ -101,7 +102,9 @@ def evaluate( } ) self.llm = self.model.get_model( - database_connection=database_connection, temperature=0 + database_connection=database_connection, + model_name = os.environ.get("LLM_MODEL", "gpt-4-32k"), + temperature=0 ) start_time = time.time() system_message_prompt = SystemMessagePromptTemplate.from_template( diff --git a/dataherald/model/__init__.py b/dataherald/model/__init__.py index 9712298c..2069ff7d 100644 --- a/dataherald/model/__init__.py +++ b/dataherald/model/__init__.py @@ -16,7 +16,8 @@ def __init__(self, system: System): def get_model( self, database_connection: DatabaseConnection, - model_family="openai", + model_family: str ="openai", + model_name: str = "gpt-4-32k", **kwargs: Any ) -> Any: pass diff --git a/dataherald/model/base_models.py b/dataherald/model/base_models.py index e27dedb9..5ecb3bdd 100644 --- a/dataherald/model/base_models.py +++ b/dataherald/model/base_models.py @@ -12,7 +12,6 @@ class BaseModel(LLMModel): def __init__(self, system): super().__init__(system) - self.model_name = os.environ.get("LLM_MODEL", "text-davinci-003") self.openai_api_key = os.environ.get("OPENAI_API_KEY") self.aleph_alpha_api_key = os.environ.get("ALEPH_ALPHA_API_KEY") self.anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY") @@ -22,7 +21,8 @@ def __init__(self, system): def get_model( self, database_connection: DatabaseConnection, - model_family="openai", + model_family: str = "openai", + model_name: str = "davinchi-003", **kwargs: Any ) -> Any: if database_connection.llm_credentials is not None: @@ -37,13 +37,13 @@ def get_model( elif model_family == "google": self.google_api_key = api_key if self.openai_api_key: - self.model = OpenAI(model_name=self.model_name, **kwargs) + self.model = OpenAI(model_name=model_name, **kwargs) elif self.aleph_alpha_api_key: - self.model = AlephAlpha(model=self.model_name, **kwargs) + self.model = AlephAlpha(model=model_name **kwargs) elif self.anthropic_api_key: - self.model = Anthropic(model=self.model, **kwargs) + self.model = Anthropic(model=model_name, **kwargs) elif self.cohere_api_key: - self.model = Cohere(model=self.model, **kwargs) + self.model = Cohere(model=model_name, **kwargs) else: raise ValueError("No valid API key environment variable found") return self.model diff --git a/dataherald/model/chat_model.py b/dataherald/model/chat_model.py index 5d637f68..00375dac 100644 --- a/dataherald/model/chat_model.py +++ b/dataherald/model/chat_model.py @@ -12,13 +12,13 @@ class ChatModel(LLMModel): def __init__(self, system): super().__init__(system) - self.model_name = os.environ.get("LLM_MODEL", "gpt-4-32k") @override def get_model( self, database_connection: DatabaseConnection, model_family="openai", + model_name: str = "gpt-4-32k", **kwargs: Any ) -> Any: if database_connection.llm_credentials is not None: @@ -35,6 +35,6 @@ def get_model( elif model_family == "cohere": os.environ["COHERE_API_KEY"] = api_key try: - return ChatLiteLLM(model_name=self.model_name, **kwargs) + return ChatLiteLLM(model_name=model_name, **kwargs) except Exception as e: raise ValueError("No valid API key environment variable found") from e diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index e2e075fa..839a1304 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -616,6 +616,7 @@ def generate_response( storage = self.system.instance(DB) self.llm = self.model.get_model( database_connection=database_connection, + model_name = os.environ.get("LLM_MODEL", "gpt-4-32k"), temperature=0, ) repository = TableDescriptionRepository(storage) diff --git a/dataherald/sql_generator/generates_nl_answer.py b/dataherald/sql_generator/generates_nl_answer.py index 7bcc8845..a9619200 100644 --- a/dataherald/sql_generator/generates_nl_answer.py +++ b/dataherald/sql_generator/generates_nl_answer.py @@ -1,3 +1,5 @@ +import os + from langchain.chains import LLMChain from langchain.prompts.chat import ( ChatPromptTemplate, @@ -39,6 +41,7 @@ def execute(self, query_response: Response) -> Response: ) self.llm = self.model.get_model( database_connection=database_connection, + model_name = os.environ.get("RESPONSE_GENERATOR_LLM_MODEL", "gpt-3.5-turbo-16k"), temperature=0, ) database = SQLDatabase.get_sql_engine(database_connection) diff --git a/dataherald/sql_generator/langchain_sqlagent.py b/dataherald/sql_generator/langchain_sqlagent.py index bad3d822..f13cc0cd 100644 --- a/dataherald/sql_generator/langchain_sqlagent.py +++ b/dataherald/sql_generator/langchain_sqlagent.py @@ -1,6 +1,7 @@ """A wrapper for the SQL generation functions in langchain""" import logging +import os import time from typing import Any, List @@ -31,7 +32,9 @@ def generate_response( ) -> Response: # type: ignore logger.info(f"Generating SQL response to question: {str(user_question.dict())}") self.llm = self.model.get_model( - database_connection=database_connection, temperature=0 + database_connection=database_connection, + model_name = os.environ.get("LLM_MODEL", "gpt-4-32k"), + temperature=0 ) self.database = SQLDatabase.get_sql_engine(database_connection) tools = SQLDatabaseToolkit(db=self.database, llm=self.llm).get_tools() diff --git a/dataherald/sql_generator/langchain_sqlchain.py b/dataherald/sql_generator/langchain_sqlchain.py index f4e3bac0..77493600 100644 --- a/dataherald/sql_generator/langchain_sqlchain.py +++ b/dataherald/sql_generator/langchain_sqlchain.py @@ -1,6 +1,7 @@ """A wrapper for the SQL generation functions in langchain""" import logging +import os import time from typing import Any, List @@ -49,7 +50,9 @@ def generate_response( ) -> Response: start_time = time.time() self.llm = self.model.get_model( - database_connection=database_connection, temperature=0 + database_connection=database_connection, + model_name = os.environ.get("LLM_MODEL", "gpt-4-32k"), + temperature=0 ) self.database = SQLDatabase.get_sql_engine(database_connection) logger.info( diff --git a/dataherald/sql_generator/llamaindex.py b/dataherald/sql_generator/llamaindex.py index 7ef7d4a9..14287b46 100644 --- a/dataherald/sql_generator/llamaindex.py +++ b/dataherald/sql_generator/llamaindex.py @@ -1,6 +1,7 @@ """A wrapper for the SQL generation functions in langchain""" import logging +import os import time from typing import Any, List @@ -38,7 +39,9 @@ def generate_response( start_time = time.time() logger.info(f"Generating SQL response to question: {str(user_question.dict())}") self.llm = self.model.get_model( - database_connection=database_connection, temperature=0 + database_connection=database_connection, + model_name = os.environ.get("LLM_MODEL", "gpt-4-32k"), + temperature=0 ) token_counter = TokenCountingHandler( tokenizer=tiktoken.encoding_for_model(self.llm.model_name).encode,