Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lib: Merge GraphCommand and Command #2638

Merged
merged 2 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
lib: Merge GraphCommand and Command
- Now we have only Command
- Command(goto=) combines the previous functionality of Command(send=) and Command(goto=)
  • Loading branch information
nfcampos committed Dec 4, 2024
commit b4b3ac6f57adad2b0458026622c1e0ceb07c6c9f
3 changes: 1 addition & 2 deletions libs/langgraph/langgraph/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from langgraph.graph.graph import END, START, Graph

Check notice on line 1 in libs/langgraph/langgraph/graph/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 61.7 ms +- 1.6 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 52.1 ms +- 1.2 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 93.1 ms +- 8.9 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 94.9 ms +- 1.5 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 616 ms +- 29 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 505 ms +- 5 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 958 ms +- 34 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 935 ms +- 19 ms ......................................... react_agent_10x: Mean +- std dev: 30.9 ms +- 0.6 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.5 ms +- 0.2 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 47.2 ms +- 0.8 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.7 ms +- 0.4 ms ......................................... react_agent_100x: Mean +- std dev: 342 ms +- 7 ms ......................................... react_agent_100x_sync: Mean +- std dev: 271 ms +- 3 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 944 ms +- 10 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 845 ms +- 9 ms ......................................... wide_state_25x300: Mean +- std dev: 24.3 ms +- 0.4 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 15.5 ms +- 0.2 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 279 ms +- 5 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 267 ms +- 4 ms ......................................... wide_state_15x600: Mean +- std dev: 28.3 ms +- 0.5 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 17.9 ms +- 0.2 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 481 ms +- 5 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 466 ms +- 5 ms ......................................... wide_state_9x1200: Mean +- std dev: 28.2 ms +- 0.5 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 17.9 ms +- 0.1 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 314 ms +- 4 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 299 ms +- 4 ms

Check notice on line 1 in libs/langgraph/langgraph/graph/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +=========================================+=========+=======================+ | react_agent_10x_checkpoint_sync | 36.8 ms | 36.7 ms: 1.00x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 272 ms | 271 ms: 1.00x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint | 313 ms | 314 ms: 1.00x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_sync | 22.3 ms | 22.5 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 478 ms | 481 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600 | 28.1 ms | 28.3 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 939 ms | 944 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 17.8 ms | 17.9 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 17.8 ms | 17.9 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint_sync | 929 ms | 935 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 838 ms | 845 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 501 ms | 505 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 15.3 ms | 15.5 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_sync | 51.4 ms | 52.1 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 945 ms | 958 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 606 ms | 616 ms: 1.02x slower | +-----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.00x slower | +-----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (12): react_agent_10x, react_agent_100x, fanout_to_subgraph_10x, react_agent_10x_checkpoint, wide_state_9x1200_checkpoint_sync, wide_state_15x600_checkpoint_sync, wide_state_9x1200, wide_state_25x300, wide_state_25x300_checkpoint, fanout_to_subgraph_10x_checkpoint_sync, wide_state_25x300_checkpoint_sync, fanout_to_subgraph_10x_checkpoint
from langgraph.graph.message import MessageGraph, MessagesState, add_messages
from langgraph.graph.state import GraphCommand, StateGraph
from langgraph.graph.state import StateGraph

