From ef6c5b471126c17083d7c885ab5a4175aab14c09 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 10 Dec 2024 10:40:34 -0800 Subject: [PATCH] lib: Add unit test for multistep planner graph --- libs/langgraph/tests/test_pregel.py | 61 +++++++++++++++++++++++ libs/langgraph/tests/test_pregel_async.py | 61 +++++++++++++++++++++++ 2 files changed, 122 insertions(+) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 91d7907d0..0103ce6b2 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -14944,3 +14944,64 @@ def node2(state: State): graph.invoke({"foo": "abc"}, config) result = graph.invoke(Command(resume="node1"), config) assert result == {"foo": "abc|node-1|node-2"} + + +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +def test_multistep_plan(request: pytest.FixtureRequest, checkpointer_name: str): + from langchain_core.messages import AnyMessage + + checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") + + class State(TypedDict, total=False): + plan: list[Union[str, list[str]]] + messages: Annotated[list[AnyMessage], add_messages] + + def planner(state: State): + if state.get("plan") is None: + # create plan somehow + plan = ["step1", ["step2", "step3"], "step4"] + # pick the first step to execute next + first_step, *plan = plan + # put the rest of plan in state + return Command(goto=first_step, update={"plan": plan}) + elif state["plan"]: + # go to the next step of the plan + next_step, *next_plan = state["plan"] + return Command(goto=next_step, update={"plan": next_plan}) + else: + # the end of the plan + pass + + def step1(state: State): + return Command(goto="planner", update={"messages": [("human", "step1")]}) + + def step2(state: State): + return Command(goto="planner", update={"messages": [("human", "step2")]}) + + def step3(state: State): + return Command(goto="planner", update={"messages": [("human", "step3")]}) + + def step4(state: State): + return Command(goto="planner", update={"messages": [("human", "step4")]}) + + builder = StateGraph(State) + builder.add_node(planner) + builder.add_node(step1) + builder.add_node(step2) + builder.add_node(step3) + builder.add_node(step4) + builder.add_edge(START, "planner") + graph = builder.compile(checkpointer=checkpointer) + + config = {"configurable": {"thread_id": "1"}} + + assert graph.invoke({"messages": [("human", "start")]}, config) == { + "messages": [ + _AnyIdHumanMessage(content="start"), + _AnyIdHumanMessage(content="step1"), + _AnyIdHumanMessage(content="step2"), + _AnyIdHumanMessage(content="step3"), + _AnyIdHumanMessage(content="step4"), + ], + "plan": [], + } diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index b3317e5ae..0064117bf 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -13199,3 +13199,64 @@ async def ask_age(s: State): ] == [ {"node": {"age": 19}}, ] + + +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +async def test_multistep_plan(checkpointer_name: str): + from langchain_core.messages import AnyMessage + + class State(TypedDict, total=False): + plan: list[Union[str, list[str]]] + messages: Annotated[list[AnyMessage], add_messages] + + def planner(state: State): + if state.get("plan") is None: + # create plan somehow + plan = ["step1", ["step2", "step3"], "step4"] + # pick the first step to execute next + first_step, *plan = plan + # put the rest of plan in state + return Command(goto=first_step, update={"plan": plan}) + elif state["plan"]: + # go to the next step of the plan + next_step, *next_plan = state["plan"] + return Command(goto=next_step, update={"plan": next_plan}) + else: + # the end of the plan + pass + + def step1(state: State): + return Command(goto="planner", update={"messages": [("human", "step1")]}) + + def step2(state: State): + return Command(goto="planner", update={"messages": [("human", "step2")]}) + + def step3(state: State): + return Command(goto="planner", update={"messages": [("human", "step3")]}) + + def step4(state: State): + return Command(goto="planner", update={"messages": [("human", "step4")]}) + + builder = StateGraph(State) + builder.add_node(planner) + builder.add_node(step1) + builder.add_node(step2) + builder.add_node(step3) + builder.add_node(step4) + builder.add_edge(START, "planner") + + async with awith_checkpointer(checkpointer_name) as checkpointer: + graph = builder.compile(checkpointer=checkpointer) + + config = {"configurable": {"thread_id": "1"}} + + assert await graph.ainvoke({"messages": [("human", "start")]}, config) == { + "messages": [ + _AnyIdHumanMessage(content="start"), + _AnyIdHumanMessage(content="step1"), + _AnyIdHumanMessage(content="step2"), + _AnyIdHumanMessage(content="step3"), + _AnyIdHumanMessage(content="step4"), + ], + "plan": [], + }