diff --git a/.env.example b/.env.example index bdae77c8..aa3f8ba7 100644 --- a/.env.example +++ b/.env.example @@ -1,7 +1,8 @@ # Openai info. All these fields are required for the engine to work. OPENAI_API_KEY = #This field is required for the engine to work. ORG_ID = -LLM_MODEL = 'gpt-4-32k' #the openAI llm model that you want to use. possible values: gpt-4-32k, gpt-4, gpt-3.5-turbo, gpt-3.5-turbo-16k +LLM_MODEL = 'gpt-4' #the openAI llm model that you want to use for evaluation and generating the nl answer. possible values: gpt-4, gpt-3.5-turbo +AGENT_LLM_MODEL = 'gpt-4-32k' # the llm model that you want to use for the agent, it should have a lrage context window. possible values: gpt-4-32k, gpt-3.5-turbo-16k DH_ENGINE_TIMEOUT = #timeout in seconds for the engine to return a response diff --git a/README.md b/README.md index 3af0e7ea..d397cd95 100644 --- a/README.md +++ b/README.md @@ -66,12 +66,17 @@ You can also self-host the engine locally using Docker. By default the engine us cp .env.example .env ``` -Specifically the following 4 fields must be manually set before the engine is started. +Specifically the following 5 fields must be manually set before the engine is started. + +LLM_MODEL is employed by evaluators and natural language generators that do not necessitate an extensive context window. + +AGENT_LLM_MODEL, on the other hand, is utilized by the NL-to-SQL generator, which relies on a larger context window. ``` #OpenAI credentials and model OPENAI_API_KEY = -LLM_MODEL = +LLM_MODEL = +AGENT_LLM_MODEL = ORG_ID = #Encryption key for storing DB connection data in Mongo diff --git a/dataherald/eval/eval_agent.py b/dataherald/eval/eval_agent.py index a48c2ea1..a0d8e888 100644 --- a/dataherald/eval/eval_agent.py +++ b/dataherald/eval/eval_agent.py @@ -1,4 +1,5 @@ import logging +import os import re import time from difflib import SequenceMatcher @@ -248,7 +249,9 @@ def evaluate( f"Generating score for the question/sql pair: {str(question.question)}/ {str(generated_answer.sql_query)}" ) self.llm = self.model.get_model( - database_connection=database_connection, temperature=0 + database_connection=database_connection, + temperature=0, + model_name=os.getenv("LLM_MODEL", "gpt-4"), ) database = SQLDatabase.get_sql_engine(database_connection) user_question = question.question diff --git a/dataherald/eval/simple_evaluator.py b/dataherald/eval/simple_evaluator.py index 7b54d418..0d45cefd 100644 --- a/dataherald/eval/simple_evaluator.py +++ b/dataherald/eval/simple_evaluator.py @@ -1,4 +1,5 @@ import logging +import os import re import time from typing import Any @@ -101,7 +102,9 @@ def evaluate( } ) self.llm = self.model.get_model( - database_connection=database_connection, temperature=0 + database_connection=database_connection, + temperature=0, + model_name=os.getenv("LLM_MODEL", "gpt-4"), ) start_time = time.time() system_message_prompt = SystemMessagePromptTemplate.from_template( diff --git a/dataherald/model/__init__.py b/dataherald/model/__init__.py index 9712298c..b8f4333b 100644 --- a/dataherald/model/__init__.py +++ b/dataherald/model/__init__.py @@ -17,6 +17,7 @@ def get_model( self, database_connection: DatabaseConnection, model_family="openai", + model_name="gpt-4", **kwargs: Any ) -> Any: pass diff --git a/dataherald/model/base_models.py b/dataherald/model/base_models.py index e27dedb9..507e40dc 100644 --- a/dataherald/model/base_models.py +++ b/dataherald/model/base_models.py @@ -12,7 +12,6 @@ class BaseModel(LLMModel): def __init__(self, system): super().__init__(system) - self.model_name = os.environ.get("LLM_MODEL", "text-davinci-003") self.openai_api_key = os.environ.get("OPENAI_API_KEY") self.aleph_alpha_api_key = os.environ.get("ALEPH_ALPHA_API_KEY") self.anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY") @@ -23,6 +22,7 @@ def get_model( self, database_connection: DatabaseConnection, model_family="openai", + model_name="davinci-003", **kwargs: Any ) -> Any: if database_connection.llm_credentials is not None: @@ -37,13 +37,13 @@ def get_model( elif model_family == "google": self.google_api_key = api_key if self.openai_api_key: - self.model = OpenAI(model_name=self.model_name, **kwargs) + self.model = OpenAI(model_name=model_name, **kwargs) elif self.aleph_alpha_api_key: - self.model = AlephAlpha(model=self.model_name, **kwargs) + self.model = AlephAlpha(model=model_name, **kwargs) elif self.anthropic_api_key: - self.model = Anthropic(model=self.model, **kwargs) + self.model = Anthropic(model=model_name, **kwargs) elif self.cohere_api_key: - self.model = Cohere(model=self.model, **kwargs) + self.model = Cohere(model=model_name, **kwargs) else: raise ValueError("No valid API key environment variable found") return self.model diff --git a/dataherald/model/chat_model.py b/dataherald/model/chat_model.py index dd78fe86..6e55ad4a 100644 --- a/dataherald/model/chat_model.py +++ b/dataherald/model/chat_model.py @@ -12,13 +12,13 @@ class ChatModel(LLMModel): def __init__(self, system): super().__init__(system) - self.model_name = os.environ.get("LLM_MODEL", "gpt-4-32k") @override def get_model( self, database_connection: DatabaseConnection, model_family="openai", + model_name="gpt-4-32k", **kwargs: Any ) -> Any: if database_connection.llm_credentials is not None: @@ -35,11 +35,11 @@ def get_model( elif model_family == "cohere": os.environ["COHERE_API_KEY"] = api_key if os.environ.get("OPENAI_API_KEY") is not None: - return ChatOpenAI(model_name=self.model_name, **kwargs) + return ChatOpenAI(model_name=model_name, **kwargs) if os.environ.get("ANTHROPIC_API_KEY") is not None: - return ChatAnthropic(model_name=self.model_name, **kwargs) + return ChatAnthropic(model_name=model_name, **kwargs) if os.environ.get("GOOGLE_API_KEY") is not None: - return ChatGooglePalm(model_name=self.model_name, **kwargs) + return ChatGooglePalm(model_name=model_name, **kwargs) if os.environ.get("COHERE_API_KEY") is not None: - return ChatCohere(model_name=self.model_name, **kwargs) + return ChatCohere(model_name=model_name, **kwargs) raise ValueError("No valid API key environment variable found") diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 625647d8..5296a442 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -617,6 +617,7 @@ def generate_response( self.llm = self.model.get_model( database_connection=database_connection, temperature=0, + model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"), ) repository = TableDescriptionRepository(storage) db_scan = repository.get_all_tables_by_db( diff --git a/dataherald/sql_generator/generates_nl_answer.py b/dataherald/sql_generator/generates_nl_answer.py index d9d7fe82..8fd503a8 100644 --- a/dataherald/sql_generator/generates_nl_answer.py +++ b/dataherald/sql_generator/generates_nl_answer.py @@ -1,3 +1,5 @@ +import os + from langchain.chains import LLMChain from langchain.prompts.chat import ( ChatPromptTemplate, @@ -40,6 +42,7 @@ def execute(self, query_response: Response) -> Response: self.llm = self.model.get_model( database_connection=database_connection, temperature=0, + model_name=os.getenv("LLM_MODEL", "gpt-4"), ) database = SQLDatabase.get_sql_engine(database_connection) query_response = create_sql_query_status( diff --git a/dataherald/sql_generator/langchain_sqlagent.py b/dataherald/sql_generator/langchain_sqlagent.py index bad3d822..99201478 100644 --- a/dataherald/sql_generator/langchain_sqlagent.py +++ b/dataherald/sql_generator/langchain_sqlagent.py @@ -1,6 +1,7 @@ """A wrapper for the SQL generation functions in langchain""" import logging +import os import time from typing import Any, List @@ -31,7 +32,9 @@ def generate_response( ) -> Response: # type: ignore logger.info(f"Generating SQL response to question: {str(user_question.dict())}") self.llm = self.model.get_model( - database_connection=database_connection, temperature=0 + database_connection=database_connection, + temperature=0, + model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"), ) self.database = SQLDatabase.get_sql_engine(database_connection) tools = SQLDatabaseToolkit(db=self.database, llm=self.llm).get_tools() diff --git a/dataherald/sql_generator/langchain_sqlchain.py b/dataherald/sql_generator/langchain_sqlchain.py index f4e3bac0..48e0bdad 100644 --- a/dataherald/sql_generator/langchain_sqlchain.py +++ b/dataherald/sql_generator/langchain_sqlchain.py @@ -1,6 +1,7 @@ """A wrapper for the SQL generation functions in langchain""" import logging +import os import time from typing import Any, List @@ -49,7 +50,9 @@ def generate_response( ) -> Response: start_time = time.time() self.llm = self.model.get_model( - database_connection=database_connection, temperature=0 + database_connection=database_connection, + temperature=0, + model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"), ) self.database = SQLDatabase.get_sql_engine(database_connection) logger.info( diff --git a/dataherald/sql_generator/llamaindex.py b/dataherald/sql_generator/llamaindex.py index 7ef7d4a9..59de3c90 100644 --- a/dataherald/sql_generator/llamaindex.py +++ b/dataherald/sql_generator/llamaindex.py @@ -1,6 +1,7 @@ """A wrapper for the SQL generation functions in langchain""" import logging +import os import time from typing import Any, List @@ -38,7 +39,9 @@ def generate_response( start_time = time.time() logger.info(f"Generating SQL response to question: {str(user_question.dict())}") self.llm = self.model.get_model( - database_connection=database_connection, temperature=0 + database_connection=database_connection, + temperature=0, + model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"), ) token_counter = TokenCountingHandler( tokenizer=tiktoken.encoding_for_model(self.llm.model_name).encode,