Skip to content

Commit

Permalink
langgraph: allow tools to return Command in tool node (#2656)
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda authored Dec 10, 2024
1 parent 038bec2 commit 59bfa5d
Show file tree
Hide file tree
Showing 5 changed files with 820 additions and 46 deletions.
199 changes: 163 additions & 36 deletions libs/langgraph/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import inspect
import json
from copy import copy
from copy import copy, deepcopy
from typing import (
Any,
Callable,
Expand All @@ -20,6 +20,7 @@
AnyMessage,
ToolCall,
ToolMessage,
convert_to_messages,
)
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import (
Expand All @@ -35,6 +36,7 @@

from langgraph.errors import GraphBubbleUp
from langgraph.store.base import BaseStore
from langgraph.types import Command
from langgraph.utils.runnable import RunnableCallable

INVALID_TOOL_NAME_ERROR_TEMPLATE = (
Expand All @@ -47,7 +49,7 @@ def msg_content_output(output: Any) -> Union[str, list[dict]]:
recognized_content_block_types = ("image", "image_url", "text", "json")
if isinstance(output, str):
return output
elif all(
elif isinstance(output, list) and all(
[
isinstance(x, dict) and x.get("type") in recognized_content_block_types
for x in output
Expand Down Expand Up @@ -210,12 +212,31 @@ def _func(
*,
store: BaseStore,
) -> Any:
tool_calls, output_type = self._parse_input(input, store)
tool_calls, input_type = self._parse_input(input, store)
config_list = get_config_list(config, len(tool_calls))
input_types = [input_type] * len(tool_calls)
with get_executor_for_config(config) as executor:
outputs = [*executor.map(self._run_one, tool_calls, config_list)]
# TypedDict, pydantic, dataclass, etc. should all be able to load from dict
return outputs if output_type == "list" else {self.messages_key: outputs}
outputs = [
*executor.map(self._run_one, tool_calls, input_types, config_list)
]

# preserve existing behavior for non-command tool outputs for backwards compatibility
if not any(isinstance(output, Command) for output in outputs):
# TypedDict, pydantic, dataclass, etc. should all be able to load from dict
return outputs if input_type == "list" else {self.messages_key: outputs}

# LangGraph will automatically handle list of Command and non-command node updates
combined_outputs: list[
Command | list[ToolMessage] | dict[str, list[ToolMessage]]
] = []
for output in outputs:
if isinstance(output, Command):
combined_outputs.append(output)
else:
combined_outputs.append(
[output] if input_type == "list" else {self.messages_key: [output]}
)
return combined_outputs

def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
Expand All @@ -242,26 +263,42 @@ async def _afunc(
*,
store: BaseStore,
) -> Any:
tool_calls, output_type = self._parse_input(input, store)
tool_calls, input_type = self._parse_input(input, store)
outputs = await asyncio.gather(
*(self._arun_one(call, config) for call in tool_calls)
*(self._arun_one(call, input_type, config) for call in tool_calls)
)
# TypedDict, pydantic, dataclass, etc. should all be able to load from dict
return outputs if output_type == "list" else {self.messages_key: outputs}

def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage:
# preserve existing behavior for non-command tool outputs for backwards compatibility
if not any(isinstance(output, Command) for output in outputs):
# TypedDict, pydantic, dataclass, etc. should all be able to load from dict
return outputs if input_type == "list" else {self.messages_key: outputs}

# LangGraph will automatically handle list of Command and non-command node updates
combined_outputs: list[
Command | list[ToolMessage] | dict[str, list[ToolMessage]]
] = []
for output in outputs:
if isinstance(output, Command):
combined_outputs.append(output)
else:
combined_outputs.append(
[output] if input_type == "list" else {self.messages_key: [output]}
)
return combined_outputs

def _run_one(
self,
call: ToolCall,
input_type: Literal["list", "dict"],
config: RunnableConfig,
) -> ToolMessage:
if invalid_tool_message := self._validate_tool_call(call):
return invalid_tool_message

try:
input = {**call, **{"type": "tool_call"}}
tool_message: ToolMessage = self.tools_by_name[call["name"]].invoke(
input, config
)
tool_message.content = cast(
Union[str, list], msg_content_output(tool_message.content)
)
return tool_message
response = self.tools_by_name[call["name"]].invoke(input)

# GraphInterrupt is a special exception that will always be raised.
# It can be triggered in the following scenarios:
# (1) a NodeInterrupt is raised inside a tool
Expand All @@ -285,24 +322,38 @@ def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage:
# Handled
else:
content = _handle_tool_error(e, flag=self.handle_tool_errors)
return ToolMessage(
content=content,
name=call["name"],
tool_call_id=call["id"],
status="error",
)

return ToolMessage(
content=content, name=call["name"], tool_call_id=call["id"], status="error"
)
if isinstance(response, Command):
return self._validate_tool_command(response, call, input_type)
elif isinstance(response, ToolMessage):
response.content = cast(
Union[str, list], msg_content_output(response.content)
)
return response
else:
raise TypeError(
f"Tool {call['name']} returned unexpected type: {type(response)}"
)

async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage:
async def _arun_one(
self,
call: ToolCall,
input_type: Literal["list", "dict"],
config: RunnableConfig,
) -> ToolMessage:
if invalid_tool_message := self._validate_tool_call(call):
return invalid_tool_message

try:
input = {**call, **{"type": "tool_call"}}
tool_message: ToolMessage = await self.tools_by_name[call["name"]].ainvoke(
input, config
)
tool_message.content = cast(
Union[str, list], msg_content_output(tool_message.content)
)
return tool_message
response = await self.tools_by_name[call["name"]].ainvoke(input)

# GraphInterrupt is a special exception that will always be raised.
# It can be triggered in the following scenarios:
# (1) a NodeInterrupt is raised inside a tool
Expand All @@ -327,9 +378,24 @@ async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage
else:
content = _handle_tool_error(e, flag=self.handle_tool_errors)

return ToolMessage(
content=content, name=call["name"], tool_call_id=call["id"], status="error"
)
return ToolMessage(
content=content,
name=call["name"],
tool_call_id=call["id"],
status="error",
)

if isinstance(response, Command):
return self._validate_tool_command(response, call, input_type)
elif isinstance(response, ToolMessage):
response.content = cast(
Union[str, list], msg_content_output(response.content)
)
return response
else:
raise TypeError(
f"Tool {call['name']} returned unexpected type: {type(response)}"
)

def _parse_input(
self,
Expand All @@ -341,14 +407,14 @@ def _parse_input(
store: BaseStore,
) -> Tuple[list[ToolCall], Literal["list", "dict"]]:
if isinstance(input, list):
output_type = "list"
input_type = "list"
message: AnyMessage = input[-1]
elif isinstance(input, dict) and (messages := input.get(self.messages_key, [])):
output_type = "dict"
input_type = "dict"
message = messages[-1]
elif messages := getattr(input, self.messages_key, None):
# Assume dataclass-like state that can coerce from dict
output_type = "dict"
input_type = "dict"
message = messages[-1]
else:
raise ValueError("No message found in input")
Expand All @@ -359,7 +425,7 @@ def _parse_input(
tool_calls = [
self._inject_tool_args(call, input, store) for call in message.tool_calls
]
return tool_calls, output_type
return tool_calls, input_type

def _validate_tool_call(self, call: ToolCall) -> Optional[ToolMessage]:
if (requested_tool := call["name"]) not in self.tools_by_name:
Expand Down Expand Up @@ -453,6 +519,67 @@ def _inject_tool_args(
tool_call_with_store = self._inject_store(tool_call_with_state, store)
return tool_call_with_store

def _validate_tool_command(
self, command: Command, call: ToolCall, input_type: Literal["list", "dict"]
) -> Command:
if isinstance(command.update, dict):
# input type is dict when ToolNode is invoked with a dict input (e.g. {"messages": [AIMessage(..., tool_calls=[...])]})
if input_type != "dict":
raise ValueError(
f"Tools can provide a dict in Command.update only when using dict with '{self.messages_key}' key as ToolNode input, "
f"got: {command.update} for tool '{call['name']}'"
)

updated_command = deepcopy(command)
state_update = cast(dict[str, Any], updated_command.update) or {}
messages_update = state_update.get(self.messages_key, [])
elif isinstance(command.update, list):
# input type is list when ToolNode is invoked with a list input (e.g. [AIMessage(..., tool_calls=[...])])
if input_type != "list":
raise ValueError(
f"Tools can provide a list of messages in Command.update only when using list of messages as ToolNode input, "
f"got: {command.update} for tool '{call['name']}'"
)

updated_command = deepcopy(command)
messages_update = updated_command.update
else:
return command

# convert to message objects if updates are in a dict format
messages_update = convert_to_messages(messages_update)
have_seen_tool_messages = False
for message in messages_update:
if not isinstance(message, ToolMessage):
continue

if have_seen_tool_messages:
raise ValueError(
f"Expected at most one ToolMessage in Command.update for tool '{call['name']}', got multiple: {messages_update}."
)

if message.tool_call_id != call["id"]:
raise ValueError(
f"ToolMessage.tool_call_id must match the tool call id. Expected: {call['id']}, got: {message.tool_call_id} for tool '{call['name']}'."
)

message.name = call["name"]
have_seen_tool_messages = True

# validate that we always have exactly one ToolMessage in Command.update if command is sent to the CURRENT graph
if updated_command.graph is None and not have_seen_tool_messages:
example_update = (
'`Command(update={"messages": [ToolMessage("Success", tool_call_id=tool_call_id), ...]}, ...)`'
if input_type == "dict"
else '`Command(update=[ToolMessage("Success", tool_call_id=tool_call_id), ...], ...)`'
)
raise ValueError(
f"Expected exactly one message (ToolMessage) in Command.update for tool '{call['name']}', got: {messages_update}. "
"Every tool call (LLM requesting to call a tool) in the message history MUST have a corresponding ToolMessage. "
f"You can fix it by modifying the tool to return {example_update}."
)
return updated_command


def tools_condition(
state: Union[list[AnyMessage], dict[str, Any], BaseModel],
Expand Down
10 changes: 9 additions & 1 deletion libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@
from langgraph.store.base import BaseStore


try:
from langchain_core.messages.tool import ToolOutputMixin
except ImportError:

class ToolOutputMixin: # type: ignore[no-redef]
pass


All = Literal["*"]
"""Special value to indicate that graph should interrupt on all nodes."""

Expand Down Expand Up @@ -244,7 +252,7 @@ def __eq__(self, value: object) -> bool:


@dataclasses.dataclass(**_DC_KWARGS)
class Command(Generic[N]):
class Command(Generic[N], ToolOutputMixin):
"""One or more commands to update the graph's state and send messages to nodes.
Args:
Expand Down
12 changes: 6 additions & 6 deletions libs/langgraph/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/langgraph/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repository = "https://www.github.com/langchain-ai/langgraph"

[tool.poetry.dependencies]
python = ">=3.9.0,<4.0"
langchain-core = ">=0.2.43,<0.4.0,!=0.3.0,!=0.3.1,!=0.3.2,!=0.3.3,!=0.3.4,!=0.3.5,!=0.3.6,!=0.3.7,!=0.3.8,!=0.3.9,!=0.3.10,!=0.3.11,!=0.3.12,!=0.3.13,!=0.3.14"
langchain-core = ">=0.2.43,<0.4.0,!=0.3.0,!=0.3.1,!=0.3.2,!=0.3.3,!=0.3.4,!=0.3.5,!=0.3.6,!=0.3.7,!=0.3.8,!=0.3.9,!=0.3.10,!=0.3.11,!=0.3.12,!=0.3.13,!=0.3.14,!=0.3.15,!=0.3.16,!=0.3.17,!=0.3.18,!=0.3.19,!=0.3.20,!=0.3.21,!=0.3.22"
langgraph-checkpoint = "^2.0.4"
langgraph-sdk = "^0.1.42"

Expand Down
Loading

0 comments on commit 59bfa5d

Please sign in to comment.