Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lib: Merge GraphCommand and Command #2638

Merged
merged 2 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions libs/langgraph/langgraph/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from langgraph.graph.graph import END, START, Graph

Check notice on line 1 in libs/langgraph/langgraph/graph/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 62.0 ms +- 1.7 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 52.0 ms +- 1.4 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 93.0 ms +- 8.3 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 94.8 ms +- 1.7 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 611 ms +- 25 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 505 ms +- 5 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 958 ms +- 35 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 931 ms +- 17 ms ......................................... react_agent_10x: Mean +- std dev: 30.9 ms +- 0.7 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.4 ms +- 0.2 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 47.1 ms +- 0.9 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.9 ms +- 0.6 ms ......................................... react_agent_100x: Mean +- std dev: 343 ms +- 7 ms ......................................... react_agent_100x_sync: Mean +- std dev: 271 ms +- 3 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 938 ms +- 9 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 838 ms +- 8 ms ......................................... wide_state_25x300: Mean +- std dev: 24.3 ms +- 0.4 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 15.5 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 278 ms +- 5 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 265 ms +- 5 ms ......................................... wide_state_15x600: Mean +- std dev: 28.4 ms +- 0.5 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 17.9 ms +- 0.2 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 478 ms +- 5 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 463 ms +- 6 ms ......................................... wide_state_9x1200: Mean +- std dev: 28.4 ms +- 0.5 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 17.9 ms +- 0.2 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 312 ms +- 4 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 297 ms +- 4 ms

Check notice on line 1 in libs/langgraph/langgraph/graph/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +====================================+=========+=======================+ | wide_state_9x1200_checkpoint_sync | 298 ms | 297 ms: 1.00x faster | +------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 272 ms | 271 ms: 1.00x faster | +------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 28.2 ms | 28.4 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 17.8 ms | 17.9 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 501 ms | 505 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 17.8 ms | 17.9 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 15.3 ms | 15.5 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | wide_state_15x600 | 28.1 ms | 28.4 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_sync | 51.4 ms | 52.0 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 945 ms | 958 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.00x slower | +------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (18): wide_state_15x600_checkpoint_sync, react_agent_10x, wide_state_9x1200_checkpoint, wide_state_25x300_checkpoint_sync, react_agent_10x_checkpoint, wide_state_15x600_checkpoint, react_agent_100x_checkpoint, react_agent_100x_checkpoint_sync, wide_state_25x300_checkpoint, react_agent_10x_sync, fanout_to_subgraph_10x_checkpoint_sync, react_agent_100x, react_agent_10x_checkpoint_sync, fanout_to_subgraph_100x_checkpoint_sync, wide_state_25x300, fanout_to_subgraph_10x, fanout_to_subgraph_10x_checkpoint, fanout_to_subgraph_100x
from langgraph.graph.message import MessageGraph, MessagesState, add_messages
from langgraph.graph.state import GraphCommand, StateGraph
from langgraph.graph.state import StateGraph

