Skip to content

Commit

Permalink
core[patch]: return ToolMessage from tool (#28605)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Dec 10, 2024
1 parent d0e9597 commit e24f86e
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 26 deletions.
11 changes: 10 additions & 1 deletion libs/core/langchain_core/messages/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@
from langchain_core.utils._merge import merge_dicts, merge_obj


class ToolMessage(BaseMessage):
class ToolOutputMixin:
"""Mixin for objects that tools can return directly.
If a custom BaseTool is invoked with a ToolCall and the output of custom code is
not an instance of ToolOutputMixin, the output will automatically be coerced to a
string and wrapped in a ToolMessage.
"""


class ToolMessage(BaseMessage, ToolOutputMixin):
"""Message for passing the result of executing a tool back to a model.
ToolMessages contain the result of a tool invocation. Typically, the result
Expand Down
103 changes: 80 additions & 23 deletions libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
CallbackManager,
Callbacks,
)
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.messages.tool import ToolCall, ToolMessage, ToolOutputMixin
from langchain_core.runnables import (
RunnableConfig,
RunnableSerializable,
Expand Down Expand Up @@ -494,7 +494,9 @@ async def ainvoke(

# --- Tool ---

def _parse_input(self, tool_input: Union[str, dict]) -> Union[str, dict[str, Any]]:
def _parse_input(
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
) -> Union[str, dict[str, Any]]:
"""Convert tool input to a pydantic model.
Args:
Expand All @@ -512,9 +514,39 @@ def _parse_input(self, tool_input: Union[str, dict]) -> Union[str, dict[str, Any
else:
if input_args is not None:
if issubclass(input_args, BaseModel):
for k, v in get_all_basemodel_annotations(input_args).items():
if (
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
and k not in tool_input
):
if tool_call_id is None:
msg = (
"When tool includes an InjectedToolCallId "
"argument, tool must always be invoked with a full "
"model ToolCall of the form: {'args': {...}, "
"'name': '...', 'type': 'tool_call', "
"'tool_call_id': '...'}"
)
raise ValueError(msg)
tool_input[k] = tool_call_id
result = input_args.model_validate(tool_input)
result_dict = result.model_dump()
elif issubclass(input_args, BaseModelV1):
for k, v in get_all_basemodel_annotations(input_args).items():
if (
_is_injected_arg_type(v, injected_type=InjectedToolCallId)
and k not in tool_input
):
if tool_call_id is None:
msg = (
"When tool includes an InjectedToolCallId "
"argument, tool must always be invoked with a full "
"model ToolCall of the form: {'args': {...}, "
"'name': '...', 'type': 'tool_call', "
"'tool_call_id': '...'}"
)
raise ValueError(msg)
tool_input[k] = tool_call_id
result = input_args.parse_obj(tool_input)
result_dict = result.dict()
else:
Expand Down Expand Up @@ -570,8 +602,10 @@ async def _arun(self, *args: Any, **kwargs: Any) -> Any:
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
return await run_in_executor(None, self._run, *args, **kwargs)

def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]:
tool_input = self._parse_input(tool_input)
def _to_args_and_kwargs(
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
) -> tuple[tuple, dict]:
tool_input = self._parse_input(tool_input, tool_call_id)
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
Expand Down Expand Up @@ -648,10 +682,9 @@ def run(
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
if signature(self._run).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager

if config_param := _get_runnable_config_param(self._run):
tool_kwargs[config_param] = config
response = context.run(self._run, *tool_args, **tool_kwargs)
Expand Down Expand Up @@ -755,7 +788,7 @@ async def arun(
artifact = None
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
Expand Down Expand Up @@ -889,20 +922,23 @@ def _prep_run_args(


def _format_output(
content: Any, artifact: Any, tool_call_id: Optional[str], name: str, status: str
) -> Union[ToolMessage, Any]:
if tool_call_id:
if not _is_message_content_type(content):
content = _stringify(content)
return ToolMessage(
content,
artifact=artifact,
tool_call_id=tool_call_id,
name=name,
status=status,
)
else:
content: Any,
artifact: Any,
tool_call_id: Optional[str],
name: str,
status: str,
) -> Union[ToolOutputMixin, Any]:
if isinstance(content, ToolOutputMixin) or not tool_call_id:
return content
if not _is_message_content_type(content):
content = _stringify(content)
return ToolMessage(
content,
artifact=artifact,
tool_call_id=tool_call_id,
name=name,
status=status,
)


def _is_message_content_type(obj: Any) -> bool:
Expand Down Expand Up @@ -954,10 +990,31 @@ class InjectedToolArg:
"""Annotation for a Tool arg that is **not** meant to be generated by a model."""


def _is_injected_arg_type(type_: type) -> bool:
class InjectedToolCallId(InjectedToolArg):
r'''Annotation for injecting the tool_call_id.
Example:
..code-block:: python
from typing_extensions import Annotated
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool, InjectedToolCallID
@tool
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallID]) -> ToolMessage:
"""Return x."""
return ToolMessage(str(x), artifact=x, name="foo", tool_call_id=tool_call_id)
''' # noqa: E501


def _is_injected_arg_type(
type_: type, injected_type: Optional[type[InjectedToolArg]] = None
) -> bool:
injected_type = injected_type or InjectedToolArg
return any(
isinstance(arg, InjectedToolArg)
or (isinstance(arg, type) and issubclass(arg, InjectedToolArg))
isinstance(arg, injected_type)
or (isinstance(arg, type) and issubclass(arg, injected_type))
for arg in get_args(type_)[1:]
)

Expand Down
6 changes: 4 additions & 2 deletions libs/core/langchain_core/tools/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ def args(self) -> dict:
# assume it takes a single string input.
return {"tool_input": {"type": "string"}}

def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]:
def _to_args_and_kwargs(
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
) -> tuple[tuple, dict]:
"""Convert tool input to pydantic model."""
args, kwargs = super()._to_args_and_kwargs(tool_input)
args, kwargs = super()._to_args_and_kwargs(tool_input, tool_call_id)
# For backwards compatibility. The tool must be run with a single input
all_args = list(args) + list(kwargs.values())
if len(all_args) != 1:
Expand Down
64 changes: 64 additions & 0 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
CallbackManagerForToolRun,
)
from langchain_core.messages import ToolMessage
from langchain_core.messages.tool import ToolOutputMixin
from langchain_core.runnables import (
Runnable,
RunnableConfig,
Expand All @@ -46,6 +47,7 @@
)
from langchain_core.tools.base import (
InjectedToolArg,
InjectedToolCallId,
SchemaAnnotationError,
_is_message_content_block,
_is_message_content_type,
Expand Down Expand Up @@ -856,6 +858,7 @@ class _RaiseNonValidationErrorTool(BaseTool):
def _parse_input(
self,
tool_input: Union[str, dict],
tool_call_id: Optional[str],
) -> Union[str, dict[str, Any]]:
raise NotImplementedError

Expand Down Expand Up @@ -920,6 +923,7 @@ class _RaiseNonValidationErrorTool(BaseTool):
def _parse_input(
self,
tool_input: Union[str, dict],
tool_call_id: Optional[str],
) -> Union[str, dict[str, Any]]:
raise NotImplementedError

Expand Down Expand Up @@ -2110,3 +2114,63 @@ def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str:
return foo.value

assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" # type: ignore


def test_tool_injected_tool_call_id() -> None:
@tool
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage:
"""foo"""
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore

assert foo.invoke(
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
) == ToolMessage(0, tool_call_id="bar") # type: ignore

with pytest.raises(ValueError):
assert foo.invoke({"x": 0})

@tool
def foo2(x: int, tool_call_id: Annotated[str, InjectedToolCallId()]) -> ToolMessage:
"""foo"""
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore

assert foo2.invoke(
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
) == ToolMessage(0, tool_call_id="bar") # type: ignore


def test_tool_uninjected_tool_call_id() -> None:
@tool
def foo(x: int, tool_call_id: str) -> ToolMessage:
"""foo"""
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore

with pytest.raises(ValueError):
foo.invoke({"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"})

assert foo.invoke(
{
"type": "tool_call",
"args": {"x": 0, "tool_call_id": "zap"},
"name": "foo",
"id": "bar",
}
) == ToolMessage(0, tool_call_id="zap") # type: ignore


def test_tool_return_output_mixin() -> None:
class Bar(ToolOutputMixin):
def __init__(self, x: int) -> None:
self.x = x

def __eq__(self, other: Any) -> bool:
return isinstance(other, self.__class__) and self.x == other.x

@tool
def foo(x: int) -> Bar:
"""Foo."""
return Bar(x=x)

assert foo.invoke(
{"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"}
) == Bar(x=0)

0 comments on commit e24f86e

Please sign in to comment.