diff --git a/.env.example b/.env.example index 0f83b133..724d7fa2 100644 --- a/.env.example +++ b/.env.example @@ -1,8 +1,7 @@ # Openai info. All these fields are required for the engine to work. OPENAI_API_KEY = #This field is required for the engine to work. ORG_ID = -LLM_MODEL = 'gpt-4' #the openAI llm model that you want to use for evaluation and generating the nl answer. possible values: gpt-4, gpt-3.5-turbo -AGENT_LLM_MODEL = 'gpt-4-32k' # the llm model that you want to use for the agent, it should have a lrage context window. possible values: gpt-4-32k, gpt-3.5-turbo-16k +LLM_MODEL = 'gpt-4-1106-preview' #the openAI llm model that you want to use. possible values: gpt-4-1106-preview. DH_ENGINE_TIMEOUT = #timeout in seconds for the engine to return a response UPPER_LIMIT_QUERY_RETURN_ROWS = #The upper limit on number of rows returned from the query engine (equivalent to using LIMIT N in PostgreSQL/MySQL/SQlite). Defauls to 50 diff --git a/README.md b/README.md index 397ae325..1b4fa954 100644 --- a/README.md +++ b/README.md @@ -68,15 +68,12 @@ cp .env.example .env Specifically the following 5 fields must be manually set before the engine is started. -LLM_MODEL is employed by evaluators and natural language generators that do not necessitate an extensive context window. - -AGENT_LLM_MODEL, on the other hand, is utilized by the NL-to-SQL generator, which relies on a larger context window. +LLM_MODEL is employed by the engine to generate SQL from natural language. You can use the default model (gpt-4-1106-preview) or use your own. ``` #OpenAI credentials and model OPENAI_API_KEY = LLM_MODEL = -AGENT_LLM_MODEL = ORG_ID = #Encryption key for storing DB connection data in Mongo diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index e7b3cdbd..64b949ba 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -10,7 +10,6 @@ import openai import pandas as pd import sqlalchemy -import tiktoken from bson.objectid import ObjectId from google.api_core.exceptions import GoogleAPIError from langchain.agents.agent import AgentExecutor @@ -41,7 +40,6 @@ DatabaseConnection, ) from dataherald.sql_generator import EngineTimeOutORItemLimitError, SQLGenerator -from dataherald.sql_generator.adaptive_agent_executor import AdaptiveAgentExecutor from dataherald.types import Question, Response from dataherald.utils.agent_prompts import ( AGENT_PREFIX, @@ -53,7 +51,6 @@ SUFFIX_WITH_FEW_SHOT_SAMPLES, SUFFIX_WITHOUT_FEW_SHOT_SAMPLES, ) -from dataherald.utils.models_context_window import OPENAI_CONTEXT_WIDNOW_SIZES logger = logging.getLogger(__name__) @@ -585,24 +582,15 @@ def create_sql_agent( input_variables=input_variables, ) llm_chain = LLMChain( - llm=self.short_context_llm, + llm=self.llm, prompt=prompt, callback_manager=callback_manager, ) tool_names = [tool.name for tool in tools] agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) - return AdaptiveAgentExecutor.from_agent_and_tools( + return AgentExecutor.from_agent_and_tools( agent=agent, tools=tools, - llm_list={ - "short_context_llm": self.short_context_llm, - "long_context_llm": self.long_context_llm, - }, - switch_to_larger_model_threshold=OPENAI_CONTEXT_WIDNOW_SIZES[ - self.short_context_llm.model_name - ] - - 500, - encoding=tiktoken.encoding_for_model(self.short_context_llm.model_name), callback_manager=callback_manager, verbose=verbose, max_iterations=max_iterations, @@ -622,15 +610,10 @@ def generate_response( start_time = time.time() context_store = self.system.instance(ContextStore) storage = self.system.instance(DB) - self.short_context_llm = self.model.get_model( + self.llm = self.model.get_model( database_connection=database_connection, temperature=0, - model_name=os.getenv("LLM_MODEL", "gpt-4"), - ) - self.long_context_llm = self.model.get_model( - database_connection=database_connection, - temperature=0, - model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"), + model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"), ) repository = TableDescriptionRepository(storage) db_scan = repository.get_all_tables_by_db( diff --git a/dataherald/sql_generator/langchain_sqlagent.py b/dataherald/sql_generator/langchain_sqlagent.py index 14a3ffcf..93f1f121 100644 --- a/dataherald/sql_generator/langchain_sqlagent.py +++ b/dataherald/sql_generator/langchain_sqlagent.py @@ -35,7 +35,7 @@ def generate_response( self.llm = self.model.get_model( database_connection=database_connection, temperature=0, - model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"), + model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"), ) self.database = SQLDatabase.get_sql_engine(database_connection) tools = SQLDatabaseToolkit(db=self.database, llm=self.llm).get_tools() diff --git a/dataherald/sql_generator/langchain_sqlchain.py b/dataherald/sql_generator/langchain_sqlchain.py index 4abaef4a..6d903c44 100644 --- a/dataherald/sql_generator/langchain_sqlchain.py +++ b/dataherald/sql_generator/langchain_sqlchain.py @@ -53,7 +53,7 @@ def generate_response( self.llm = self.model.get_model( database_connection=database_connection, temperature=0, - model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"), + model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"), ) self.database = SQLDatabase.get_sql_engine(database_connection) logger.info( diff --git a/dataherald/sql_generator/llamaindex.py b/dataherald/sql_generator/llamaindex.py index 689729af..de9d0ddc 100644 --- a/dataherald/sql_generator/llamaindex.py +++ b/dataherald/sql_generator/llamaindex.py @@ -42,7 +42,7 @@ def generate_response( self.llm = self.model.get_model( database_connection=database_connection, temperature=0, - model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"), + model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"), ) token_counter = TokenCountingHandler( tokenizer=tiktoken.encoding_for_model(self.llm.model_name).encode,