diff --git a/libs/langgraph/langgraph/prebuilt/tool_node.py b/libs/langgraph/langgraph/prebuilt/tool_node.py index 91e6483a1..b7db3c1ce 100644 --- a/libs/langgraph/langgraph/prebuilt/tool_node.py +++ b/libs/langgraph/langgraph/prebuilt/tool_node.py @@ -70,8 +70,8 @@ def msg_content_output(output: Any) -> str | List[dict]: class ToolNode(RunnableCallable): """A node that runs the tools called in the last AIMessage. - It can be used either in StateGraph with a "messages" key or in MessageGraph. If - multiple tool calls are requested, they will be run in parallel. The output will be + It can be used either in StateGraph with a "messages" key (or a custom key passed via ToolNode's 'messages_key'). + If multiple tool calls are requested, they will be run in parallel. The output will be a list of ToolMessages, one for each tool call. The `ToolNode` is roughly analogous to: @@ -102,12 +102,14 @@ def __init__( name: str = "tools", tags: Optional[list[str]] = None, handle_tool_errors: Optional[bool] = True, + messages_key: str = "messages", ) -> None: super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False) self.tools_by_name: Dict[str, BaseTool] = {} self.tool_to_state_args: Dict[str, Dict[str, Optional[str]]] = {} self.tool_to_store_arg: Dict[str, Optional[str]] = {} self.handle_tool_errors = handle_tool_errors + self.messages_key = messages_key for tool_ in tools: if not isinstance(tool_, BaseTool): tool_ = cast(BaseTool, create_tool(tool_)) @@ -131,7 +133,7 @@ def _func( with get_executor_for_config(config) as executor: outputs = [*executor.map(self._run_one, tool_calls, config_list)] # TypedDict, pydantic, dataclass, etc. should all be able to load from dict - return outputs if output_type == "list" else {"messages": outputs} + return outputs if output_type == "list" else {self.messages_key: outputs} def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any @@ -163,7 +165,7 @@ async def _afunc( *(self._arun_one(call, config) for call in tool_calls) ) # TypedDict, pydantic, dataclass, etc. should all be able to load from dict - return outputs if output_type == "list" else {"messages": outputs} + return outputs if output_type == "list" else {self.messages_key: outputs} def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage: if invalid_tool_message := self._validate_tool_call(call): @@ -214,10 +216,10 @@ def _parse_input( if isinstance(input, list): output_type = "list" message: AnyMessage = input[-1] - elif isinstance(input, dict) and (messages := input.get("messages", [])): + elif isinstance(input, dict) and (messages := input.get(self.messages_key, [])): output_type = "dict" message = messages[-1] - elif messages := getattr(input, "messages", None): + elif messages := getattr(input, self.messages_key, None): # Assume dataclass-like state that can coerce from dict output_type = "dict" message = messages[-1] @@ -256,10 +258,10 @@ def _inject_state( required_fields = list(state_args.values()) if ( len(required_fields) == 1 - and required_fields[0] == "messages" + and required_fields[0] == self.messages_key or required_fields[0] is None ): - input = {"messages": input} + input = {self.messages_key: input} else: err_msg = ( f"Invalid input to ToolNode. Tool {tool_call['name']} requires " @@ -325,6 +327,7 @@ def _inject_tool_args( def tools_condition( state: Union[list[AnyMessage], dict[str, Any], BaseModel], + messages_key: str = "messages", ) -> Literal["tools", "__end__"]: """Use in the conditional_edge to route to the ToolNode if the last message @@ -377,9 +380,9 @@ def tools_condition( """ if isinstance(state, list): ai_message = state[-1] - elif isinstance(state, dict) and (messages := state.get("messages", [])): + elif isinstance(state, dict) and (messages := state.get(messages_key, [])): ai_message = messages[-1] - elif messages := getattr(state, "messages", []): + elif messages := getattr(state, messages_key, []): ai_message = messages[-1] else: raise ValueError(f"No messages found in input state to tool_edge: {state}") diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index c1176f365..2ebc2f6cb 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -1,5 +1,6 @@ import dataclasses import json +from functools import partial from typing import ( Annotated, Any, @@ -35,8 +36,13 @@ from typing_extensions import TypedDict from langgraph.checkpoint.base import BaseCheckpointSaver -from langgraph.graph import START, MessagesState, StateGraph -from langgraph.prebuilt import ToolNode, ValidationNode, create_react_agent +from langgraph.graph import START, MessagesState, StateGraph, add_messages +from langgraph.prebuilt import ( + ToolNode, + ValidationNode, + create_react_agent, + tools_condition, +) from langgraph.prebuilt.tool_node import InjectedState, InjectedStore from langgraph.store.base import BaseStore from langgraph.store.memory import InMemoryStore @@ -46,7 +52,7 @@ IS_LANGCHAIN_CORE_030_OR_GREATER, awith_checkpointer, ) -from tests.messages import _AnyIdHumanMessage +from tests.messages import _AnyIdHumanMessage, _AnyIdToolMessage pytestmark = pytest.mark.anyio @@ -826,6 +832,47 @@ def get_day_list(days: list[str]) -> list[str]: assert outputs[0].content == json.dumps(data, ensure_ascii=False) +def test_tool_node_messages_key() -> None: + @dec_tool + def add(a: int, b: int): + """Adds a and b.""" + return a + b + + model = FakeToolCallingModel( + tool_calls=[[ToolCall(name=add.name, args={"a": 1, "b": 2}, id="test_id")]] + ) + + class State(TypedDict): + subgraph_messages: Annotated[list[AnyMessage], add_messages] + + def call_model(state: State): + response = model.invoke(state["subgraph_messages"]) + model.tool_calls = [] + return {"subgraph_messages": response} + + builder = StateGraph(State) + builder.add_node("agent", call_model) + builder.add_node("tools", ToolNode([add], messages_key="subgraph_messages")) + builder.add_conditional_edges( + "agent", partial(tools_condition, messages_key="subgraph_messages") + ) + builder.add_edge(START, "agent") + builder.add_edge("tools", "agent") + + graph = builder.compile() + result = graph.invoke({"subgraph_messages": [HumanMessage(content="hi")]}) + assert result["subgraph_messages"] == [ + _AnyIdHumanMessage(content="hi"), + AIMessage( + content="hi", + id="0", + tool_calls=[ToolCall(name=add.name, args={"a": 1, "b": 2}, id="test_id")], + ), + _AnyIdToolMessage(content="3", name=add.name, tool_call_id="test_id"), + AIMessage(content="hi-hi-3", id="1"), + ] + + async def test_return_direct() -> None: @dec_tool(return_direct=True) def tool_return_direct(input: str) -> str: