Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lib: Add unit test for multistep planner graph #2695

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import enum

Check notice on line 1 in libs/langgraph/tests/test_pregel.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 62.1 ms +- 1.4 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 52.4 ms +- 1.1 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 93.5 ms +- 7.3 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 95.0 ms +- 1.9 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 616 ms +- 10 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 512 ms +- 6 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 936 ms +- 39 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 947 ms +- 16 ms ......................................... react_agent_10x: Mean +- std dev: 31.1 ms +- 0.6 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.8 ms +- 0.2 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 47.0 ms +- 0.8 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.9 ms +- 0.4 ms ......................................... react_agent_100x: Mean +- std dev: 347 ms +- 6 ms ......................................... react_agent_100x_sync: Mean +- std dev: 274 ms +- 4 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 938 ms +- 8 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 837 ms +- 7 ms ......................................... wide_state_25x300: Mean +- std dev: 23.8 ms +- 0.5 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 14.9 ms +- 0.3 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 287 ms +- 13 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 273 ms +- 12 ms ......................................... wide_state_15x600: Mean +- std dev: 27.6 ms +- 0.4 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 17.3 ms +- 0.1 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 486 ms +- 13 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 474 ms +- 14 ms ......................................... wide_state_9x1200: Mean +- std dev: 27.6 ms +- 0.6 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 17.3 ms +- 0.1 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 321 ms +- 13 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 306 ms +- 13 ms

Check notice on line 1 in libs/langgraph/tests/test_pregel.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +========================================+=========+=======================+ | fanout_to_subgraph_100x | 628 ms | 616 ms: 1.02x faster | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 952 ms | 936 ms: 1.02x faster | +----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 17.3 ms | 17.3 ms: 1.00x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 510 ms | 512 ms: 1.00x slower | +----------------------------------------+---------+-----------------------+ | wide_state_15x600 | 27.5 ms | 27.6 ms: 1.00x slower | +----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 833 ms | 837 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint_sync | 36.6 ms | 36.9 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 94.4 ms | 95.0 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint | 46.6 ms | 47.0 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 930 ms | 938 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.00x slower | +----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (18): fanout_to_subgraph_10x_checkpoint, react_agent_100x_sync, fanout_to_subgraph_10x, wide_state_9x1200_checkpoint_sync, wide_state_9x1200_sync, react_agent_10x_sync, react_agent_100x, wide_state_25x300, wide_state_25x300_sync, wide_state_25x300_checkpoint_sync, wide_state_15x600_checkpoint, wide_state_25x300_checkpoint, wide_state_9x1200_checkpoint, fanout_to_subgraph_100x_checkpoint_sync, react_agent_10x, fanout_to_subgraph_10x_sync, wide_state_9x1200, wide_state_15x600_checkpoint_sync
import json
import logging
import operator
Expand Down Expand Up @@ -14944,3 +14944,64 @@
graph.invoke({"foo": "abc"}, config)
result = graph.invoke(Command(resume="node1"), config)
assert result == {"foo": "abc|node-1|node-2"}


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_multistep_plan(request: pytest.FixtureRequest, checkpointer_name: str):
from langchain_core.messages import AnyMessage

checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

class State(TypedDict, total=False):
plan: list[Union[str, list[str]]]
messages: Annotated[list[AnyMessage], add_messages]

def planner(state: State):
if state.get("plan") is None:
# create plan somehow
plan = ["step1", ["step2", "step3"], "step4"]
# pick the first step to execute next
first_step, *plan = plan
# put the rest of plan in state
return Command(goto=first_step, update={"plan": plan})
elif state["plan"]:
# go to the next step of the plan
next_step, *next_plan = state["plan"]
return Command(goto=next_step, update={"plan": next_plan})
else:
# the end of the plan
pass

def step1(state: State):
return Command(goto="planner", update={"messages": [("human", "step1")]})

def step2(state: State):
return Command(goto="planner", update={"messages": [("human", "step2")]})

def step3(state: State):
return Command(goto="planner", update={"messages": [("human", "step3")]})

def step4(state: State):
return Command(goto="planner", update={"messages": [("human", "step4")]})

builder = StateGraph(State)
builder.add_node(planner)
builder.add_node(step1)
builder.add_node(step2)
builder.add_node(step3)
builder.add_node(step4)
builder.add_edge(START, "planner")
graph = builder.compile(checkpointer=checkpointer)

config = {"configurable": {"thread_id": "1"}}

assert graph.invoke({"messages": [("human", "start")]}, config) == {
"messages": [
_AnyIdHumanMessage(content="start"),
_AnyIdHumanMessage(content="step1"),
_AnyIdHumanMessage(content="step2"),
_AnyIdHumanMessage(content="step3"),
_AnyIdHumanMessage(content="step4"),
],
"plan": [],
}
61 changes: 61 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -13199,3 +13199,64 @@ async def ask_age(s: State):
] == [
{"node": {"age": 19}},
]


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_multistep_plan(checkpointer_name: str):
from langchain_core.messages import AnyMessage

class State(TypedDict, total=False):
plan: list[Union[str, list[str]]]
messages: Annotated[list[AnyMessage], add_messages]

def planner(state: State):
if state.get("plan") is None:
# create plan somehow
plan = ["step1", ["step2", "step3"], "step4"]
# pick the first step to execute next
first_step, *plan = plan
# put the rest of plan in state
return Command(goto=first_step, update={"plan": plan})
elif state["plan"]:
# go to the next step of the plan
next_step, *next_plan = state["plan"]
return Command(goto=next_step, update={"plan": next_plan})
else:
# the end of the plan
pass

def step1(state: State):
return Command(goto="planner", update={"messages": [("human", "step1")]})

def step2(state: State):
return Command(goto="planner", update={"messages": [("human", "step2")]})

def step3(state: State):
return Command(goto="planner", update={"messages": [("human", "step3")]})

def step4(state: State):
return Command(goto="planner", update={"messages": [("human", "step4")]})

builder = StateGraph(State)
builder.add_node(planner)
builder.add_node(step1)
builder.add_node(step2)
builder.add_node(step3)
builder.add_node(step4)
builder.add_edge(START, "planner")

async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer)

config = {"configurable": {"thread_id": "1"}}

assert await graph.ainvoke({"messages": [("human", "start")]}, config) == {
"messages": [
_AnyIdHumanMessage(content="start"),
_AnyIdHumanMessage(content="step1"),
_AnyIdHumanMessage(content="step2"),
_AnyIdHumanMessage(content="step3"),
_AnyIdHumanMessage(content="step4"),
],
"plan": [],
}
Loading