diff --git a/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py b/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py index dd2cddb5e7..1e2209bd77 100644 --- a/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py +++ b/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py @@ -11,17 +11,13 @@ ) from langchain_core.language_models import BaseChatModel -from langchain_core.messages import ( - AIMessage, - BaseMessage, - SystemMessage, -) +from langchain_core.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda from langchain_core.tools import BaseTool from langgraph._api.deprecation import deprecated_parameter from langgraph.checkpoint.base import BaseCheckpointSaver -from langgraph.graph import END, StateGraph +from langgraph.graph import StateGraph from langgraph.graph.graph import CompiledGraph from langgraph.graph.message import add_messages from langgraph.managed import IsLastStep @@ -430,15 +426,15 @@ class Agent,Tools otherClass model = model.bind_tools(tool_classes) # Define the function that determines whether to continue or not - def should_continue(state: AgentState) -> Literal["continue", "end"]: + def should_continue(state: AgentState) -> Literal["tools", "__end__"]: messages = state["messages"] last_message = messages[-1] # If there is no function call, then we finish if not isinstance(last_message, AIMessage) or not last_message.tool_calls: - return "end" + return "__end__" # Otherwise if there is, we continue else: - return "continue" + return "tools" preprocessor = _get_model_preprocessing_runnable(state_modifier, messages_modifier) model_runnable = preprocessor | model @@ -498,23 +494,24 @@ async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: "agent", # Next, we pass in the function that will determine which node is called next. should_continue, - # Finally we pass in a mapping. - # The keys are strings, and the values are other nodes. - # END is a special node marking that the graph should finish. - # What will happen is we will call `should_continue`, and then the output of that - # will be matched against the keys in this mapping. - # Based on which one it matches, that node will then be called. - { - # If `tools`, then we call the tool node. - "continue": "tools", - # Otherwise we finish. - "end": END, - }, ) - # We now add a normal edge from `tools` to `agent`. - # This means that after `tools` is called, `agent` node is called next. - workflow.add_edge("tools", "agent") + # If any of the tools are configured to return_directly after running, + # our graph needs to check if these were called + should_return_direct = {t.name for t in tool_classes if t.return_direct} + + def route_tool_responses(state: AgentState) -> Literal["agent", "__end__"]: + for m in reversed(state["messages"]): + if not isinstance(m, ToolMessage): + break + if m.name in should_return_direct: + return "__end__" + return "agent" + + if should_return_direct: + workflow.add_conditional_edges("tools", route_tool_responses) + else: + workflow.add_edge("tools", "agent") # Finally, we compile it! # This compiles it into a LangChain Runnable, diff --git a/libs/langgraph/tests/__snapshots__/test_pregel.ambr b/libs/langgraph/tests/__snapshots__/test_pregel.ambr index 9354409a71..2e50030e3e 100644 --- a/libs/langgraph/tests/__snapshots__/test_pregel.ambr +++ b/libs/langgraph/tests/__snapshots__/test_pregel.ambr @@ -4857,13 +4857,11 @@ { "source": "agent", "target": "tools", - "data": "continue", "conditional": true }, { "source": "agent", "target": "__end__", - "data": "end", "conditional": true } ] @@ -4875,8 +4873,8 @@ graph TD; __start__ --> agent; tools --> agent; - agent -.  continue  .-> tools; - agent -.  end  .-> __end__; + agent -.-> tools; + agent -.-> __end__; ''' # --- diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index 42dca7ce96..0fc1685e56 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -47,6 +47,9 @@ class FakeToolCallingModel(BaseChatModel): + tool_calls: Optional[list[list[ToolCall]]] = None + index: int = 0 + def _generate( self, messages: List[BaseMessage], @@ -56,7 +59,15 @@ def _generate( ) -> ChatResult: """Top Level call""" messages_string = "-".join([m.content for m in messages]) - message = AIMessage(content=messages_string, id="0") + tool_calls = ( + self.tool_calls[self.index % len(self.tool_calls)] + if self.tool_calls + else [] + ) + message = AIMessage( + content=messages_string, id=str(self.index), tool_calls=tool_calls.copy() + ) + self.index += 1 return ChatResult(generations=[ChatGeneration(message=message)]) @property @@ -68,8 +79,6 @@ def bind_tools( tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: - if len(tools) > 0: - raise ValueError("Not supported yet!") return self @@ -144,29 +153,35 @@ def test_passing_two_modifiers(): def test_system_message_modifier(): - model = FakeToolCallingModel() messages_modifier = SystemMessage(content="Foo") - agent_1 = create_react_agent(model, [], messages_modifier=messages_modifier) - agent_2 = create_react_agent(model, [], state_modifier=messages_modifier) + agent_1 = create_react_agent( + FakeToolCallingModel(), [], messages_modifier=messages_modifier + ) + agent_2 = create_react_agent( + FakeToolCallingModel(), [], state_modifier=messages_modifier + ) for agent in [agent_1, agent_2]: inputs = [HumanMessage("hi?")] response = agent.invoke({"messages": inputs}) expected_response = { - "messages": inputs + [AIMessage(content="Foo-hi?", id="0")] + "messages": inputs + [AIMessage(content="Foo-hi?", id="0", tool_calls=[])] } assert response == expected_response def test_system_message_string_modifier(): - model = FakeToolCallingModel() messages_modifier = "Foo" - agent_1 = create_react_agent(model, [], messages_modifier=messages_modifier) - agent_2 = create_react_agent(model, [], state_modifier=messages_modifier) + agent_1 = create_react_agent( + FakeToolCallingModel(), [], messages_modifier=messages_modifier + ) + agent_2 = create_react_agent( + FakeToolCallingModel(), [], state_modifier=messages_modifier + ) for agent in [agent_1, agent_2]: inputs = [HumanMessage("hi?")] response = agent.invoke({"messages": inputs}) expected_response = { - "messages": inputs + [AIMessage(content="Foo-hi?", id="0")] + "messages": inputs + [AIMessage(content="Foo-hi?", id="0", tool_calls=[])] } assert response == expected_response @@ -584,3 +599,100 @@ def get_day_list(days: list[str]) -> list[str]: [AIMessage(content="", tool_calls=tool_calls)] ) assert outputs[0].content == json.dumps(data, ensure_ascii=False) + + +async def test_return_direct() -> None: + @dec_tool(return_direct=True) + def tool_return_direct(input: str) -> str: + """A tool that returns directly.""" + return f"Direct result: {input}" + + @dec_tool + def tool_normal(input: str) -> str: + """A normal tool.""" + return f"Normal result: {input}" + + first_tool_call = [ + ToolCall( + name="tool_return_direct", + args={"input": "Test direct"}, + id="1", + ), + ] + expected_ai = AIMessage( + content="Test direct", + id="0", + tool_calls=first_tool_call, + ) + model = FakeToolCallingModel(tool_calls=[first_tool_call, []]) + agent = create_react_agent(model, [tool_return_direct, tool_normal]) + + # Test direct return for tool_return_direct + result = agent.invoke( + {"messages": [HumanMessage(content="Test direct", id="hum0")]} + ) + assert result["messages"] == [ + HumanMessage(content="Test direct", id="hum0"), + expected_ai, + ToolMessage( + content="Direct result: Test direct", + name="tool_return_direct", + tool_call_id="1", + id=result["messages"][2].id, + ), + ] + second_tool_call = [ + ToolCall( + name="tool_normal", + args={"input": "Test normal"}, + id="2", + ), + ] + model = FakeToolCallingModel(tool_calls=[second_tool_call, []]) + agent = create_react_agent(model, [tool_return_direct, tool_normal]) + result = agent.invoke( + {"messages": [HumanMessage(content="Test normal", id="hum1")]} + ) + assert result["messages"] == [ + HumanMessage(content="Test normal", id="hum1"), + AIMessage(content="Test normal", id="0", tool_calls=second_tool_call), + ToolMessage( + content="Normal result: Test normal", + name="tool_normal", + tool_call_id="2", + id=result["messages"][2].id, + ), + AIMessage(content="Test normal-Test normal-Normal result: Test normal", id="1"), + ] + + both_tool_calls = [ + ToolCall( + name="tool_return_direct", + args={"input": "Test both direct"}, + id="3", + ), + ToolCall( + name="tool_normal", + args={"input": "Test both normal"}, + id="4", + ), + ] + model = FakeToolCallingModel(tool_calls=[both_tool_calls, []]) + agent = create_react_agent(model, [tool_return_direct, tool_normal]) + result = agent.invoke({"messages": [HumanMessage(content="Test both", id="hum2")]}) + assert result["messages"] == [ + HumanMessage(content="Test both", id="hum2"), + AIMessage(content="Test both", id="0", tool_calls=both_tool_calls), + ToolMessage( + content="Direct result: Test both direct", + name="tool_return_direct", + tool_call_id="3", + id=result["messages"][2].id, + ), + ToolMessage( + content="Normal result: Test both normal", + name="tool_normal", + tool_call_id="4", + id=result["messages"][3].id, + ), + ]