Skip to content

Commit

Permalink
Remove adaptive executor
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Nov 15, 2023
1 parent b8dc51f commit 5fe4474
Showing 1 changed file with 4 additions and 21 deletions.
25 changes: 4 additions & 21 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import openai
import pandas as pd
import sqlalchemy
import tiktoken
from bson.objectid import ObjectId
from google.api_core.exceptions import GoogleAPIError
from langchain.agents.agent import AgentExecutor
Expand Down Expand Up @@ -41,7 +40,6 @@
DatabaseConnection,
)
from dataherald.sql_generator import EngineTimeOutORItemLimitError, SQLGenerator
from dataherald.sql_generator.adaptive_agent_executor import AdaptiveAgentExecutor
from dataherald.types import Question, Response
from dataherald.utils.agent_prompts import (
AGENT_PREFIX,
Expand All @@ -53,7 +51,6 @@
SUFFIX_WITH_FEW_SHOT_SAMPLES,
SUFFIX_WITHOUT_FEW_SHOT_SAMPLES,
)
from dataherald.utils.models_context_window import OPENAI_CONTEXT_WIDNOW_SIZES

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -585,24 +582,15 @@ def create_sql_agent(
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=self.short_context_llm,
llm=self.llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AdaptiveAgentExecutor.from_agent_and_tools(
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
llm_list={
"short_context_llm": self.short_context_llm,
"long_context_llm": self.long_context_llm,
},
switch_to_larger_model_threshold=OPENAI_CONTEXT_WIDNOW_SIZES[
self.short_context_llm.model_name
]
- 500,
encoding=tiktoken.encoding_for_model(self.short_context_llm.model_name),
callback_manager=callback_manager,
verbose=verbose,
max_iterations=max_iterations,
Expand All @@ -622,15 +610,10 @@ def generate_response(
start_time = time.time()
context_store = self.system.instance(ContextStore)
storage = self.system.instance(DB)
self.short_context_llm = self.model.get_model(
self.llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
model_name=os.getenv("LLM_MODEL", "gpt-4"),
)
self.long_context_llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"),
model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"),
)
repository = TableDescriptionRepository(storage)
db_scan = repository.get_all_tables_by_db(
Expand Down

0 comments on commit 5fe4474

Please sign in to comment.