Skip to content

Commit

Permalink
Add return_direct support
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Sep 20, 2024
1 parent b9fe384 commit fbc272a
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 36 deletions.
45 changes: 21 additions & 24 deletions libs/langgraph/langgraph/prebuilt/chat_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,13 @@
)

from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
SystemMessage,
)
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage
from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool

from langgraph._api.deprecation import deprecated_parameter
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph import END, StateGraph
from langgraph.graph import StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import add_messages
from langgraph.managed import IsLastStep
Expand Down Expand Up @@ -430,15 +426,15 @@ class Agent,Tools otherClass
model = model.bind_tools(tool_classes)

# Define the function that determines whether to continue or not
def should_continue(state: AgentState) -> Literal["continue", "end"]:
def should_continue(state: AgentState) -> Literal["tools", "__end__"]:
messages = state["messages"]
last_message = messages[-1]
# If there is no function call, then we finish
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
return "end"
return "__end__"
# Otherwise if there is, we continue
else:
return "continue"
return "tools"

preprocessor = _get_model_preprocessing_runnable(state_modifier, messages_modifier)
model_runnable = preprocessor | model
Expand Down Expand Up @@ -498,23 +494,24 @@ async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
"agent",
# Next, we pass in the function that will determine which node is called next.
should_continue,
# Finally we pass in a mapping.
# The keys are strings, and the values are other nodes.
# END is a special node marking that the graph should finish.
# What will happen is we will call `should_continue`, and then the output of that
# will be matched against the keys in this mapping.
# Based on which one it matches, that node will then be called.
{
# If `tools`, then we call the tool node.
"continue": "tools",
# Otherwise we finish.
"end": END,
},
)

# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("tools", "agent")
# If any of the tools are configured to return_directly after running,
# our graph needs to check if these were called
should_return_direct = {t.name for t in tool_classes if t.return_direct}

def route_tool_responses(state: AgentState) -> Literal["agent", "__end__"]:
for m in reversed(state["messages"]):
if not isinstance(m, ToolMessage):
break
if m.name in should_return_direct:
return "__end__"
return "agent"

if should_return_direct:
workflow.add_conditional_edges("tools", route_tool_responses)
else:
workflow.add_edge("tools", "agent")

# Finally, we compile it!
# This compiles it into a LangChain Runnable,
Expand Down
23 changes: 14 additions & 9 deletions libs/langgraph/tests/__snapshots__/test_pregel.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -4851,19 +4851,23 @@
"target": "agent"
},
{
"source": "tools",
"target": "agent"
"source": "agent",
"target": "tools",
"conditional": true
},
{
"source": "agent",
"target": "__end__",
"conditional": true
},
{
"source": "tools",
"target": "tools",
"data": "continue",
"conditional": true
},
{
"source": "agent",
"source": "tools",
"target": "__end__",
"data": "end",
"conditional": true
}
]
Expand All @@ -4874,9 +4878,10 @@
'''
graph TD;
__start__ --> agent;
tools --> agent;
agent -.  continue  .-> tools;
agent -.  end  .-> __end__;
agent -.-> tools;
agent -.-> __end__;
tools -.-> __end__;
tools -.-> tools;

'''
# ---
Expand Down Expand Up @@ -5017,7 +5022,7 @@
'{"title": "LangGraphOutput", "type": "object", "properties": {"input": {"title": "Input", "type": "string"}, "agent_outcome": {"title": "Agent Outcome", "anyOf": [{"$ref": "#/definitions/AgentAction"}, {"$ref": "#/definitions/AgentFinish"}]}, "intermediate_steps": {"title": "Intermediate Steps", "type": "array", "items": {"type": "array", "minItems": 2, "maxItems": 2, "items": [{"$ref": "#/definitions/AgentAction"}, {"type": "string"}]}}}, "definitions": {"AgentAction": {"title": "AgentAction", "description": "Represents a request to execute an action by an agent.\\n\\nThe action consists of the name of the tool to execute and the input to pass\\nto the tool. The log is used to pass along extra information about the action.", "type": "object", "properties": {"tool": {"title": "Tool", "type": "string"}, "tool_input": {"title": "Tool Input", "anyOf": [{"type": "string"}, {"type": "object"}]}, "log": {"title": "Log", "type": "string"}, "type": {"title": "Type", "default": "AgentAction", "enum": ["AgentAction"], "type": "string"}}, "required": ["tool", "tool_input", "log"]}, "AgentFinish": {"title": "AgentFinish", "description": "Final return value of an ActionAgent.\\n\\nAgents return an AgentFinish when they have reached a stopping condition.", "type": "object", "properties": {"return_values": {"title": "Return Values", "type": "object"}, "log": {"title": "Log", "type": "string"}, "type": {"title": "Type", "default": "AgentFinish", "enum": ["AgentFinish"], "type": "string"}}, "required": ["return_values", "log"]}}}'
# ---
# name: test_state_graph_w_config_inherited_state_keys
'{"$defs": {"Configurable": {"properties": {"tools": {"default": null, "items": {"type": "string"}, "title": "Tools", "type": "array"}}, "title": "Configurable", "type": "object"}}, "properties": {"configurable": {"allOf": [{"$ref": "#/$defs/Configurable"}], "default": null}}, "title": "LangGraphConfig", "type": "object"}'
'{"$defs": {"Configurable": {"properties": {"tools": {"default": null, "items": {"type": "string"}, "title": "Tools", "type": "array"}}, "title": "Configurable", "type": "object"}}, "properties": {"configurable": {"$ref": "#/$defs/Configurable", "default": null}}, "title": "LangGraphConfig", "type": "object"}'
# ---
# name: test_state_graph_w_config_inherited_state_keys.1
'{"$defs": {"AgentAction": {"description": "Represents a request to execute an action by an agent.\\n\\nThe action consists of the name of the tool to execute and the input to pass\\nto the tool. The log is used to pass along extra information about the action.", "properties": {"tool": {"title": "Tool", "type": "string"}, "tool_input": {"anyOf": [{"type": "string"}, {"type": "object"}], "title": "Tool Input"}, "log": {"title": "Log", "type": "string"}, "type": {"const": "AgentAction", "default": "AgentAction", "enum": ["AgentAction"], "title": "Type", "type": "string"}}, "required": ["tool", "tool_input", "log"], "title": "AgentAction", "type": "object"}, "AgentFinish": {"description": "Final return value of an ActionAgent.\\n\\nAgents return an AgentFinish when they have reached a stopping condition.", "properties": {"return_values": {"title": "Return Values", "type": "object"}, "log": {"title": "Log", "type": "string"}, "type": {"const": "AgentFinish", "default": "AgentFinish", "enum": ["AgentFinish"], "title": "Type", "type": "string"}}, "required": ["return_values", "log"], "title": "AgentFinish", "type": "object"}}, "properties": {"input": {"title": "Input", "type": "string"}, "agent_outcome": {"anyOf": [{"$ref": "#/$defs/AgentAction"}, {"$ref": "#/$defs/AgentFinish"}, {"type": "null"}], "default": null, "title": "Agent Outcome"}, "intermediate_steps": {"default": null, "items": {"maxItems": 2, "minItems": 2, "prefixItems": [{"$ref": "#/$defs/AgentAction"}, {"type": "string"}], "type": "array"}, "title": "Intermediate Steps", "type": "array"}}, "required": ["input"], "title": "LangGraphInput", "type": "object"}'
Expand Down
112 changes: 109 additions & 3 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@