__all__ = [
"END",
"START",
"Graph",
"StateGraph",
"GraphCommand",
"MessageGraph",
"add_messages",
"MessagesState",
Expand Down
48 changes: 12 additions & 36 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import dataclasses
import inspect
import logging
import typing
Expand All @@ -9,7 +8,6 @@
from typing import (
Any,
Callable,
Generic,
Literal,
NamedTuple,
Optional,
Expand Down Expand Up @@ -55,7 +53,7 @@
from langgraph.pregel.read import ChannelRead, PregelNode
from langgraph.pregel.write import SKIP_WRITE, ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.types import _DC_KWARGS, All, Checkpointer, Command, N, RetryPolicy
from langgraph.types import All, Checkpointer, Command, RetryPolicy
from langgraph.utils.fields import get_field_default
from langgraph.utils.pydantic import create_model
from langgraph.utils.runnable import RunnableCallable, coerce_to_runnable
Expand Down Expand Up @@ -84,22 +82,6 @@ def _get_node_name(node: RunnableLike) -> str:
raise TypeError(f"Unsupported node type: {type(node)}")


@dataclasses.dataclass(**_DC_KWARGS)
class GraphCommand(Generic[N], Command[N]):
"""One or more commands to update a StateGraph's state and go to, or send messages to nodes."""

goto: Union[str, Sequence[str]] = ()

def __repr__(self) -> str:
# get all non-None values
contents = ", ".join(
f"{key}={value!r}"
for key, value in dataclasses.asdict(self).items()
if value
)
return f"Command({contents})"


class StateNodeSpec(NamedTuple):
runnable: Runnable
metadata: Optional[dict[str, Any]]
Expand Down Expand Up @@ -392,7 +374,7 @@ def add_node(
input = input_hint
if (
(rtn := hints.get("return"))
and get_origin(rtn) in (Command, GraphCommand)
and get_origin(rtn) is Command
and (rargs := get_args(rtn))
and get_origin(rargs[0]) is Literal
and (vals := get_args(rargs[0]))
Expand Down Expand Up @@ -834,15 +816,12 @@ def _control_branch(value: Any) -> Sequence[Union[str, Send]]:
if value.graph == Command.PARENT:
raise ParentCommand(value)
rtn: list[Union[str, Send]] = []
if isinstance(value, GraphCommand):
if isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.goto)
if isinstance(value.send, Send):
rtn.append(value.send)
if isinstance(value.goto, Send):
rtn.append(value.goto)
elif isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.send)
rtn.extend(value.goto)
return rtn


Expand All @@ -854,15 +833,12 @@ async def _acontrol_branch(value: Any) -> Sequence[Union[str, Send]]:
if value.graph == Command.PARENT:
raise ParentCommand(value)
rtn: list[Union[str, Send]] = []
if isinstance(value, GraphCommand):
if isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.goto)
if isinstance(value.send, Send):
rtn.append(value.send)
if isinstance(value.goto, Send):
rtn.append(value.goto)
elif isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.send)
rtn.extend(value.goto)
return rtn


Expand Down
11 changes: 6 additions & 5 deletions libs/langgraph/langgraph/pregel/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,18 @@ def map_command(
"""Map input chunk to a sequence of pending writes in the form (channel, value)."""
if cmd.graph == Command.PARENT:
raise InvalidUpdateError("There is not parent graph")
if cmd.send:
if isinstance(cmd.send, (tuple, list)):
sends = cmd.send
if cmd.goto:
if isinstance(cmd.goto, (tuple, list)):
sends = cmd.goto
else:
sends = [cmd.send]
sends = [cmd.goto]
for send in sends:
if not isinstance(send, Send):
raise TypeError(
f"In Command.send, expected Send, got {type(send).__name__}"
f"In Command.goto, expected Send, got {type(send).__name__}"
)
yield (NULL_TASK_ID, PUSH if FF_SEND_V2 else TASKS, send)
# TODO handle goto str for state graph
if cmd.resume:
if isinstance(cmd.resume, dict) and all(is_task_id(k) for k in cmd.resume):
for tid, resume in cmd.resume.items():
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ class Command(Generic[N]):

graph: Optional[str] = None
update: Optional[dict[str, Any]] = None
send: Union[Send, Sequence[Send]] = ()
resume: Optional[Union[Any, dict[str, Any]]] = None
goto: Union[Send, Sequence[Union[Send, str]], str] = ()

def __repr__(self) -> str:
# get all non-None values
Expand Down
58 changes: 29 additions & 29 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
START,
)
from langgraph.errors import InvalidUpdateError, MultipleSubgraphsError, NodeInterrupt
from langgraph.graph import END, Graph, GraphCommand, StateGraph
from langgraph.graph import END, Graph, StateGraph
from langgraph.graph.message import MessageGraph, MessagesState, add_messages
from langgraph.managed.shared_value import SharedValue
from langgraph.prebuilt.chat_agent_executor import create_tool_calling_executor
Expand Down Expand Up @@ -270,10 +270,10 @@ class State(TypedDict):
bar: str

def node_a(state: State):
return GraphCommand(goto="b", update={"foo": "bar"})
return Command(goto="b", update={"foo": "bar"})

def node_b(state: State):
return GraphCommand(goto=END, update={"bar": "baz"})
return Command(goto=END, update={"bar": "baz"})

builder = StateGraph(State)
builder.add_node("a", node_a)
Expand Down Expand Up @@ -1925,8 +1925,8 @@ def __call__(self, state):

