Skip to content

Commit

Permalink
DH-4992/creating adaptive agent executor (#234)
Browse files Browse the repository at this point in the history
* DH-4992/creating adaptive agent executor

* DH-4922/reducing the env vars and update the token counter

* DH-4922/reformat with black
  • Loading branch information
MohammadrezaPourreza authored and DishenWang2023 committed May 7, 2024
1 parent aec7fc1 commit 6d07cb8
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 3 deletions.
175 changes: 175 additions & 0 deletions dataherald/sql_generator/adaptive_agent_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from langchain.agents.agent import (
AgentExecutor,
BaseMultiActionAgent,
BaseSingleActionAgent,
ExceptionTool,
)
from langchain.agents.tools import InvalidTool
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools.base import BaseTool
from overrides import override
from tiktoken.core import Encoding


class AdaptiveAgentExecutor(AgentExecutor):
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent] # noqa: UP007
tools: Sequence[BaseTool]
return_intermediate_steps: bool = False
max_iterations: Optional[int] = 15 # noqa: UP007
max_execution_time: Optional[float] = None # noqa: UP007
early_stopping_method: str = "force"
handle_parsing_errors: Union[ # noqa: UP007
bool, str, Callable[[OutputParserException], str]
] = False
trim_intermediate_steps: Union[ # noqa: UP007
int, Callable[[List[Tuple[AgentAction, str]]], List[Tuple[AgentAction, str]]]
] = -1
llm_list: dict
switch_to_larger_model_threshold: int
enc: Encoding
tokens: int

@classmethod
def from_agent_and_tools(
cls,
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent], # noqa: UP007
tools: Sequence[BaseTool],
llm_list: dict,
switch_to_larger_model_threshold: int,
encoding: Encoding,
callbacks: Callbacks = None,
**kwargs: Any,
) -> AgentExecutor:
"""Create from agent and tools."""
return cls(
agent=agent,
tools=tools,
llm_list=llm_list,
switch_to_larger_model_threshold=switch_to_larger_model_threshold,
enc=encoding,
callbacks=callbacks,
tokens=len(encoding.encode(str(agent.llm_chain.prompt.template))),
**kwargs,
)

def token_counter(self, intermediate_steps: List[Tuple[AgentAction, str]]) -> int:
if len(intermediate_steps) == 0:
return self.tokens
new_item_text = (
str(intermediate_steps[-1][0].log)
+ str(intermediate_steps[-1][0].tool)
+ str(intermediate_steps[-1][0].tool_input)
+ str(intermediate_steps[-1][1])
)
self.tokens += len(self.enc.encode(new_item_text))
return self.tokens

@override
def _take_next_step( # noqa: PLR0912 C901 PLR0915
self,
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None, # noqa: UP007
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: # noqa: UP007
try:
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)

if self.agent.llm_chain.llm == self.llm_list["short_context_llm"]:
if (
self.token_counter(intermediate_steps)
> self.switch_to_larger_model_threshold
):
self.agent.llm_chain.llm = self.llm_list["long_context_llm"]

# Call the LLM to see what to do.
output = self.agent.plan(
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
except OutputParserException as e:
if isinstance(self.handle_parsing_errors, bool):
raise_error = not self.handle_parsing_errors
else:
raise_error = False
if raise_error:
raise ValueError( # noqa: B904
"An output parsing error occurred. "
"In order to pass this error back to the agent and have it try "
"again, pass `handle_parsing_errors=True` to the AgentExecutor. "
f"This is the error: {str(e)}"
)
text = str(e)
if isinstance(self.handle_parsing_errors, bool):
if e.send_to_llm:
observation = str(e.observation)
text = str(e.llm_output)
else:
observation = "Invalid or incomplete response"
elif isinstance(self.handle_parsing_errors, str):
observation = self.handle_parsing_errors
elif callable(self.handle_parsing_errors):
observation = self.handle_parsing_errors(e)
else:
raise ValueError(
"Got unexpected type of `handle_parsing_errors`"
) from e
output = AgentAction("_Exception", observation, text)
if run_manager:
run_manager.on_agent_action(output, color="green")
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = ExceptionTool().run(
output.tool_input,
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
return [(output, observation)]
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
return output
actions: List[AgentAction]
if isinstance(output, AgentAction):
actions = [output]
else:
actions = output
result = []
for agent_action in actions:
if run_manager:
run_manager.on_agent_action(agent_action, color="green")
# Otherwise we lookup the tool
if agent_action.tool in name_to_tool_map:
tool = name_to_tool_map[agent_action.tool]
return_direct = tool.return_direct
color = color_mapping[agent_action.tool]
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
if return_direct:
tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation
observation = tool.run(
agent_action.tool_input,
verbose=self.verbose,
color=color,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = InvalidTool().run(
{
"requested_tool_name": agent_action.tool,
"available_tool_names": list(name_to_tool_map.keys()),
},
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
result.append((agent_action, observation))
return result
23 changes: 20 additions & 3 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import openai
import pandas as pd
import sqlalchemy
import tiktoken
from bson.objectid import ObjectId
from google.api_core.exceptions import GoogleAPIError
from langchain.agents.agent import AgentExecutor
Expand Down Expand Up @@ -40,6 +41,7 @@
DatabaseConnection,
)
from dataherald.sql_generator import EngineTimeOutORItemLimitError, SQLGenerator
from dataherald.sql_generator.adaptive_agent_executor import AdaptiveAgentExecutor
from dataherald.types import Question, Response
from dataherald.utils.agent_prompts import (
AGENT_PREFIX,
Expand All @@ -51,6 +53,7 @@
SUFFIX_WITH_FEW_SHOT_SAMPLES,
SUFFIX_WITHOUT_FEW_SHOT_SAMPLES,
)
from dataherald.utils.models_context_window import OPENAI_CONTEXT_WIDNOW_SIZES

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -578,15 +581,24 @@ def create_sql_agent(
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=self.llm,
llm=self.short_context_llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
return AdaptiveAgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
llm_list={
"short_context_llm": self.short_context_llm,
"long_context_llm": self.long_context_llm,
},
switch_to_larger_model_threshold=OPENAI_CONTEXT_WIDNOW_SIZES[
self.short_context_llm.model_name
]
- 500,
encoding=tiktoken.encoding_for_model(self.short_context_llm.model_name),
callback_manager=callback_manager,
verbose=verbose,
max_iterations=max_iterations,
Expand All @@ -605,7 +617,12 @@ def generate_response(
start_time = time.time()
context_store = self.system.instance(ContextStore)
storage = self.system.instance(DB)
self.llm = self.model.get_model(
self.short_context_llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
model_name=os.getenv("LLM_MODEL", "gpt-4"),
)
self.long_context_llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"),
Expand Down
14 changes: 14 additions & 0 deletions dataherald/utils/models_context_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
OPENAI_CONTEXT_WIDNOW_SIZES = {
"gpt-4": 8000,
"gpt-4-0613": 8000,
"gpt-4-32k": 32000,
"gpt-4-32k-0613": 32000,
"gpt-4-0314": 8000,
"gpt-4-32k-0314": 32000,
"gpt-3.5-turbo": 4000,
"gpt-3.5-turbo-16k": 16000,
"gpt-3.5-turbo-instruct": 4000,
"gpt-3.5-turbo-0613": 4000,
"gpt-3.5-turbo-16k-0613": 16000,
"gpt-3.5-turbo-0301": 4000,
}

0 comments on commit 6d07cb8

Please sign in to comment.