diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index c8812f1c8..81257c883 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -65,6 +65,7 @@ CONFIG_KEY_STREAM, CONFIG_KEY_STREAM_WRITER, CONFIG_KEY_TASK_ID, + END, ERROR, INPUT, INTERRUPT, @@ -901,8 +902,8 @@ def update_state( checkpoint, LoopProtocol(config=config, step=step + 1, stop=step + 2), ) as (channels, managed): - # no values, just clear all tasks - if values is None and as_node is None: + # no values as END, just clear all tasks + if values is None and as_node == END: if saved is not None: # tasks for this checkpoint next_tasks = prepare_next_tasks( @@ -955,6 +956,25 @@ def update_state( return patch_checkpoint_map( next_config, saved.metadata if saved else None ) + # no values, copy checkpoint + if values is None and as_node is None: + next_checkpoint = create_checkpoint(checkpoint, None, step) + # copy checkpoint + next_config = checkpointer.put( + checkpoint_config, + next_checkpoint, + { + **checkpoint_metadata, + "source": "update", + "step": step + 1, + "writes": {}, + "parents": saved.metadata.get("parents", {}) if saved else {}, + }, + {}, + ) + return patch_checkpoint_map( + next_config, saved.metadata if saved else None + ) # apply pending writes, if not on specific checkpoint if ( CONFIG_KEY_CHECKPOINT_ID not in config[CONF] diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index f31509e86..fda881b15 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -8556,8 +8556,8 @@ def tool_two_node(s: State) -> State: parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, ) # clear the interrupt and next tasks - tool_two.update_state(thread1, None) - # interrupt is cleared, task will still run next + tool_two.update_state(thread1, None, as_node=END) + # interrupt and next tasks are cleared assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, next=(), @@ -8575,6 +8575,167 @@ def tool_two_node(s: State) -> State: ) +@pytest.mark.skipif(not FF_SEND_V2, reason="send v2 is not enabled") +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +def test_copy_checkpoint( + request: pytest.FixtureRequest, checkpointer_name: str +) -> None: + checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") + + class State(TypedDict): + my_key: Annotated[str, operator.add] + market: str + + def tool_one(s: State) -> State: + return {"my_key": " one"} + + tool_two_node_count = 0 + + def tool_two_node(s: State) -> State: + nonlocal tool_two_node_count + tool_two_node_count += 1 + if s["market"] == "DE": + answer = interrupt("Just because...") + else: + answer = " all good" + return {"my_key": answer} + + def start(state: State) -> list[Union[Send, str]]: + return ["tool_two", Send("tool_one", state)] + + tool_two_graph = StateGraph(State) + tool_two_graph.add_node("tool_two", tool_two_node, retry=RetryPolicy()) + tool_two_graph.add_node("tool_one", tool_one) + tool_two_graph.set_conditional_entry_point(start) + tool_two = tool_two_graph.compile() + + tracer = FakeTracer() + assert tool_two.invoke( + {"my_key": "value", "market": "DE"}, {"callbacks": [tracer]} + ) == { + "my_key": "value one", + "market": "DE", + } + assert tool_two_node_count == 1, "interrupts aren't retried" + assert len(tracer.runs) == 1 + run = tracer.runs[0] + assert run.end_time is not None + assert run.error is None + assert run.outputs == {"market": "DE", "my_key": "value one"} + + assert tool_two.invoke({"my_key": "value", "market": "US"}) == { + "my_key": "value one all good", + "market": "US", + } + + tool_two = tool_two_graph.compile(checkpointer=checkpointer) + + # missing thread_id + with pytest.raises(ValueError, match="thread_id"): + tool_two.invoke({"my_key": "value", "market": "DE"}) + + # flow: interrupt -> resume with answer + thread2 = {"configurable": {"thread_id": "2"}} + # stop when about to enter node + assert [ + c for c in tool_two.stream({"my_key": "value ⛰️", "market": "DE"}, thread2) + ] == [ + { + "tool_one": {"my_key": " one"}, + }, + { + "__interrupt__": ( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:")], + ), + ) + }, + ] + # resume with answer + assert [c for c in tool_two.stream(Command(resume=" my answer"), thread2)] == [ + {"tool_two": {"my_key": " my answer"}}, + ] + + # flow: interrupt -> clear tasks + thread1 = {"configurable": {"thread_id": "1"}} + # stop when about to enter node + assert tool_two.invoke({"my_key": "value ⛰️", "market": "DE"}, thread1) == { + "my_key": "value ⛰️ one", + "market": "DE", + } + assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": {"tool_one": {"my_key": " one"}}, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] + assert tool_two.get_state(thread1) == StateSnapshot( + values={"my_key": "value ⛰️ one", "market": "DE"}, + next=("tool_two",), + tasks=( + PregelTask( + AnyStr(), + "tool_two", + (PULL, "tool_two"), + interrupts=( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:")], + ), + ), + ), + ), + config=tool_two.checkpointer.get_tuple(thread1).config, + created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + metadata={ + "parents": {}, + "source": "loop", + "step": 0, + "writes": {"tool_one": {"my_key": " one"}}, + "thread_id": "1", + }, + parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + ) + # clear the interrupt and next tasks + tool_two.update_state(thread1, None) + # interrupt is cleared, next task is kept + assert tool_two.get_state(thread1) == StateSnapshot( + values={"my_key": "value ⛰️ one", "market": "DE"}, + next=("tool_two",), + tasks=( + PregelTask( + AnyStr(), + "tool_two", + (PULL, "tool_two"), + interrupts=(), + ), + ), + config=tool_two.checkpointer.get_tuple(thread1).config, + created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + metadata={ + "parents": {}, + "source": "update", + "step": 1, + "writes": {}, + "thread_id": "1", + }, + parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + ) + + @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_start_branch_then( snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str