Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langgraph: add support for BaseModel updates to Command #2747

Merged
merged 10 commits into from
Jan 23, 2025
3 changes: 3 additions & 0 deletions libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import dataclasses

Check notice on line 1 in libs/langgraph/langgraph/types.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 63.2 ms +- 1.2 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 55.2 ms +- 0.7 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 79.8 ms +- 1.1 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 97.7 ms +- 1.0 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 629 ms +- 10 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 546 ms +- 15 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 808 ms +- 23 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 988 ms +- 15 ms ......................................... react_agent_10x: Mean +- std dev: 30.7 ms +- 0.7 ms ......................................... react_agent_10x_sync: Mean +- std dev: 23.1 ms +- 0.2 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 38.8 ms +- 0.9 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.0 ms +- 0.5 ms ......................................... react_agent_100x: Mean +- std dev: 339 ms +- 7 ms ......................................... react_agent_100x_sync: Mean +- std dev: 271 ms +- 2 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 640 ms +- 7 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 612 ms +- 6 ms ......................................... wide_state_25x300: Mean +- std dev: 23.6 ms +- 0.5 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 15.9 ms +- 0.2 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 252 ms +- 15 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 249 ms +- 17 ms ......................................... wide_state_15x600: Mean +- std dev: 27.7 ms +- 0.6 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 18.4 ms +- 0.4 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 430 ms +- 13 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 424 ms +- 13 ms ......................................... wide_state_9x1200: Mean +- std dev: 27.7 ms +- 0.5 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 18.3 ms +- 0.1 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 281 ms +- 13 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 275 ms +- 13 ms

Check notice on line 1 in libs/langgraph/langgraph/types.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+----------+-----------------------+ | Benchmark | main | changes | +=========================================+==========+=======================+ | fanout_to_subgraph_100x_checkpoint | 920 ms | 808 ms: 1.14x faster | +-----------------------------------------+----------+-----------------------+ | react_agent_100x_checkpoint_sync | 675 ms | 612 ms: 1.10x faster | +-----------------------------------------+----------+-----------------------+ | fanout_to_subgraph_100x | 691 ms | 629 ms: 1.10x faster | +-----------------------------------------+----------+-----------------------+ | react_agent_100x_checkpoint | 670 ms | 640 ms: 1.05x faster | +-----------------------------------------+----------+-----------------------+ | fanout_to_subgraph_100x_checkpoint_sync | 1.03 sec | 988 ms: 1.05x faster | +-----------------------------------------+----------+-----------------------+ | wide_state_25x300_checkpoint | 263 ms | 252 ms: 1.04x faster | +-----------------------------------------+----------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 101 ms | 97.7 ms: 1.03x faster | +-----------------------------------------+----------+-----------------------+ | wide_state_25x300_sync | 16.4 ms | 15.9 ms: 1.03x faster | +-----------------------------------------+----------+-----------------------+ | fanout_to_subgraph_100x_sync | 562 ms | 546 ms: 1.03x faster | +-----------------------------------------+----------+-----------------------+ | react_agent_10x_checkpoint_sync | 36.9 ms | 36.0 ms: 1.02x faster | +-----------------------------------------+----------+-----------------------+ | fanout_to_subgraph_10x_checkpoint | 81.3 ms | 79.8 ms: 1.02x faster | +-----------------------------------------+----------+-----------------------+ | wide_state_15x600_checkpoint_sync | 432 ms | 424 ms: 1.02x faster | +-----------------------------------------+----------+-----------------------+ | react_agent_100x_sync | 276 ms | 271 ms: 1.02x faster | +-----------------------------------------+----------+-----------------------+ | wide_state_25x300 | 24.0 ms | 23.6 ms: 1.02x faster | +-----------------------------------------+----------+-----------------------+ | fanout_to_subgraph_10x | 64.1 ms | 63.2 ms: 1.02x faster | +-----------------------------------------+----------+-----------------------+ | wide_state_9x1200_checkpoint_sync | 279 ms | 275 ms: 1.01x faster | +-----------------------------------------+----------+-----------------------+ | react_agent_100x | 344 ms | 339 ms: 1.01x faster | +-----------------------------------------+----------+-----------------------+ | react_agent_10x_sync | 23.4 ms | 23.1 ms: 1.01x faster | +-----------------------------------------+----------+-----------------------+ | fanout_to_subgraph_10x_sync | 55.8 ms | 55.2 ms: 1.01x faster | +-----------------------------------------+----------+-----------------------+ | react_agent_10x | 31.1 ms | 30.7 ms: 1.01x faster | +-----------------------------------------+----------+-----------------------+ | wide_state_15x600_checkpoint | 435 ms | 430 ms: 1.01x faster | +-----------------------------------------+----------+-----------------------+ | wide_state_15x600_sync | 18.5 ms | 18.4 ms: 1.01x faster | +-----------------------------------------+----------+-----------------------+ | wide_state_15x600 | 28.0 ms | 27.7 ms: 1.01x faster | +-----------------------------------------+----------+-----------------------+ | react_agent_10x_checkpoint | 39.2 ms | 38.8 ms: 1.01x faster | +-----------------------------------------+----------+-----------------------+ | wide_state_9x1200_sync | 18.4 ms | 18.3 ms: 1.0
import sys
from collections import deque
from typing import (
Expand All @@ -16,6 +16,7 @@
TypeVar,
Union,
cast,
get_type_hints,
)

from langchain_core.runnables import Runnable, RunnableConfig
Expand Down Expand Up @@ -289,6 +290,8 @@
for t in self.update
):
return self.update
elif hints := get_type_hints(type(self.update)):
return [(k, getattr(self.update, k)) for k in hints]
elif self.update is not None:
return [("__root__", self.update)]
else:
Expand Down
30 changes: 30 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import Counter, deque
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass
from random import randrange
from typing import (
Annotated,
Expand Down Expand Up @@ -5134,6 +5135,35 @@ def my_node(state: State):
assert graph.invoke({"foo": ""}) == {"foo": "ab"}


def test_command_pydantic_dataclass() -> None:
from pydantic import BaseModel

class PydanticState(BaseModel):
foo: str

@dataclass
class DataclassState:
foo: str

for State in (PydanticState, DataclassState):

def node_a(state) -> Command[Literal["node_b"]]:
return Command(
update=State(foo="foo"),
goto="node_b",
)

def node_b(state):
return {"foo": state.foo + "bar"}

builder = StateGraph(State)
builder.add_edge(START, "node_a")
builder.add_node(node_a)
builder.add_node(node_b)
graph = builder.compile()
assert graph.invoke(State(foo="")) == {"foo": "foobar"}


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_command_with_static_breakpoints(
request: pytest.FixtureRequest, checkpointer_name: str
Expand Down
Loading