Skip to content

Commit

Permalink
core[major]: On Tool End Observation Casting Fix (#18798)
Browse files Browse the repository at this point in the history
This PR updates the on_tool_end handlers to return the raw output from the tool instead of casting it to a string. 

This is technically a breaking change, though it's impact is expected to be somewhat minimal. It will fix behavior in `astream_events` as well.

Fixes the following issue #18760 raised by @eyurtsev

---------

Co-authored-by: Eugene Yurtsev <[email protected]>
  • Loading branch information
keenborder786 and eyurtsev authored Mar 11, 2024
1 parent a96a6e0 commit 43db4cd
Show file tree
Hide file tree
Showing 20 changed files with 42 additions and 34 deletions.
4 changes: 2 additions & 2 deletions docs/docs/modules/agents/how_to/streaming.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1068,15 +1068,15 @@
"\n",
" def on_tool_end(\n",
" self,\n",
" output: str,\n",
" output: Any,\n",
" *,\n",
" run_id: UUID,\n",
" parent_run_id: Optional[UUID] = None,\n",
" **kwargs: Any,\n",
" ) -> Any:\n",
" \"\"\"Run when tool ends running.\"\"\"\n",
" print(\"Tool end\")\n",
" print(output)\n",
" print(str(output))\n",
"\n",
" async def on_llm_end(\n",
" self,\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/modules/callbacks/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class BaseCallbackHandler:
) -> Any:
"""Run when tool starts running."""

def on_tool_end(self, output: str, **kwargs: Any) -> Any:
def on_tool_end(self, output: Any, **kwargs: Any) -> Any:
"""Run when tool ends running."""

def on_tool_error(
Expand Down
3 changes: 2 additions & 1 deletion libs/community/langchain_community/callbacks/aim_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,9 @@ def on_tool_start(

self._run.track(aim.Text(input_str), name="on_tool_start", context=resp)

def on_tool_end(self, output: str, **kwargs: Any) -> None:
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
output = str(output)
aim = import_aim()
self.step += 1
self.tool_ends += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:

def on_tool_end(
self,
output: str,
output: Any,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:

def on_tool_end(
self,
output: str,
output: Any,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:

def on_tool_end(
self,
output: str,
output: Any,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,9 @@ def on_tool_start(
if self.stream_logs:
self.logger.report_text(resp)

def on_tool_end(self, output: str, **kwargs: Any) -> None:
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
output = str(output)
self.step += 1
self.tool_ends += 1
self.ends += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,9 @@ def on_tool_start(
resp.update({"input_str": input_str})
self.action_records.append(resp)

def on_tool_end(self, output: str, **kwargs: Any) -> None:
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
output = str(output)
self.step += 1
self.tool_ends += 1
self.ends += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:

def on_tool_end(
self,
output: str,
output: Any,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,14 @@ def on_tool_start(

def on_tool_end(
self,
output: str,
output: Any,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
tags: Union[List[str], None] = None,
**kwargs: Any,
) -> None:
output = str(output)
if self.__has_valid_config is False:
return
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,9 @@ def on_tool_start(
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"tool_start_{tool_starts}")

def on_tool_end(self, output: str, **kwargs: Any) -> None:
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
output = str(output)
self.metrics["step"] += 1
self.metrics["tool_ends"] += 1
self.metrics["ends"] += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,9 @@ def on_tool_start(

self.jsonf(resp, self.temp_dir, f"tool_start_{tool_starts}")

def on_tool_end(self, output: str, **kwargs: Any) -> None:
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
output = str(output)
self.metrics["step"] += 1
self.metrics["tool_ends"] += 1
self.metrics["ends"] += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,13 @@ def on_tool_start(

def on_tool_end(
self,
output: str,
output: Any,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
self._container.markdown(f"**{output}**")
self._container.markdown(f"**{str(output)}**")

def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self._container.markdown("**Tool encountered an error...**")
Expand Down Expand Up @@ -363,12 +363,13 @@ def on_tool_start(

def on_tool_end(
self,
output: str,
output: Any,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
output = str(output)
self._require_current_thought().on_tool_end(
output, color, observation_prefix, llm_prefix, **kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,9 @@ def on_tool_start(
if self.stream_logs:
self.run.log(resp)

def on_tool_end(self, output: str, **kwargs: Any) -> None:
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
output = str(output)
self.step += 1
self.tool_ends += 1
self.ends += 1
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class ToolManagerMixin:

def on_tool_end(
self,
output: str,
output: Any,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
Expand Down Expand Up @@ -440,7 +440,7 @@ async def on_tool_start(

async def on_tool_end(
self,
output: str,
output: Any,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
Expand Down
10 changes: 6 additions & 4 deletions libs/core/langchain_core/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,14 +976,15 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):

def on_tool_end(
self,
output: str,
output: Any,
**kwargs: Any,
) -> None:
"""Run when tool ends running.
Args:
output (str): The output of the tool.
output (Any): The output of the tool.
"""
output = str(output)
handle_event(
self.handlers,
"on_tool_end",
Expand Down Expand Up @@ -1038,12 +1039,13 @@ def get_sync(self) -> CallbackManagerForToolRun:
)

@shielded
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
async def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running.
Args:
output (str): The output of the tool.
output (Any): The output of the tool.
"""
output = str(output)
await ahandle_event(
self.handlers,
"on_tool_end",
Expand Down
3 changes: 2 additions & 1 deletion libs/core/langchain_core/callbacks/stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ def on_agent_action(

def on_tool_end(
self,
output: str,
output: Any,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
output = str(output)
if observation_prefix is not None:
print_text(f"\n{observation_prefix}")
print_text(output, color=color or self.color)
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/callbacks/streaming_stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
pass

def on_tool_end(self, output: str, **kwargs: Any) -> None:
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""

def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
Expand Down
12 changes: 4 additions & 8 deletions libs/core/langchain_core/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,17 +410,13 @@ def run(
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {self.handle_tool_error}"
)
run_manager.on_tool_end(
str(observation), color="red", name=self.name, **kwargs
)
run_manager.on_tool_end(observation, color="red", name=self.name, **kwargs)
return observation
except (Exception, KeyboardInterrupt) as e:
run_manager.on_tool_error(e)
raise e
else:
run_manager.on_tool_end(
str(observation), color=color, name=self.name, **kwargs
)
run_manager.on_tool_end(observation, color=color, name=self.name, **kwargs)
return observation

async def arun(
Expand Down Expand Up @@ -502,15 +498,15 @@ async def arun(
f"or callable. Received: {self.handle_tool_error}"
)
await run_manager.on_tool_end(
str(observation), color="red", name=self.name, **kwargs
observation, color="red", name=self.name, **kwargs
)
return observation
except (Exception, KeyboardInterrupt) as e:
await run_manager.on_tool_error(e)
raise e
else:
await run_manager.on_tool_end(
str(observation), color=color, name=self.name, **kwargs
observation, color=color, name=self.name, **kwargs
)
return observation

Expand Down
3 changes: 2 additions & 1 deletion libs/core/langchain_core/tracers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,9 @@ def on_tool_start(
self._on_tool_start(tool_run)
return tool_run

def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run:
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for a tool run."""
output = str(output)
tool_run = self._get_run(run_id, run_type="tool")
tool_run.outputs = {"output": output}
tool_run.end_time = datetime.now(timezone.utc)
Expand Down

0 comments on commit 43db4cd

Please sign in to comment.