Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langgraph: allow tools to return Command in tool node #2656

Merged
merged 28 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
0bd42c0
langgraph: allow tools to return Command in tool node
vbarda Dec 5, 2024
1b51f23
add test
vbarda Dec 5, 2024
bf85c09
spelling
vbarda Dec 5, 2024
4ded6df
Merge branch 'main' into vb/tool-node-support-command
vbarda Dec 5, 2024
983ad82
Merge branch 'main' into vb/tool-node-support-command
vbarda Dec 6, 2024
ac30db5
update to support multiple commands
vbarda Dec 6, 2024
b9e4dbf
use args
vbarda Dec 6, 2024
c75d743
don't use type hints
vbarda Dec 6, 2024
740bb05
combine updates
vbarda Dec 6, 2024
108ed2e
lint
vbarda Dec 6, 2024
d6e6df9
Merge branch 'main' into vb/tool-node-support-command
vbarda Dec 6, 2024
681fbcf
don't wrap in command, let langgraph handle
vbarda Dec 6, 2024
d35c1fa
factor out + test
vbarda Dec 6, 2024
2604a77
better error messages
vbarda Dec 6, 2024
5e8de8a
relax validation for parent graph updates
vbarda Dec 7, 2024
607d7ce
simplify list updates
vbarda Dec 8, 2024
ca3698f
output_type -> input_type
vbarda Dec 8, 2024
55a22f3
simplify and match core changes
vbarda Dec 10, 2024
442ad8c
update tests to match core changes
vbarda Dec 10, 2024
24515bf
match name in core
vbarda Dec 10, 2024
6a79de0
update
vbarda Dec 10, 2024
91313f3
update tests
vbarda Dec 10, 2024
d0c6690
remove special exception
vbarda Dec 10, 2024
61569fc
lint
vbarda Dec 10, 2024
347a345
update core
vbarda Dec 10, 2024
3a98ce0
Merge branch 'main' into vb/tool-node-support-command
vbarda Dec 10, 2024
a65e625
move import in tests
vbarda Dec 10, 2024
86ce622
update pyproject
vbarda Dec 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 163 additions & 36 deletions libs/langgraph/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio

Check notice on line 1 in libs/langgraph/langgraph/prebuilt/tool_node.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 62.1 ms +- 1.5 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 52.5 ms +- 0.7 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 93.7 ms +- 7.1 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 95.0 ms +- 1.5 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 632 ms +- 20 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 514 ms +- 9 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 960 ms +- 45 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 948 ms +- 18 ms ......................................... react_agent_10x: Mean +- std dev: 31.0 ms +- 0.6 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.9 ms +- 0.4 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 46.9 ms +- 0.8 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.9 ms +- 0.4 ms ......................................... react_agent_100x: Mean +- std dev: 346 ms +- 7 ms ......................................... react_agent_100x_sync: Mean +- std dev: 275 ms +- 3 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 935 ms +- 10 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 833 ms +- 10 ms ......................................... wide_state_25x300: Mean +- std dev: 23.6 ms +- 0.4 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 14.9 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 286 ms +- 13 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 273 ms +- 14 ms ......................................... wide_state_15x600: Mean +- std dev: 27.6 ms +- 0.6 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 17.4 ms +- 0.2 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 486 ms +- 14 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 472 ms +- 13 ms ......................................... wide_state_9x1200: Mean +- std dev: 27.6 ms +- 0.5 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 17.3 ms +- 0.1 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 320 ms +- 15 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 306 ms +- 13 ms