class FakeToolCallingModel(BaseChatModel):
tool_calls: Optional[list[list[ToolCall]]] = None
index: int = 0

def _generate(
self,
messages: List[BaseMessage],
Expand All @@ -56,7 +59,15 @@ def _generate(
) -> ChatResult:
"""Top Level call"""
messages_string = "-".join([m.content for m in messages])
message = AIMessage(content=messages_string, id="0")
tool_calls = (
self.tool_calls[self.index % len(self.tool_calls)]
if self.tool_calls
else []
)
message = AIMessage(
content=messages_string, id=str(self.index), tool_calls=tool_calls.copy()
)
self.index += 1
return ChatResult(generations=[ChatGeneration(message=message)])

@property
Expand All @@ -68,8 +79,6 @@ def bind_tools(
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
if len(tools) > 0:
raise ValueError("Not supported yet!")
return self


Expand Down Expand Up @@ -584,3 +593,100 @@ def get_day_list(days: list[str]) -> list[str]:
[AIMessage(content="", tool_calls=tool_calls)]
)
assert outputs[0].content == json.dumps(data, ensure_ascii=False)


async def test_return_direct() -> None:
@dec_tool(return_direct=True)
def tool_return_direct(input: str) -> str:
"""A tool that returns directly."""
return f"Direct result: {input}"

@dec_tool
def tool_normal(input: str) -> str:
"""A normal tool."""
return f"Normal result: {input}"

first_tool_call = [
ToolCall(
name="tool_return_direct",
args={"input": "Test direct"},
id="1",
),
]
expected_ai = AIMessage(
content="Test direct",
id="0",
tool_calls=first_tool_call,
)
model = FakeToolCallingModel(tool_calls=[first_tool_call, []])
agent = create_react_agent(model, [tool_return_direct, tool_normal])

# Test direct return for tool_return_direct
result = agent.invoke(
{"messages": [HumanMessage(content="Test direct", id="hum0")]}
)
assert result["messages"] == [
HumanMessage(content="Test direct", id="hum0"),
expected_ai,
ToolMessage(
content="Direct result: Test direct",
name="tool_return_direct",
tool_call_id="1",
id=result["messages"][2].id,
),
]
second_tool_call = [
ToolCall(
name="tool_normal",
args={"input": "Test normal"},
id="2",
),
]
model = FakeToolCallingModel(tool_calls=[second_tool_call, []])
agent = create_react_agent(model, [tool_return_direct, tool_normal])
result = agent.invoke(
{"messages": [HumanMessage(content="Test normal", id="hum1")]}
)
assert result["messages"] == [
HumanMessage(content="Test normal", id="hum1"),
AIMessage(content="Test normal", id="0", tool_calls=second_tool_call),
ToolMessage(
content="Normal result: Test normal",
name="tool_normal",
tool_call_id="2",
id=result["messages"][2].id,
),
AIMessage(content="Test normal-Test normal-Normal result: Test normal", id="1"),
]

both_tool_calls = [
ToolCall(
name="tool_return_direct",
args={"input": "Test both direct"},
id="3",
),
ToolCall(
name="tool_normal",
args={"input": "Test both normal"},
id="4",
),
]
model = FakeToolCallingModel(tool_calls=[both_tool_calls, []])
agent = create_react_agent(model, [tool_return_direct, tool_normal])
result = agent.invoke({"messages": [HumanMessage(content="Test both", id="hum2")]})
assert result["messages"] == [
HumanMessage(content="Test both", id="hum2"),
AIMessage(content="Test both", id="0", tool_calls=both_tool_calls),
ToolMessage(
content="Direct result: Test both direct",
name="tool_return_direct",
tool_call_id="3",
id=result["messages"][2].id,
),
ToolMessage(
content="Normal result: Test both normal",
name="tool_normal",
tool_call_id="4",
id=result["messages"][3].id,
),
]

0 comments on commit fbc272a

Please sign in to comment.