def send_for_fun(state):
return [
Send("2", Command(send=Send("2", 3))),
Send("2", GraphCommand(send=Send("2", 4))),
Send("2", Command(goto=Send("2", 3))),
Send("2", Command(goto=Send("2", 4))),
"3.1",
]

Expand All @@ -1947,8 +1947,8 @@ def route_to_three(state) -> Literal["3"]:
== [
"0",
"1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='2', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='2', arg=4))",
"2|3",
"2|4",
"3",
Expand All @@ -1959,8 +1959,8 @@ def route_to_three(state) -> Literal["3"]:
"0",
"1",
"3.1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='2', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='2', arg=4))",
"3",
"2|3",
"2|4",
Expand Down Expand Up @@ -2000,15 +2000,15 @@ def __call__(self, state):
if isinstance(state, list)
else ["|".join((self.name, str(state)))]
)
if isinstance(state, GraphCommand):
if isinstance(state, Command):
return replace(state, update=update)
else:
return update

def send_for_fun(state):
return [
Send("2", GraphCommand(send=Send("2", 3))),
Send("2", GraphCommand(send=Send("flaky", 4))),
Send("2", Command(goto=Send("2", 3))),
Send("2", Command(goto=Send("flaky", 4))),
"3.1",
]

Expand All @@ -2030,8 +2030,8 @@ def route_to_three(state) -> Literal["3"]:
assert graph.invoke(["0"], thread1, debug=1) == [
"0",
"1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
]
assert builder.nodes["2"].runnable.func.ticks == 3
Expand All @@ -2046,8 +2046,8 @@ def route_to_three(state) -> Literal["3"]:
assert graph.invoke(None, thread1, debug=1) == [
"0",
"1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
"3",
Expand All @@ -2069,8 +2069,8 @@ def route_to_three(state) -> Literal["3"]:
values=[
"0",
"1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
"3",
Expand Down Expand Up @@ -2105,8 +2105,8 @@ def route_to_three(state) -> Literal["3"]:
values=[
"0",
"1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
],
Expand All @@ -2123,8 +2123,8 @@ def route_to_three(state) -> Literal["3"]:
"writes": {
"1": ["1"],
"2": [
["2|Command(send=Send(node='2', arg=3))"],
["2|Command(send=Send(node='flaky', arg=4))"],
["2|Command(goto=Send(node='2', arg=3))"],
["2|Command(goto=Send(node='flaky', arg=4))"],
["2|3"],
],
"flaky": ["flaky|4"],
Expand Down Expand Up @@ -2209,7 +2209,7 @@ def route_to_three(state) -> Literal["3"]:
error=None,
interrupts=(),
state=None,
result=["2|Command(send=Send(node='2', arg=3))"],
result=["2|Command(goto=Send(node='2', arg=3))"],
),
PregelTask(
id=AnyStr(),
Expand All @@ -2223,7 +2223,7 @@ def route_to_three(state) -> Literal["3"]:
error=None,
interrupts=(),
state=None,
result=["2|Command(send=Send(node='flaky', arg=4))"],
result=["2|Command(goto=Send(node='flaky', arg=4))"],
),
PregelTask(
id=AnyStr(),
Expand Down Expand Up @@ -2786,10 +2786,10 @@ def test_send_react_interrupt_control(
tool_calls=[ToolCall(name="foo", args={"hi": [1, 2, 3]}, id=AnyStr())],
)

def agent(state) -> GraphCommand[Literal["foo"]]:
return GraphCommand(
def agent(state) -> Command[Literal["foo"]]:
return Command(
update={"messages": ai_message},
send=[Send(call["name"], call) for call in ai_message.tool_calls],
goto=[Send(call["name"], call) for call in ai_message.tool_calls],
)

foo_called = 0
Expand Down Expand Up @@ -14580,9 +14580,9 @@ def test_parent_command(request: pytest.FixtureRequest, checkpointer_name: str)
from langchain_core.tools import tool

@tool(return_direct=True)
def get_user_name() -> GraphCommand:
def get_user_name() -> Command:
"""Retrieve user name"""
return GraphCommand(update={"user_name": "Meow"}, graph=GraphCommand.PARENT)
return Command(update={"user_name": "Meow"}, graph=Command.PARENT)

subgraph_builder = StateGraph(MessagesState)
subgraph_builder.add_node("tool", get_user_name)
Expand Down
Loading
Loading