Check notice on line 1 in libs/langgraph/langgraph/prebuilt/tool_node.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +========================================+=========+=======================+ | wide_state_9x1200_sync | 17.3 ms | 17.3 ms: 1.00x slower | +----------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 14.9 ms | 14.9 ms: 1.00x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 511 ms | 514 ms: 1.00x slower | +----------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 27.4 ms | 27.6 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 17.3 ms | 17.4 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 828 ms | 833 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x | 61.7 ms | 62.1 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 94.4 ms | 95.0 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | react_agent_10x | 30.8 ms | 31.0 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 927 ms | 935 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_sync | 52.0 ms | 52.5 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 272 ms | 275 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | react_agent_10x_sync | 22.6 ms | 22.9 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint_sync | 36.4 ms | 36.9 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 942 ms | 960 ms: 1.02x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 615 ms | 632 ms: 1.03x slower | +----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.01x slower | +----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (12): wide_state_25x300_checkpoint_sync, wide_state_15x600_checkpoint_sync, wide_state_25x300_checkpoint, wide_state_9x1200_checkpoint_sync, wide_state_15x600_checkpoint, wide_state_9x1200_checkpoint, wide_state_25x300, react_agent_10x_checkpoint, fanout_to_subgraph_10x_checkpoint, react_agent_100x, wide_state_15x600, fanout_to_subgraph_100x_checkpoint_sync
import inspect
import json
from copy import copy
from copy import copy, deepcopy
from typing import (
Any,
Callable,
Expand All @@ -20,6 +20,7 @@
AnyMessage,
ToolCall,
ToolMessage,
convert_to_messages,
)
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import (
Expand All @@ -35,6 +36,7 @@

from langgraph.errors import GraphBubbleUp
from langgraph.store.base import BaseStore
from langgraph.types import Command
from langgraph.utils.runnable import RunnableCallable

INVALID_TOOL_NAME_ERROR_TEMPLATE = (
Expand All @@ -47,7 +49,7 @@
recognized_content_block_types = ("image", "image_url", "text", "json")
if isinstance(output, str):
return output
elif all(
elif isinstance(output, list) and all(
[
isinstance(x, dict) and x.get("type") in recognized_content_block_types
for x in output
Expand Down Expand Up @@ -210,12 +212,31 @@
*,
store: BaseStore,
) -> Any:
tool_calls, output_type = self._parse_input(input, store)
tool_calls, input_type = self._parse_input(input, store)
config_list = get_config_list(config, len(tool_calls))
input_types = [input_type] * len(tool_calls)
with get_executor_for_config(config) as executor:
outputs = [*executor.map(self._run_one, tool_calls, config_list)]
# TypedDict, pydantic, dataclass, etc. should all be able to load from dict
return outputs if output_type == "list" else {self.messages_key: outputs}
outputs = [
*executor.map(self._run_one, tool_calls, input_types, config_list)
]

# preserve existing behavior for non-command tool outputs for backwards compatibility
if not any(isinstance(output, Command) for output in outputs):
# TypedDict, pydantic, dataclass, etc. should all be able to load from dict
return outputs if input_type == "list" else {self.messages_key: outputs}

# LangGraph will automatically handle list of Command and non-command node updates
combined_outputs: list[
Command | list[ToolMessage] | dict[str, list[ToolMessage]]
] = []
for output in outputs:
if isinstance(output, Command):
combined_outputs.append(output)
else:
combined_outputs.append(
[output] if input_type == "list" else {self.messages_key: [output]}
)
return combined_outputs

def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
Expand All @@ -242,26 +263,42 @@
*,
store: BaseStore,
) -> Any:
tool_calls, output_type = self._parse_input(input, store)
tool_calls, input_type = self._parse_input(input, store)
outputs = await asyncio.gather(
*(self._arun_one(call, config) for call in tool_calls)
*(self._arun_one(call, input_type, config) for call in tool_calls)
)
# TypedDict, pydantic, dataclass, etc. should all be able to load from dict
return outputs if output_type == "list" else {self.messages_key: outputs}

def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage:
# preserve existing behavior for non-command tool outputs for backwards compatibility
if not any(isinstance(output, Command) for output in outputs):
# TypedDict, pydantic, dataclass, etc. should all be able to load from dict
return outputs if input_type == "list" else {self.messages_key: outputs}

# LangGraph will automatically handle list of Command and non-command node updates
combined_outputs: list[
Command | list[ToolMessage] | dict[str, list[ToolMessage]]
] = []
for output in outputs:
if isinstance(output, Command):
combined_outputs.append(output)
else:
combined_outputs.append(
[output] if input_type == "list" else {self.messages_key: [output]}
)
return combined_outputs

def _run_one(
self,
call: ToolCall,
input_type: Literal["list", "dict"],
config: RunnableConfig,
) -> ToolMessage:
if invalid_tool_message := self._validate_tool_call(call):
return invalid_tool_message

try:
input = {**call, **{"type": "tool_call"}}
tool_message: ToolMessage = self.tools_by_name[call["name"]].invoke(
input, config
)
tool_message.content = cast(
Union[str, list], msg_content_output(tool_message.content)
)
return tool_message
response = self.tools_by_name[call["name"]].invoke(input)

# GraphInterrupt is a special exception that will always be raised.
# It can be triggered in the following scenarios:
# (1) a NodeInterrupt is raised inside a tool
Expand All @@ -285,24 +322,38 @@
# Handled
else:
content = _handle_tool_error(e, flag=self.handle_tool_errors)
return ToolMessage(
content=content,
name=call["name"],
tool_call_id=call["id"],
status="error",
)

return ToolMessage(
content=content, name=call["name"], tool_call_id=call["id"], status="error"
)
if isinstance(response, Command):
return self._validate_tool_command(response, call, input_type)
elif isinstance(response, ToolMessage):
response.content = cast(
Union[str, list], msg_content_output(response.content)
)
return response
else:
raise TypeError(
f"Tool {call['name']} returned unexpected type: {type(response)}"
)

async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage:
async def _arun_one(
self,
call: ToolCall,
input_type: Literal["list", "dict"],
config: RunnableConfig,
) -> ToolMessage:
if invalid_tool_message := self._validate_tool_call(call):
return invalid_tool_message

try:
input = {**call, **{"type": "tool_call"}}
tool_message: ToolMessage = await self.tools_by_name[call["name"]].ainvoke(
input, config
)
tool_message.content = cast(
Union[str, list], msg_content_output(tool_message.content)
)
return tool_message
response = await self.tools_by_name[call["name"]].ainvoke(input)

# GraphInterrupt is a special exception that will always be raised.
# It can be triggered in the following scenarios:
# (1) a NodeInterrupt is raised inside a tool
Expand All @@ -327,9 +378,24 @@
else:
content = _handle_tool_error(e, flag=self.handle_tool_errors)

return ToolMessage(
content=content, name=call["name"], tool_call_id=call["id"], status="error"
)
return ToolMessage(
content=content,
name=call["name"],
tool_call_id=call["id"],
status="error",
)

if isinstance(response, Command):
return self._validate_tool_command(response, call, input_type)
elif isinstance(response, ToolMessage):
response.content = cast(
Union[str, list], msg_content_output(response.content)
)
return response
else:
raise TypeError(
f"Tool {call['name']} returned unexpected type: {type(response)}"
)

def _parse_input(
self,
Expand All @@ -341,14 +407,14 @@
store: BaseStore,
) -> Tuple[list[ToolCall], Literal["list", "dict"]]:
if isinstance(input, list):
output_type = "list"
input_type = "list"
message: AnyMessage = input[-1]
elif isinstance(input, dict) and (messages := input.get(self.messages_key, [])):
output_type = "dict"
input_type = "dict"
message = messages[-1]
elif messages := getattr(input, self.messages_key, None):
# Assume dataclass-like state that can coerce from dict
output_type = "dict"
input_type = "dict"
message = messages[-1]
else:
raise ValueError("No message found in input")
Expand All @@ -359,7 +425,7 @@
tool_calls = [
self._inject_tool_args(call, input, store) for call in message.tool_calls
]
return tool_calls, output_type
return tool_calls, input_type

def _validate_tool_call(self, call: ToolCall) -> Optional[ToolMessage]:
if (requested_tool := call["name"]) not in self.tools_by_name:
Expand Down Expand Up @@ -453,6 +519,67 @@
tool_call_with_store = self._inject_store(tool_call_with_state, store)
return tool_call_with_store

def _validate_tool_command(
self, command: Command, call: ToolCall, input_type: Literal["list", "dict"]
) -> Command:
if isinstance(command.update, dict):
# input type is dict when ToolNode is invoked with a dict input (e.g. {"messages": [AIMessage(..., tool_calls=[...])]})
if input_type != "dict":
raise ValueError(
f"Tools can provide a dict in Command.update only when using dict with '{self.messages_key}' key as ToolNode input, "
f"got: {command.update} for tool '{call['name']}'"
)

updated_command = deepcopy(command)
state_update = cast(dict[str, Any], updated_command.update) or {}
messages_update = state_update.get(self.messages_key, [])
elif isinstance(command.update, list):
# input type is list when ToolNode is invoked with a list input (e.g. [AIMessage(..., tool_calls=[...])])
if input_type != "list":
raise ValueError(
f"Tools can provide a list of messages in Command.update only when using list of messages as ToolNode input, "
f"got: {command.update} for tool '{call['name']}'"
)

updated_command = deepcopy(command)
messages_update = updated_command.update
else:
return command

# convert to message objects if updates are in a dict format
messages_update = convert_to_messages(messages_update)
have_seen_tool_messages = False
for message in messages_update:
if not isinstance(message, ToolMessage):
continue

if have_seen_tool_messages:
raise ValueError(
f"Expected at most one ToolMessage in Command.update for tool '{call['name']}', got multiple: {messages_update}."
)

if message.tool_call_id != call["id"]:
raise ValueError(
f"ToolMessage.tool_call_id must match the tool call id. Expected: {call['id']}, got: {message.tool_call_id} for tool '{call['name']}'."
)

message.name = call["name"]
have_seen_tool_messages = True

# validate that we always have exactly one ToolMessage in Command.update if command is sent to the CURRENT graph
if updated_command.graph is None and not have_seen_tool_messages:
example_update = (
'`Command(update={"messages": [ToolMessage("Success", tool_call_id=tool_call_id), ...]}, ...)`'
if input_type == "dict"
else '`Command(update=[ToolMessage("Success", tool_call_id=tool_call_id), ...], ...)`'
)
raise ValueError(
f"Expected exactly one message (ToolMessage) in Command.update for tool '{call['name']}', got: {messages_update}. "
vbarda marked this conversation as resolved.
Show resolved Hide resolved
"Every tool call (LLM requesting to call a tool) in the message history MUST have a corresponding ToolMessage. "
f"You can fix it by modifying the tool to return {example_update}."
)
return updated_command


def tools_condition(
state: Union[list[AnyMessage], dict[str, Any], BaseModel],
Expand Down
10 changes: 9 additions & 1 deletion libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@
from langgraph.store.base import BaseStore


try:
from langchain_core.messages.tool import ToolOutputMixin
except ImportError:

class ToolOutputMixin: # type: ignore[no-redef]
pass


All = Literal["*"]
"""Special value to indicate that graph should interrupt on all nodes."""

Expand Down Expand Up @@ -244,7 +252,7 @@ def __eq__(self, value: object) -> bool:


@dataclasses.dataclass(**_DC_KWARGS)
class Command(Generic[N]):
class Command(Generic[N], ToolOutputMixin):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nfcampos could you confirm you're OK with this change?

"""One or more commands to update the graph's state and send messages to nodes.
Args:
Expand Down
12 changes: 6 additions & 6 deletions libs/langgraph/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/langgraph/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repository = "https://www.github.com/langchain-ai/langgraph"

[tool.poetry.dependencies]
python = ">=3.9.0,<4.0"
langchain-core = ">=0.2.43,<0.4.0,!=0.3.0,!=0.3.1,!=0.3.2,!=0.3.3,!=0.3.4,!=0.3.5,!=0.3.6,!=0.3.7,!=0.3.8,!=0.3.9,!=0.3.10,!=0.3.11,!=0.3.12,!=0.3.13,!=0.3.14"
langchain-core = ">=0.2.43,<0.4.0,!=0.3.0,!=0.3.1,!=0.3.2,!=0.3.3,!=0.3.4,!=0.3.5,!=0.3.6,!=0.3.7,!=0.3.8,!=0.3.9,!=0.3.10,!=0.3.11,!=0.3.12,!=0.3.13,!=0.3.14,!=0.3.15,!=0.3.16,!=0.3.17,!=0.3.18,!=0.3.19,!=0.3.20,!=0.3.21,!=0.3.22"
langgraph-checkpoint = "^2.0.4"
langgraph-sdk = "^0.1.42"

Expand Down
Loading
Loading