__all__ = [
"END",
"START",
"Graph",
"StateGraph",
"GraphCommand",
"MessageGraph",
"add_messages",
"MessagesState",
Expand Down
48 changes: 12 additions & 36 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import dataclasses
import inspect
import logging
import typing
Expand All @@ -9,7 +8,6 @@
from typing import (
Any,
Callable,
Generic,
Literal,
NamedTuple,
Optional,
Expand Down Expand Up @@ -55,7 +53,7 @@
from langgraph.pregel.read import ChannelRead, PregelNode
from langgraph.pregel.write import SKIP_WRITE, ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.types import _DC_KWARGS, All, Checkpointer, Command, N, RetryPolicy
from langgraph.types import All, Checkpointer, Command, RetryPolicy
from langgraph.utils.fields import get_field_default
from langgraph.utils.pydantic import create_model
from langgraph.utils.runnable import RunnableCallable, coerce_to_runnable
Expand Down Expand Up @@ -84,22 +82,6 @@ def _get_node_name(node: RunnableLike) -> str:
raise TypeError(f"Unsupported node type: {type(node)}")


@dataclasses.dataclass(**_DC_KWARGS)
class GraphCommand(Generic[N], Command[N]):
"""One or more commands to update a StateGraph's state and go to, or send messages to nodes."""

goto: Union[str, Sequence[str]] = ()

def __repr__(self) -> str:
# get all non-None values
contents = ", ".join(
f"{key}={value!r}"
for key, value in dataclasses.asdict(self).items()
if value
)
return f"Command({contents})"


class StateNodeSpec(NamedTuple):
runnable: Runnable
metadata: Optional[dict[str, Any]]
Expand Down Expand Up @@ -392,7 +374,7 @@ def add_node(
input = input_hint
if (
(rtn := hints.get("return"))
and get_origin(rtn) in (Command, GraphCommand)
and get_origin(rtn) is Command
and (rargs := get_args(rtn))
and get_origin(rargs[0]) is Literal
and (vals := get_args(rargs[0]))
Expand Down Expand Up @@ -834,15 +816,12 @@ def _control_branch(value: Any) -> Sequence[Union[str, Send]]:
if value.graph == Command.PARENT:
raise ParentCommand(value)
rtn: list[Union[str, Send]] = []
if isinstance(value, GraphCommand):
if isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.goto)
if isinstance(value.send, Send):
rtn.append(value.send)
if isinstance(value.goto, Send):
rtn.append(value.goto)
elif isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.send)
rtn.extend(value.goto)
return rtn


Expand All @@ -854,15 +833,12 @@ async def _acontrol_branch(value: Any) -> Sequence[Union[str, Send]]:
if value.graph == Command.PARENT:
raise ParentCommand(value)
rtn: list[Union[str, Send]] = []
if isinstance(value, GraphCommand):
if isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.goto)
if isinstance(value.send, Send):
rtn.append(value.send)
if isinstance(value.goto, Send):
rtn.append(value.goto)
elif isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.send)
rtn.extend(value.goto)
return rtn


Expand Down
9 changes: 5 additions & 4 deletions libs/langgraph/langgraph/pregel/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,18 @@ def map_command(
"""Map input chunk to a sequence of pending writes in the form (channel, value)."""
if cmd.graph == Command.PARENT:
raise InvalidUpdateError("There is not parent graph")
if cmd.send:
if cmd.goto:
if isinstance(cmd.send, (tuple, list)):
sends = cmd.send
sends = cmd.goto
else:
sends = [cmd.send]
sends = [cmd.goto]
for send in sends:
if not isinstance(send, Send):
raise TypeError(
f"In Command.send, expected Send, got {type(send).__name__}"
f"In Command.goto, expected Send, got {type(send).__name__}"
)
yield (NULL_TASK_ID, PUSH if FF_SEND_V2 else TASKS, send)
# TODO handle goto str for state graph
if cmd.resume:
if isinstance(cmd.resume, dict) and all(is_task_id(k) for k in cmd.resume):
for tid, resume in cmd.resume.items():
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ class Command(Generic[N]):

graph: Optional[str] = None
update: Optional[dict[str, Any]] = None
send: Union[Send, Sequence[Send]] = ()
resume: Optional[Union[Any, dict[str, Any]]] = None
goto: Union[Send, Sequence[Union[Send, str]], str] = ()

def __repr__(self) -> str:
# get all non-None values
Expand Down
58 changes: 29 additions & 29 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
START,
)
from langgraph.errors import InvalidUpdateError, MultipleSubgraphsError, NodeInterrupt
from langgraph.graph import END, Graph, GraphCommand, StateGraph
from langgraph.graph import END, Graph, StateGraph
from langgraph.graph.message import MessageGraph, MessagesState, add_messages
from langgraph.managed.shared_value import SharedValue
from langgraph.prebuilt.chat_agent_executor import create_tool_calling_executor
Expand Down Expand Up @@ -270,10 +270,10 @@ class State(TypedDict):
bar: str

def node_a(state: State):
return GraphCommand(goto="b", update={"foo": "bar"})
return Command(goto="b", update={"foo": "bar"})

def node_b(state: State):
return GraphCommand(goto=END, update={"bar": "baz"})
return Command(goto=END, update={"bar": "baz"})

builder = StateGraph(State)
builder.add_node("a", node_a)
Expand Down Expand Up @@ -1925,8 +1925,8 @@ def __call__(self, state):

