diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index e7b3cdbd..64b949ba 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -10,7 +10,6 @@ import openai import pandas as pd import sqlalchemy -import tiktoken from bson.objectid import ObjectId from google.api_core.exceptions import GoogleAPIError from langchain.agents.agent import AgentExecutor @@ -41,7 +40,6 @@ DatabaseConnection, ) from dataherald.sql_generator import EngineTimeOutORItemLimitError, SQLGenerator -from dataherald.sql_generator.adaptive_agent_executor import AdaptiveAgentExecutor from dataherald.types import Question, Response from dataherald.utils.agent_prompts import ( AGENT_PREFIX, @@ -53,7 +51,6 @@ SUFFIX_WITH_FEW_SHOT_SAMPLES, SUFFIX_WITHOUT_FEW_SHOT_SAMPLES, ) -from dataherald.utils.models_context_window import OPENAI_CONTEXT_WIDNOW_SIZES logger = logging.getLogger(__name__) @@ -585,24 +582,15 @@ def create_sql_agent( input_variables=input_variables, ) llm_chain = LLMChain( - llm=self.short_context_llm, + llm=self.llm, prompt=prompt, callback_manager=callback_manager, ) tool_names = [tool.name for tool in tools] agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) - return AdaptiveAgentExecutor.from_agent_and_tools( + return AgentExecutor.from_agent_and_tools( agent=agent, tools=tools, - llm_list={ - "short_context_llm": self.short_context_llm, - "long_context_llm": self.long_context_llm, - }, - switch_to_larger_model_threshold=OPENAI_CONTEXT_WIDNOW_SIZES[ - self.short_context_llm.model_name - ] - - 500, - encoding=tiktoken.encoding_for_model(self.short_context_llm.model_name), callback_manager=callback_manager, verbose=verbose, max_iterations=max_iterations, @@ -622,15 +610,10 @@ def generate_response( start_time = time.time() context_store = self.system.instance(ContextStore) storage = self.system.instance(DB) - self.short_context_llm = self.model.get_model( + self.llm = self.model.get_model( database_connection=database_connection, temperature=0, - model_name=os.getenv("LLM_MODEL", "gpt-4"), - ) - self.long_context_llm = self.model.get_model( - database_connection=database_connection, - temperature=0, - model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"), + model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"), ) repository = TableDescriptionRepository(storage) db_scan = repository.get_all_tables_by_db(