From 3629be27c9196f3a9c5173ef9fdcc874a2daae1e Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 10 Dec 2024 11:00:09 -0800 Subject: [PATCH] Fix --- libs/langgraph/langgraph/pregel/loop.py | 24 ++++++++-------- libs/langgraph/tests/test_pregel.py | 16 +++++------ libs/langgraph/tests/test_pregel_async.py | 35 +++++++++++++++++++++++ 3 files changed, 54 insertions(+), 21 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index e962762599..a8e945edd8 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -422,18 +422,6 @@ def tick( self.status = "out_of_steps" return False - # apply NULL writes - if null_writes := [ - w[1:] for w in self.checkpoint_pending_writes if w[0] == NULL_TASK_ID - ]: - mv_writes = apply_writes( - self.checkpoint, - self.channels, - [PregelTaskWrites((), INPUT, null_writes, [])], - self.checkpointer_get_next_version, - ) - for key, values in mv_writes.items(): - self._update_mv(key, values) # prepare next tasks self.tasks = prepare_next_tasks( self.checkpoint, @@ -552,6 +540,18 @@ def _first(self, *, input_keys: Union[str, Sequence[str]]) -> None: # save writes for tid, ws in writes.items(): self.put_writes(tid, ws) + # apply NULL writes + if null_writes := [ + w[1:] for w in self.checkpoint_pending_writes if w[0] == NULL_TASK_ID + ]: + mv_writes = apply_writes( + self.checkpoint, + self.channels, + [PregelTaskWrites((), INPUT, null_writes, [])], + self.checkpointer_get_next_version, + ) + for key, values in mv_writes.items(): + self._update_mv(key, values) # proceed past previous checkpoint if is_resuming: self.checkpoint["versions_seen"].setdefault(INTERRUPT, {}) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 6aa17f6b2f..00cd0a604f 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -14906,9 +14906,14 @@ def my_node(state: State): assert graph.invoke({"foo": ""}) == {"foo": "ab"} -def test_command_with_static_breakpoints() -> None: +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +def test_command_with_static_breakpoints( + request: pytest.FixtureRequest, checkpointer_name: str +) -> None: """Test that we can use Command to resume and update with static breakpoints.""" + checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") + class State(TypedDict): """The graph state.""" @@ -14930,15 +14935,8 @@ def node2(state: State): builder.add_edge(START, "node1") builder.add_edge("node1", "node2") - # A checkpointer must be enabled for interrupts to work! - checkpointer = MemorySaver() graph = builder.compile(checkpointer=checkpointer, interrupt_before=["node1"]) - - config = { - "configurable": { - "thread_id": uuid.uuid4(), - } - } + config = {"configurable": {"thread_id": str(uuid.uuid4())}} # Start the graph and interrupt at the first node graph.invoke({"foo": "abc"}, config) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index b3317e5aec..19fe6e65ad 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -13199,3 +13199,38 @@ async def ask_age(s: State): ] == [ {"node": {"age": 19}}, ] + + +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +async def test_command_with_static_breakpoints(checkpointer_name: str) -> None: + """Test that we can use Command to resume and update with static breakpoints.""" + + class State(TypedDict): + """The graph state.""" + + foo: str + + def node1(state: State): + return { + "foo": state["foo"] + "|node-1", + } + + def node2(state: State): + return { + "foo": state["foo"] + "|node-2", + } + + builder = StateGraph(State) + builder.add_node("node1", node1) + builder.add_node("node2", node2) + builder.add_edge(START, "node1") + builder.add_edge("node1", "node2") + + async with awith_checkpointer(checkpointer_name) as checkpointer: + graph = builder.compile(checkpointer=checkpointer, interrupt_before=["node1"]) + config = {"configurable": {"thread_id": str(uuid.uuid4())}} + + # Start the graph and interrupt at the first node + await graph.ainvoke({"foo": "abc"}, config) + result = await graph.ainvoke(Command(update={"foo": "def"}), config) + assert result == {"foo": "def|node-1|node-2"}