diff --git a/libs/langgraph/langgraph/prebuilt/tool_node.py b/libs/langgraph/langgraph/prebuilt/tool_node.py index 38c349edb..d3d0751e2 100644 --- a/libs/langgraph/langgraph/prebuilt/tool_node.py +++ b/libs/langgraph/langgraph/prebuilt/tool_node.py @@ -1,7 +1,7 @@ import asyncio import inspect import json -from copy import copy +from copy import copy, deepcopy from typing import ( Any, Callable, @@ -20,6 +20,7 @@ AnyMessage, ToolCall, ToolMessage, + convert_to_messages, ) from langchain_core.runnables import RunnableConfig from langchain_core.runnables.config import ( @@ -35,6 +36,7 @@ from langgraph.errors import GraphBubbleUp from langgraph.store.base import BaseStore +from langgraph.types import Command from langgraph.utils.runnable import RunnableCallable INVALID_TOOL_NAME_ERROR_TEMPLATE = ( @@ -47,7 +49,7 @@ def msg_content_output(output: Any) -> Union[str, list[dict]]: recognized_content_block_types = ("image", "image_url", "text", "json") if isinstance(output, str): return output - elif all( + elif isinstance(output, list) and all( [ isinstance(x, dict) and x.get("type") in recognized_content_block_types for x in output @@ -210,12 +212,31 @@ def _func( *, store: BaseStore, ) -> Any: - tool_calls, output_type = self._parse_input(input, store) + tool_calls, input_type = self._parse_input(input, store) config_list = get_config_list(config, len(tool_calls)) + input_types = [input_type] * len(tool_calls) with get_executor_for_config(config) as executor: - outputs = [*executor.map(self._run_one, tool_calls, config_list)] - # TypedDict, pydantic, dataclass, etc. should all be able to load from dict - return outputs if output_type == "list" else {self.messages_key: outputs} + outputs = [ + *executor.map(self._run_one, tool_calls, input_types, config_list) + ] + + # preserve existing behavior for non-command tool outputs for backwards compatibility + if not any(isinstance(output, Command) for output in outputs): + # TypedDict, pydantic, dataclass, etc. should all be able to load from dict + return outputs if input_type == "list" else {self.messages_key: outputs} + + # LangGraph will automatically handle list of Command and non-command node updates + combined_outputs: list[ + Command | list[ToolMessage] | dict[str, list[ToolMessage]] + ] = [] + for output in outputs: + if isinstance(output, Command): + combined_outputs.append(output) + else: + combined_outputs.append( + [output] if input_type == "list" else {self.messages_key: [output]} + ) + return combined_outputs def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any @@ -242,26 +263,42 @@ async def _afunc( *, store: BaseStore, ) -> Any: - tool_calls, output_type = self._parse_input(input, store) + tool_calls, input_type = self._parse_input(input, store) outputs = await asyncio.gather( - *(self._arun_one(call, config) for call in tool_calls) + *(self._arun_one(call, input_type, config) for call in tool_calls) ) - # TypedDict, pydantic, dataclass, etc. should all be able to load from dict - return outputs if output_type == "list" else {self.messages_key: outputs} - def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage: + # preserve existing behavior for non-command tool outputs for backwards compatibility + if not any(isinstance(output, Command) for output in outputs): + # TypedDict, pydantic, dataclass, etc. should all be able to load from dict + return outputs if input_type == "list" else {self.messages_key: outputs} + + # LangGraph will automatically handle list of Command and non-command node updates + combined_outputs: list[ + Command | list[ToolMessage] | dict[str, list[ToolMessage]] + ] = [] + for output in outputs: + if isinstance(output, Command): + combined_outputs.append(output) + else: + combined_outputs.append( + [output] if input_type == "list" else {self.messages_key: [output]} + ) + return combined_outputs + + def _run_one( + self, + call: ToolCall, + input_type: Literal["list", "dict"], + config: RunnableConfig, + ) -> ToolMessage: if invalid_tool_message := self._validate_tool_call(call): return invalid_tool_message try: input = {**call, **{"type": "tool_call"}} - tool_message: ToolMessage = self.tools_by_name[call["name"]].invoke( - input, config - ) - tool_message.content = cast( - Union[str, list], msg_content_output(tool_message.content) - ) - return tool_message + response = self.tools_by_name[call["name"]].invoke(input) + # GraphInterrupt is a special exception that will always be raised. # It can be triggered in the following scenarios: # (1) a NodeInterrupt is raised inside a tool @@ -285,24 +322,38 @@ def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage: # Handled else: content = _handle_tool_error(e, flag=self.handle_tool_errors) + return ToolMessage( + content=content, + name=call["name"], + tool_call_id=call["id"], + status="error", + ) - return ToolMessage( - content=content, name=call["name"], tool_call_id=call["id"], status="error" - ) + if isinstance(response, Command): + return self._validate_tool_command(response, call, input_type) + elif isinstance(response, ToolMessage): + response.content = cast( + Union[str, list], msg_content_output(response.content) + ) + return response + else: + raise TypeError( + f"Tool {call['name']} returned unexpected type: {type(response)}" + ) - async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage: + async def _arun_one( + self, + call: ToolCall, + input_type: Literal["list", "dict"], + config: RunnableConfig, + ) -> ToolMessage: if invalid_tool_message := self._validate_tool_call(call): return invalid_tool_message try: input = {**call, **{"type": "tool_call"}} - tool_message: ToolMessage = await self.tools_by_name[call["name"]].ainvoke( - input, config - ) - tool_message.content = cast( - Union[str, list], msg_content_output(tool_message.content) - ) - return tool_message + response = await self.tools_by_name[call["name"]].ainvoke(input) + # GraphInterrupt is a special exception that will always be raised. # It can be triggered in the following scenarios: # (1) a NodeInterrupt is raised inside a tool @@ -327,9 +378,24 @@ async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage else: content = _handle_tool_error(e, flag=self.handle_tool_errors) - return ToolMessage( - content=content, name=call["name"], tool_call_id=call["id"], status="error" - ) + return ToolMessage( + content=content, + name=call["name"], + tool_call_id=call["id"], + status="error", + ) + + if isinstance(response, Command): + return self._validate_tool_command(response, call, input_type) + elif isinstance(response, ToolMessage): + response.content = cast( + Union[str, list], msg_content_output(response.content) + ) + return response + else: + raise TypeError( + f"Tool {call['name']} returned unexpected type: {type(response)}" + ) def _parse_input( self, @@ -341,14 +407,14 @@ def _parse_input( store: BaseStore, ) -> Tuple[list[ToolCall], Literal["list", "dict"]]: if isinstance(input, list): - output_type = "list" + input_type = "list" message: AnyMessage = input[-1] elif isinstance(input, dict) and (messages := input.get(self.messages_key, [])): - output_type = "dict" + input_type = "dict" message = messages[-1] elif messages := getattr(input, self.messages_key, None): # Assume dataclass-like state that can coerce from dict - output_type = "dict" + input_type = "dict" message = messages[-1] else: raise ValueError("No message found in input") @@ -359,7 +425,7 @@ def _parse_input( tool_calls = [ self._inject_tool_args(call, input, store) for call in message.tool_calls ] - return tool_calls, output_type + return tool_calls, input_type def _validate_tool_call(self, call: ToolCall) -> Optional[ToolMessage]: if (requested_tool := call["name"]) not in self.tools_by_name: @@ -453,6 +519,67 @@ def _inject_tool_args( tool_call_with_store = self._inject_store(tool_call_with_state, store) return tool_call_with_store + def _validate_tool_command( + self, command: Command, call: ToolCall, input_type: Literal["list", "dict"] + ) -> Command: + if isinstance(command.update, dict): + # input type is dict when ToolNode is invoked with a dict input (e.g. {"messages": [AIMessage(..., tool_calls=[...])]}) + if input_type != "dict": + raise ValueError( + f"Tools can provide a dict in Command.update only when using dict with '{self.messages_key}' key as ToolNode input, " + f"got: {command.update} for tool '{call['name']}'" + ) + + updated_command = deepcopy(command) + state_update = cast(dict[str, Any], updated_command.update) or {} + messages_update = state_update.get(self.messages_key, []) + elif isinstance(command.update, list): + # input type is list when ToolNode is invoked with a list input (e.g. [AIMessage(..., tool_calls=[...])]) + if input_type != "list": + raise ValueError( + f"Tools can provide a list of messages in Command.update only when using list of messages as ToolNode input, " + f"got: {command.update} for tool '{call['name']}'" + ) + + updated_command = deepcopy(command) + messages_update = updated_command.update + else: + return command + + # convert to message objects if updates are in a dict format + messages_update = convert_to_messages(messages_update) + have_seen_tool_messages = False + for message in messages_update: + if not isinstance(message, ToolMessage): + continue + + if have_seen_tool_messages: + raise ValueError( + f"Expected at most one ToolMessage in Command.update for tool '{call['name']}', got multiple: {messages_update}." + ) + + if message.tool_call_id != call["id"]: + raise ValueError( + f"ToolMessage.tool_call_id must match the tool call id. Expected: {call['id']}, got: {message.tool_call_id} for tool '{call['name']}'." + ) + + message.name = call["name"] + have_seen_tool_messages = True + + # validate that we always have exactly one ToolMessage in Command.update if command is sent to the CURRENT graph + if updated_command.graph is None and not have_seen_tool_messages: + example_update = ( + '`Command(update={"messages": [ToolMessage("Success", tool_call_id=tool_call_id), ...]}, ...)`' + if input_type == "dict" + else '`Command(update=[ToolMessage("Success", tool_call_id=tool_call_id), ...], ...)`' + ) + raise ValueError( + f"Expected exactly one message (ToolMessage) in Command.update for tool '{call['name']}', got: {messages_update}. " + "Every tool call (LLM requesting to call a tool) in the message history MUST have a corresponding ToolMessage. " + f"You can fix it by modifying the tool to return {example_update}." + ) + return updated_command + def tools_condition( state: Union[list[AnyMessage], dict[str, Any], BaseModel], diff --git a/libs/langgraph/langgraph/types.py b/libs/langgraph/langgraph/types.py index 1c23ad4de..2086552f9 100644 --- a/libs/langgraph/langgraph/types.py +++ b/libs/langgraph/langgraph/types.py @@ -32,6 +32,14 @@ from langgraph.store.base import BaseStore +try: + from langchain_core.messages.tool import ToolOutputMixin +except ImportError: + + class ToolOutputMixin: # type: ignore[no-redef] + pass + + All = Literal["*"] """Special value to indicate that graph should interrupt on all nodes.""" @@ -244,7 +252,7 @@ def __eq__(self, value: object) -> bool: @dataclasses.dataclass(**_DC_KWARGS) -class Command(Generic[N]): +class Command(Generic[N], ToolOutputMixin): """One or more commands to update the graph's state and send messages to nodes. Args: diff --git a/libs/langgraph/poetry.lock b/libs/langgraph/poetry.lock index 634b9d8db..bdb2a4da6 100644 --- a/libs/langgraph/poetry.lock +++ b/libs/langgraph/poetry.lock @@ -1325,13 +1325,13 @@ files = [ [[package]] name = "langchain-core" -version = "0.3.15" +version = "0.3.23" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_core-0.3.15-py3-none-any.whl", hash = "sha256:3d4ca6dbb8ed396a6ee061063832a2451b0ce8c345570f7b086ffa7288e4fa29"}, - {file = "langchain_core-0.3.15.tar.gz", hash = "sha256:b1a29787a4ffb7ec2103b4e97d435287201da7809b369740dd1e32f176325aba"}, + {file = "langchain_core-0.3.23-py3-none-any.whl", hash = "sha256:550c0b996990830fa6515a71a1192a8a0343367999afc36d4ede14222941e420"}, + {file = "langchain_core-0.3.23.tar.gz", hash = "sha256:f9e175e3b82063cc3b160c2ca2b155832e1c6f915312e1204828f97d4aabf6e1"}, ] [package.dependencies] @@ -1382,7 +1382,7 @@ url = "../checkpoint-duckdb" [[package]] name = "langgraph-checkpoint-postgres" -version = "2.0.7" +version = "2.0.8" description = "Library with a Postgres implementation of LangGraph checkpoint saver." optional = false python-versions = "^3.9.0,<4.0" @@ -1418,7 +1418,7 @@ url = "../checkpoint-sqlite" [[package]] name = "langgraph-sdk" -version = "0.1.42" +version = "0.1.43" description = "SDK for interacting with LangGraph API" optional = false python-versions = "^3.9.0,<4.0" @@ -3413,4 +3413,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<4.0" -content-hash = "2df4d5d5e61917bdfff0ba430067a17662666eedee2858d841fa02e594cf69d0" +content-hash = "936530a5f00f329aeff2e6e921fe64480be317fad2c0a59cd54ea9018d089304" diff --git a/libs/langgraph/pyproject.toml b/libs/langgraph/pyproject.toml index 925be91ff..cf29b2375 100644 --- a/libs/langgraph/pyproject.toml +++ b/libs/langgraph/pyproject.toml @@ -9,7 +9,7 @@ repository = "https://www.github.com/langchain-ai/langgraph" [tool.poetry.dependencies] python = ">=3.9.0,<4.0" -langchain-core = ">=0.2.43,<0.4.0,!=0.3.0,!=0.3.1,!=0.3.2,!=0.3.3,!=0.3.4,!=0.3.5,!=0.3.6,!=0.3.7,!=0.3.8,!=0.3.9,!=0.3.10,!=0.3.11,!=0.3.12,!=0.3.13,!=0.3.14" +langchain-core = ">=0.2.43,<0.4.0,!=0.3.0,!=0.3.1,!=0.3.2,!=0.3.3,!=0.3.4,!=0.3.5,!=0.3.6,!=0.3.7,!=0.3.8,!=0.3.9,!=0.3.10,!=0.3.11,!=0.3.12,!=0.3.13,!=0.3.14,!=0.3.15,!=0.3.16,!=0.3.17,!=0.3.18,!=0.3.19,!=0.3.20,!=0.3.21,!=0.3.22" langgraph-checkpoint = "^2.0.4" langgraph-sdk = "^0.1.42" diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index 0997668b2..9c6541d9a 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -46,7 +46,7 @@ create_react_agent, tools_condition, ) -from langgraph.prebuilt.chat_agent_executor import _validate_chat_history +from langgraph.prebuilt.chat_agent_executor import AgentState, _validate_chat_history from langgraph.prebuilt.tool_node import ( TOOL_CALL_ERROR_TEMPLATE, InjectedState, @@ -56,7 +56,7 @@ ) from langgraph.store.base import BaseStore from langgraph.store.memory import InMemoryStore -from langgraph.types import Interrupt +from langgraph.types import Command, Interrupt, interrupt from tests.conftest import ( ALL_CHECKPOINTERS_ASYNC, ALL_CHECKPOINTERS_SYNC, @@ -988,6 +988,645 @@ def handle(e: NodeInterrupt): assert task.interrupts == (Interrupt(value="foo", when="during"),) +@pytest.mark.skipif( + not IS_LANGCHAIN_CORE_030_OR_GREATER, + reason="Langchain core 0.3.0 or greater is required", +) +async def test_tool_node_command(): + from langchain_core.tools.base import InjectedToolCallId + + @dec_tool + def transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): + """Transfer to Bob""" + return Command( + update={ + "messages": [ + ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id) + ] + }, + goto="bob", + graph=Command.PARENT, + ) + + @dec_tool + async def async_transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): + """Transfer to Bob""" + return Command( + update={ + "messages": [ + ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id) + ] + }, + goto="bob", + graph=Command.PARENT, + ) + + class CustomToolSchema(BaseModel): + tool_call_id: Annotated[str, InjectedToolCallId] + + class MyCustomTool(BaseTool): + def _run(*args: Any, **kwargs: Any): + return Command( + update={ + "messages": [ + ToolMessage( + content="Transferred to Bob", + tool_call_id=kwargs["tool_call_id"], + ) + ] + }, + goto="bob", + graph=Command.PARENT, + ) + + async def _arun(*args: Any, **kwargs: Any): + return Command( + update={ + "messages": [ + ToolMessage( + content="Transferred to Bob", + tool_call_id=kwargs["tool_call_id"], + ) + ] + }, + goto="bob", + graph=Command.PARENT, + ) + + custom_tool = MyCustomTool( + name="custom_transfer_to_bob", + description="Transfer to bob", + args_schema=CustomToolSchema, + ) + async_custom_tool = MyCustomTool( + name="async_custom_transfer_to_bob", + description="Transfer to bob", + args_schema=CustomToolSchema, + ) + + # test mixing regular tools and tools returning commands + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + result = ToolNode([add, transfer_to_bob]).invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[ + {"args": {"a": 1, "b": 2}, "id": "1", "name": "add"}, + {"args": {}, "id": "2", "name": "transfer_to_bob"}, + ], + ) + ] + } + ) + + assert result == [ + { + "messages": [ + ToolMessage( + content="3", + tool_call_id="1", + name="add", + ) + ] + }, + Command( + update={ + "messages": [ + ToolMessage( + content="Transferred to Bob", + tool_call_id="2", + name="transfer_to_bob", + ) + ] + }, + goto="bob", + graph=Command.PARENT, + ), + ] + + # test tools returning commands + + # test sync tools + for tool in [transfer_to_bob, custom_tool]: + result = ToolNode([tool]).invoke( + { + "messages": [ + AIMessage( + "", tool_calls=[{"args": {}, "id": "1", "name": tool.name}] + ) + ] + } + ) + assert result == [ + Command( + update={ + "messages": [ + ToolMessage( + content="Transferred to Bob", + tool_call_id="1", + name=tool.name, + ) + ] + }, + goto="bob", + graph=Command.PARENT, + ) + ] + + # test async tools + for tool in [async_transfer_to_bob, async_custom_tool]: + result = await ToolNode([tool]).ainvoke( + { + "messages": [ + AIMessage( + "", tool_calls=[{"args": {}, "id": "1", "name": tool.name}] + ) + ] + } + ) + assert result == [ + Command( + update={ + "messages": [ + ToolMessage( + content="Transferred to Bob", + tool_call_id="1", + name=tool.name, + ) + ] + }, + goto="bob", + graph=Command.PARENT, + ) + ] + + # test multiple commands + result = ToolNode([transfer_to_bob, custom_tool]).invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[ + {"args": {}, "id": "1", "name": "transfer_to_bob"}, + {"args": {}, "id": "2", "name": "custom_transfer_to_bob"}, + ], + ) + ] + } + ) + assert result == [ + Command( + update={ + "messages": [ + ToolMessage( + content="Transferred to Bob", + tool_call_id="1", + name="transfer_to_bob", + ) + ] + }, + goto="bob", + graph=Command.PARENT, + ), + Command( + update={ + "messages": [ + ToolMessage( + content="Transferred to Bob", + tool_call_id="2", + name="custom_transfer_to_bob", + ) + ] + }, + goto="bob", + graph=Command.PARENT, + ), + ] + + # test validation (mismatch between input type and command.update type) + with pytest.raises(ValueError): + + @dec_tool + def list_update_tool(tool_call_id: Annotated[str, InjectedToolCallId]): + """My tool""" + return Command( + update=[ToolMessage(content="foo", tool_call_id=tool_call_id)] + ) + + ToolNode([list_update_tool]).invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[ + {"args": {}, "id": "1", "name": "list_update_tool"} + ], + ) + ] + } + ) + + # test validation (missing tool message in the update for current graph) + with pytest.raises(ValueError): + + @dec_tool + def no_update_tool(): + """My tool""" + return Command(update={"messages": []}) + + ToolNode([no_update_tool]).invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[{"args": {}, "id": "1", "name": "no_update_tool"}], + ) + ] + } + ) + + # test validation (missing tool message in the update for parent graph is OK) + @dec_tool + def node_update_parent_tool(): + """No update""" + return Command(update={"messages": []}, graph=Command.PARENT) + + assert ToolNode([node_update_parent_tool]).invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[ + {"args": {}, "id": "1", "name": "node_update_parent_tool"} + ], + ) + ] + } + ) == [Command(update={"messages": []}, graph=Command.PARENT)] + + # test validation (multiple tool messages) + with pytest.raises(ValueError): + for graph in (None, Command.PARENT): + + @dec_tool + def multiple_tool_messages_tool(): + """My tool""" + return Command( + update={ + "messages": [ + ToolMessage(content="foo", tool_call_id=""), + ToolMessage(content="bar", tool_call_id=""), + ] + }, + graph=graph, + ) + + ToolNode([multiple_tool_messages_tool]).invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[ + { + "args": {}, + "id": "1", + "name": "multiple_tool_messages_tool", + } + ], + ) + ] + } + ) + + +@pytest.mark.skipif( + not IS_LANGCHAIN_CORE_030_OR_GREATER, + reason="Langchain core 0.3.0 or greater is required", +) +async def test_tool_node_command_list_input(): + from langchain_core.tools.base import InjectedToolCallId + + @dec_tool + def transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): + """Transfer to Bob""" + return Command( + update=[ + ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id) + ], + goto="bob", + graph=Command.PARENT, + ) + + @dec_tool + async def async_transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): + """Transfer to Bob""" + return Command( + update=[ + ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id) + ], + goto="bob", + graph=Command.PARENT, + ) + + class CustomToolSchema(BaseModel): + tool_call_id: Annotated[str, InjectedToolCallId] + + class MyCustomTool(BaseTool): + def _run(*args: Any, **kwargs: Any): + return Command( + update=[ + ToolMessage( + content="Transferred to Bob", + tool_call_id=kwargs["tool_call_id"], + ) + ], + goto="bob", + graph=Command.PARENT, + ) + + async def _arun(*args: Any, **kwargs: Any): + return Command( + update=[ + ToolMessage( + content="Transferred to Bob", + tool_call_id=kwargs["tool_call_id"], + ) + ], + goto="bob", + graph=Command.PARENT, + ) + + custom_tool = MyCustomTool( + name="custom_transfer_to_bob", + description="Transfer to bob", + args_schema=CustomToolSchema, + ) + async_custom_tool = MyCustomTool( + name="async_custom_transfer_to_bob", + description="Transfer to bob", + args_schema=CustomToolSchema, + ) + + # test mixing regular tools and tools returning commands + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + result = ToolNode([add, transfer_to_bob]).invoke( + [ + AIMessage( + "", + tool_calls=[ + {"args": {"a": 1, "b": 2}, "id": "1", "name": "add"}, + {"args": {}, "id": "2", "name": "transfer_to_bob"}, + ], + ) + ] + ) + + assert result == [ + [ + ToolMessage( + content="3", + tool_call_id="1", + name="add", + ) + ], + Command( + update=[ + ToolMessage( + content="Transferred to Bob", + tool_call_id="2", + name="transfer_to_bob", + ) + ], + goto="bob", + graph=Command.PARENT, + ), + ] + + # test tools returning commands + + # test sync tools + for tool in [transfer_to_bob, custom_tool]: + result = ToolNode([tool]).invoke( + [AIMessage("", tool_calls=[{"args": {}, "id": "1", "name": tool.name}])] + ) + assert result == [ + Command( + update=[ + ToolMessage( + content="Transferred to Bob", + tool_call_id="1", + name=tool.name, + ) + ], + goto="bob", + graph=Command.PARENT, + ) + ] + + # test async tools + for tool in [async_transfer_to_bob, async_custom_tool]: + result = await ToolNode([tool]).ainvoke( + [AIMessage("", tool_calls=[{"args": {}, "id": "1", "name": tool.name}])] + ) + assert result == [ + Command( + update=[ + ToolMessage( + content="Transferred to Bob", + tool_call_id="1", + name=tool.name, + ) + ], + goto="bob", + graph=Command.PARENT, + ) + ] + + # test multiple commands + result = ToolNode([transfer_to_bob, custom_tool]).invoke( + [ + AIMessage( + "", + tool_calls=[ + {"args": {}, "id": "1", "name": "transfer_to_bob"}, + {"args": {}, "id": "2", "name": "custom_transfer_to_bob"}, + ], + ) + ] + ) + assert result == [ + Command( + update=[ + ToolMessage( + content="Transferred to Bob", + tool_call_id="1", + name="transfer_to_bob", + ) + ], + goto="bob", + graph=Command.PARENT, + ), + Command( + update=[ + ToolMessage( + content="Transferred to Bob", + tool_call_id="2", + name="custom_transfer_to_bob", + ) + ], + goto="bob", + graph=Command.PARENT, + ), + ] + + # test validation (mismatch between input type and command.update type) + with pytest.raises(ValueError): + + @dec_tool + def list_update_tool(tool_call_id: Annotated[str, InjectedToolCallId]): + """My tool""" + return Command( + update={ + "messages": [ToolMessage(content="foo", tool_call_id=tool_call_id)] + } + ) + + ToolNode([list_update_tool]).invoke( + [ + AIMessage( + "", + tool_calls=[{"args": {}, "id": "1", "name": "list_update_tool"}], + ) + ] + ) + + # test validation (missing tool message in the update for current graph) + with pytest.raises(ValueError): + + @dec_tool + def no_update_tool(): + """My tool""" + return Command(update=[]) + + ToolNode([no_update_tool]).invoke( + [ + AIMessage( + "", + tool_calls=[{"args": {}, "id": "1", "name": "no_update_tool"}], + ) + ] + ) + + # test validation (missing tool message in the update for parent graph is OK) + @dec_tool + def node_update_parent_tool(): + """No update""" + return Command(update=[], graph=Command.PARENT) + + assert ToolNode([node_update_parent_tool]).invoke( + [ + AIMessage( + "", + tool_calls=[{"args": {}, "id": "1", "name": "node_update_parent_tool"}], + ) + ] + ) == [Command(update=[], graph=Command.PARENT)] + + # test validation (multiple tool messages) + with pytest.raises(ValueError): + for graph in (None, Command.PARENT): + + @dec_tool + def multiple_tool_messages_tool(): + """My tool""" + return Command( + update=[ + ToolMessage(content="foo", tool_call_id=""), + ToolMessage(content="bar", tool_call_id=""), + ], + graph=graph, + ) + + ToolNode([multiple_tool_messages_tool]).invoke( + [ + AIMessage( + "", + tool_calls=[ + { + "args": {}, + "id": "1", + "name": "multiple_tool_messages_tool", + } + ], + ) + ] + ) + + +@pytest.mark.skipif( + not IS_LANGCHAIN_CORE_030_OR_GREATER, + reason="Langchain core 0.3.0 or greater is required", +) +def test_react_agent_update_state(): + from langchain_core.tools.base import InjectedToolCallId + + class State(AgentState): + user_name: str + + @dec_tool + def get_user_name(tool_call_id: Annotated[str, InjectedToolCallId]): + """Retrieve user name""" + user_name = interrupt("Please provider user name:") + return Command( + update={ + "user_name": user_name, + "messages": [ + ToolMessage( + "Successfully retrieved user name", tool_call_id=tool_call_id + ) + ], + } + ) + + def state_modifier(state: State): + user_name = state.get("user_name") + if user_name is None: + return state["messages"] + + system_msg = f"User name is {user_name}" + return [{"role": "system", "content": system_msg}] + state["messages"] + + checkpointer = MemorySaver() + tool_calls = [[{"args": {}, "id": "1", "name": "get_user_name"}]] + model = FakeToolCallingModel(tool_calls=tool_calls) + agent = create_react_agent( + model, + [get_user_name], + state_schema=State, + state_modifier=state_modifier, + checkpointer=checkpointer, + ) + config = {"configurable": {"thread_id": "1"}} + # run until interrpupted + agent.invoke({"messages": [("user", "what's my name")]}, config) + # supply the value for the interrupt + response = agent.invoke(Command(resume="Archibald"), config) + # confirm that the state was updated + assert response["user_name"] == "Archibald" + assert len(response["messages"]) == 4 + tool_message: ToolMessage = response["messages"][-2] + assert tool_message.content == "Successfully retrieved user name" + assert tool_message.tool_call_id == "1" + assert tool_message.name == "get_user_name" + + def my_function(some_val: int, some_other_val: str) -> str: return f"{some_val} - {some_other_val}"