def send_for_fun(state):
return [
Send("2", Command(send=Send("2", 3))),
Send("2", GraphCommand(send=Send("2", 4))),
Send("2", Command(goto=Send("2", 3))),
Send("2", Command(goto=Send("2", 4))),
"3.1",
]

Expand All @@ -1947,8 +1947,8 @@ def route_to_three(state) -> Literal["3"]:
== [
"0",
"1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='2', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='2', arg=4))",
"2|3",
"2|4",
"3",
Expand All @@ -1959,8 +1959,8 @@ def route_to_three(state) -> Literal["3"]:
"0",
"1",
"3.1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='2', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='2', arg=4))",
"3",
"2|3",
"2|4",
Expand Down Expand Up @@ -2000,15 +2000,15 @@ def __call__(self, state):
if isinstance(state, list)
else ["|".join((self.name, str(state)))]
)
if isinstance(state, GraphCommand):
if isinstance(state, Command):
return replace(state, update=update)
else:
return update

def send_for_fun(state):
return [
Send("2", GraphCommand(send=Send("2", 3))),
Send("2", GraphCommand(send=Send("flaky", 4))),
Send("2", Command(goto=Send("2", 3))),
Send("2", Command(goto=Send("flaky", 4))),
"3.1",
]

Expand All @@ -2030,8 +2030,8 @@ def route_to_three(state) -> Literal["3"]:
assert graph.invoke(["0"], thread1, debug=1) == [
"0",
"1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
]
assert builder.nodes["2"].runnable.func.ticks == 3
Expand All @@ -2046,8 +2046,8 @@ def route_to_three(state) -> Literal["3"]:
assert graph.invoke(None, thread1, debug=1) == [
"0",
"1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
"3",
Expand All @@ -2069,8 +2069,8 @@ def route_to_three(state) -> Literal["3"]:
values=[
"0",
"1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
"3",
Expand Down Expand Up @@ -2105,8 +2105,8 @@ def route_to_three(state) -> Literal["3"]:
values=[
"0",
"1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
],
Expand All @@ -2123,8 +2123,8 @@ def route_to_three(state) -> Literal["3"]:
"writes": {
"1": ["1"],
"2": [
["2|Command(send=Send(node='2', arg=3))"],
["2|Command(send=Send(node='flaky', arg=4))"],
["2|Command(goto=Send(node='2', arg=3))"],
["2|Command(goto=Send(node='flaky', arg=4))"],
["2|3"],
],
"flaky": ["flaky|4"],
Expand Down Expand Up @@ -2209,7 +2209,7 @@ def route_to_three(state) -> Literal["3"]:
error=None,
interrupts=(),
state=None,
result=["2|Command(send=Send(node='2', arg=3))"],
result=["2|Command(goto=Send(node='2', arg=3))"],
),
PregelTask(
id=AnyStr(),
Expand All @@ -2223,7 +2223,7 @@ def route_to_three(state) -> Literal["3"]:
error=None,
interrupts=(),
state=None,
result=["2|Command(send=Send(node='flaky', arg=4))"],
result=["2|Command(goto=Send(node='flaky', arg=4))"],
),
PregelTask(
id=AnyStr(),
Expand Down Expand Up @@ -2786,10 +2786,10 @@ def test_send_react_interrupt_control(
tool_calls=[ToolCall(name="foo", args={"hi": [1, 2, 3]}, id=AnyStr())],
)

def agent(state) -> GraphCommand[Literal["foo"]]:
return GraphCommand(
def agent(state) -> Command[Literal["foo"]]:
return Command(
update={"messages": ai_message},
send=[Send(call["name"], call) for call in ai_message.tool_calls],
goto=[Send(call["name"], call) for call in ai_message.tool_calls],
)

foo_called = 0
Expand Down Expand Up @@ -14580,9 +14580,9 @@ def test_parent_command(request: pytest.FixtureRequest, checkpointer_name: str)
from langchain_core.tools import tool

@tool(return_direct=True)
def get_user_name() -> GraphCommand:
def get_user_name() -> Command:
"""Retrieve user name"""
return GraphCommand(update={"user_name": "Meow"}, graph=GraphCommand.PARENT)
return Command(update={"user_name": "Meow"}, graph=Command.PARENT)

subgraph_builder = StateGraph(MessagesState)
subgraph_builder.add_node("tool", get_user_name)
Expand Down
Loading
Loading