diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index 3c3008ce4..5d104a85f 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -595,7 +595,7 @@ def prepare_single_task( for tid, c, v in pending_writes if tid in (NULL_TASK_ID, task_id) and c == RESUME ), - MISSING, + configurable.get(CONFIG_KEY_RESUME_VALUE, MISSING), ), }, ), @@ -720,7 +720,7 @@ def prepare_single_task( if tid in (NULL_TASK_ID, task_id) and c == RESUME ), - MISSING, + configurable.get(CONFIG_KEY_RESUME_VALUE, MISSING), ), }, ), diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index abe27eb28..d45cdb310 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -1,3 +1,4 @@ +from dataclasses import asdict from typing import ( Any, AsyncIterator, @@ -27,6 +28,7 @@ get_sync_client, ) from langgraph_sdk.schema import Checkpoint, ThreadState +from langgraph_sdk.schema import Command as CommandSDK from langgraph_sdk.schema import StreamMode as StreamModeSDK from typing_extensions import Self @@ -41,7 +43,7 @@ from langgraph.errors import GraphInterrupt from langgraph.pregel.protocol import PregelProtocol from langgraph.pregel.types import All, PregelTask, StateSnapshot, StreamMode -from langgraph.types import Interrupt, StreamProtocol +from langgraph.types import Command, Interrupt, StreamProtocol from langgraph.utils.config import merge_configs @@ -597,11 +599,17 @@ def stream( stream_modes, requested, req_single, stream = self._get_stream_modes( stream_mode, config ) + if isinstance(input, Command): + command: Optional[CommandSDK] = cast(CommandSDK, asdict(input)) + input = None + else: + command = None for chunk in sync_client.runs.stream( thread_id=sanitized_config["configurable"].get("thread_id"), assistant_id=self.name, input=input, + command=command, config=sanitized_config, stream_mode=stream_modes, interrupt_before=interrupt_before, @@ -680,11 +688,17 @@ async def astream( stream_modes, requested, req_single, stream = self._get_stream_modes( stream_mode, config ) + if isinstance(input, Command): + command: Optional[CommandSDK] = cast(CommandSDK, asdict(input)) + input = None + else: + command = None async for chunk in client.runs.stream( thread_id=sanitized_config["configurable"].get("thread_id"), assistant_id=self.name, input=input, + command=command, config=sanitized_config, stream_mode=stream_modes, interrupt_before=interrupt_before, diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index c2ed63d28..6ff70276f 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -8728,6 +8728,176 @@ def start(state: State) -> list[Union[Send, str]]: ) +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +def test_dynamic_interrupt_subgraph( + request: pytest.FixtureRequest, checkpointer_name: str +) -> None: + checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") + + class SubgraphState(TypedDict): + my_key: str + market: str + + tool_two_node_count = 0 + + def tool_two_node(s: SubgraphState) -> SubgraphState: + nonlocal tool_two_node_count + tool_two_node_count += 1 + if s["market"] == "DE": + answer = interrupt("Just because...") + else: + answer = " all good" + return {"my_key": answer} + + subgraph = StateGraph(SubgraphState) + subgraph.add_node("do", tool_two_node, retry=RetryPolicy()) + subgraph.add_edge(START, "do") + + class State(TypedDict): + my_key: Annotated[str, operator.add] + market: str + + tool_two_graph = StateGraph(State) + tool_two_graph.add_node("tool_two", subgraph.compile()) + tool_two_graph.add_edge(START, "tool_two") + tool_two = tool_two_graph.compile() + + tracer = FakeTracer() + assert tool_two.invoke( + {"my_key": "value", "market": "DE"}, {"callbacks": [tracer]} + ) == { + "my_key": "value", + "market": "DE", + } + assert tool_two_node_count == 1, "interrupts aren't retried" + assert len(tracer.runs) == 1 + run = tracer.runs[0] + assert run.end_time is not None + assert run.error is None + assert run.outputs == {"market": "DE", "my_key": "value"} + + assert tool_two.invoke({"my_key": "value", "market": "US"}) == { + "my_key": "value all good", + "market": "US", + } + + tool_two = tool_two_graph.compile(checkpointer=checkpointer) + + # missing thread_id + with pytest.raises(ValueError, match="thread_id"): + tool_two.invoke({"my_key": "value", "market": "DE"}) + + # flow: interrupt -> resume with answer + thread2 = {"configurable": {"thread_id": "2"}} + # stop when about to enter node + assert [ + c for c in tool_two.stream({"my_key": "value ⛰️", "market": "DE"}, thread2) + ] == [ + { + "__interrupt__": ( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:"), AnyStr("do:")], + ), + ) + }, + ] + # resume with answer + assert [c for c in tool_two.stream(Command(resume=" my answer"), thread2)] == [ + {"tool_two": {"my_key": " my answer", "market": "DE"}}, + ] + + # flow: interrupt -> clear tasks + thread1 = {"configurable": {"thread_id": "1"}} + # stop when about to enter node + assert tool_two.invoke({"my_key": "value ⛰️", "market": "DE"}, thread1) == { + "my_key": "value ⛰️", + "market": "DE", + } + assert [ + c.metadata + for c in tool_two.checkpointer.list( + {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} + ) + ] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] + assert tool_two.get_state(thread1) == StateSnapshot( + values={"my_key": "value ⛰️", "market": "DE"}, + next=("tool_two",), + tasks=( + PregelTask( + AnyStr(), + "tool_two", + (PULL, "tool_two"), + interrupts=( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:"), AnyStr("do:")], + ), + ), + state={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("tool_two:"), + } + }, + ), + ), + config=tool_two.checkpointer.get_tuple(thread1).config, + created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + metadata={ + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + parent_config=[ + *tool_two.checkpointer.list( + {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2 + ) + ][-1].config, + ) + # clear the interrupt and next tasks + tool_two.update_state(thread1, None, as_node=END) + # interrupt and next tasks are cleared + assert tool_two.get_state(thread1) == StateSnapshot( + values={"my_key": "value ⛰️", "market": "DE"}, + next=(), + tasks=(), + config=tool_two.checkpointer.get_tuple(thread1).config, + created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + metadata={ + "parents": {}, + "source": "update", + "step": 1, + "writes": {}, + "thread_id": "1", + }, + parent_config=[ + *tool_two.checkpointer.list( + {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2 + ) + ][-1].config, + ) + + @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_start_branch_then( snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str @@ -14471,3 +14641,35 @@ class CustomParentState(TypedDict): }, tasks=(), ) + + +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +def test_interrupt_subgraph(request: pytest.FixtureRequest, checkpointer_name: str): + checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") + + class State(TypedDict): + baz: str + + def foo(state): + return {"baz": "foo"} + + def bar(state): + value = interrupt("Please provide baz value:") + return {"baz": value} + + child_builder = StateGraph(State) + child_builder.add_node(bar) + child_builder.add_edge(START, "bar") + + builder = StateGraph(State) + builder.add_node(foo) + builder.add_node("bar", child_builder.compile()) + builder.add_edge(START, "foo") + builder.add_edge("foo", "bar") + graph = builder.compile(checkpointer=checkpointer) + + thread1 = {"configurable": {"thread_id": "1"}} + # First run, interrupted at bar + assert graph.invoke({"baz": ""}, thread1) + # Resume with answer + assert graph.invoke(Command(resume="bar"), thread1) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 813fa8240..378ca1ac1 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -429,6 +429,189 @@ async def tool_two_node(s: State) -> State: ) +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.11+ is required for async contextvars support", +) +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +async def test_dynamic_interrupt_subgraph(checkpointer_name: str) -> None: + class SubgraphState(TypedDict): + my_key: str + market: str + + tool_two_node_count = 0 + + def tool_two_node(s: SubgraphState) -> SubgraphState: + nonlocal tool_two_node_count + tool_two_node_count += 1 + if s["market"] == "DE": + answer = interrupt("Just because...") + else: + answer = " all good" + return {"my_key": answer} + + subgraph = StateGraph(SubgraphState) + subgraph.add_node("do", tool_two_node, retry=RetryPolicy()) + subgraph.add_edge(START, "do") + + class State(TypedDict): + my_key: Annotated[str, operator.add] + market: str + + tool_two_graph = StateGraph(State) + tool_two_graph.add_node("tool_two", subgraph.compile()) + tool_two_graph.add_edge(START, "tool_two") + tool_two = tool_two_graph.compile() + + tracer = FakeTracer() + assert await tool_two.ainvoke( + {"my_key": "value", "market": "DE"}, {"callbacks": [tracer]} + ) == { + "my_key": "value", + "market": "DE", + } + assert tool_two_node_count == 1, "interrupts aren't retried" + assert len(tracer.runs) == 1 + run = tracer.runs[0] + assert run.end_time is not None + assert run.error is None + assert run.outputs == {"market": "DE", "my_key": "value"} + + assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == { + "my_key": "value all good", + "market": "US", + } + + async with awith_checkpointer(checkpointer_name) as checkpointer: + tool_two = tool_two_graph.compile(checkpointer=checkpointer) + + # missing thread_id + with pytest.raises(ValueError, match="thread_id"): + await tool_two.ainvoke({"my_key": "value", "market": "DE"}) + + # flow: interrupt -> resume with answer + thread2 = {"configurable": {"thread_id": "2"}} + # stop when about to enter node + assert [ + c + async for c in tool_two.astream( + {"my_key": "value ⛰️", "market": "DE"}, thread2 + ) + ] == [ + { + "__interrupt__": ( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:"), AnyStr("do:")], + ), + ) + }, + ] + # resume with answer + assert [ + c async for c in tool_two.astream(Command(resume=" my answer"), thread2) + ] == [ + {"tool_two": {"my_key": " my answer", "market": "DE"}}, + ] + + # flow: interrupt -> clear + thread1 = {"configurable": {"thread_id": "1"}} + thread1root = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} + # stop when about to enter node + assert [ + c + async for c in tool_two.astream( + {"my_key": "value ⛰️", "market": "DE"}, thread1 + ) + ] == [ + { + "__interrupt__": ( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:"), AnyStr("do:")], + ), + ) + }, + ] + assert [c.metadata async for c in tool_two.checkpointer.alist(thread1root)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] + tup = await tool_two.checkpointer.aget_tuple(thread1) + assert await tool_two.aget_state(thread1) == StateSnapshot( + values={"my_key": "value ⛰️", "market": "DE"}, + next=("tool_two",), + tasks=( + PregelTask( + AnyStr(), + "tool_two", + (PULL, "tool_two"), + interrupts=( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:"), AnyStr("do:")], + ), + ), + state={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("tool_two:"), + } + }, + ), + ), + config=tup.config, + created_at=tup.checkpoint["ts"], + metadata={ + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + parent_config=[ + c async for c in tool_two.checkpointer.alist(thread1root, limit=2) + ][-1].config, + ) + + # clear the interrupt and next tasks + await tool_two.aupdate_state(thread1, None, as_node=END) + # interrupt is cleared, as well as the next tasks + tup = await tool_two.checkpointer.aget_tuple(thread1) + assert await tool_two.aget_state(thread1) == StateSnapshot( + values={"my_key": "value ⛰️", "market": "DE"}, + next=(), + tasks=(), + config=tup.config, + created_at=tup.checkpoint["ts"], + metadata={ + "parents": {}, + "source": "update", + "step": 1, + "writes": {}, + "thread_id": "1", + }, + parent_config=[ + c async for c in tool_two.checkpointer.alist(thread1root, limit=2) + ][-1].config, + ) + + @pytest.mark.skipif(not FF_SEND_V2, reason="send v2 is not enabled") @pytest.mark.skipif( sys.version_info < (3, 11), @@ -12677,3 +12860,39 @@ class CustomParentState(TypedDict): }, tasks=(), ) + + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.11+ is required for async contextvars support", +) +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +async def test_interrupt_subgraph(checkpointer_name: str): + class State(TypedDict): + baz: str + + def foo(state): + return {"baz": "foo"} + + def bar(state): + value = interrupt("Please provide baz value:") + return {"baz": value} + + child_builder = StateGraph(State) + child_builder.add_node(bar) + child_builder.add_edge(START, "bar") + + builder = StateGraph(State) + builder.add_node(foo) + builder.add_node("bar", child_builder.compile()) + builder.add_edge(START, "foo") + builder.add_edge("foo", "bar") + + async with awith_checkpointer(checkpointer_name) as checkpointer: + graph = builder.compile(checkpointer=checkpointer) + + thread1 = {"configurable": {"thread_id": "1"}} + # First run, interrupted at bar + assert await graph.ainvoke({"baz": ""}, thread1) + # Resume with answer + assert await graph.ainvoke(Command(resume="bar"), thread1) diff --git a/libs/sdk-py/langgraph_sdk/client.py b/libs/sdk-py/langgraph_sdk/client.py index 7a959a477..63eb17be1 100644 --- a/libs/sdk-py/langgraph_sdk/client.py +++ b/libs/sdk-py/langgraph_sdk/client.py @@ -3321,6 +3321,7 @@ def stream( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, @@ -3344,6 +3345,7 @@ def stream( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, @@ -3364,6 +3366,7 @@ def stream( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, @@ -3388,6 +3391,7 @@ def stream( assistant_id: The assistant ID or graph name to stream from. If using graph name, will default to first assistant created from that graph. input: The input to the graph. + command: The command to execute. stream_mode: The stream mode(s) to use. stream_subgraphs: Whether to stream output from subgraphs. metadata: Metadata to assign to the run. @@ -3438,6 +3442,7 @@ def stream( """ # noqa: E501 payload = { "input": input, + "command": command, "config": config, "metadata": metadata, "stream_mode": stream_mode, @@ -3471,6 +3476,7 @@ def create( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, @@ -3490,6 +3496,7 @@ def create( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, @@ -3510,6 +3517,7 @@ def create( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, @@ -3532,6 +3540,7 @@ def create( assistant_id: The assistant ID or graph name to stream from. If using graph name, will default to first assistant created from that graph. input: The input to the graph. + command: The command to execute. stream_mode: The stream mode(s) to use. stream_subgraphs: Whether to stream output from subgraphs. metadata: Metadata to assign to the run. @@ -3618,6 +3627,7 @@ def create( """ # noqa: E501 payload = { "input": input, + "command": command, "stream_mode": stream_mode, "stream_subgraphs": stream_subgraphs, "config": config, @@ -3655,6 +3665,7 @@ def wait( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, @@ -3675,6 +3686,7 @@ def wait( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, @@ -3692,6 +3704,7 @@ def wait( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, @@ -3713,6 +3726,7 @@ def wait( assistant_id: The assistant ID or graph name to run. If using graph name, will default to first assistant created from that graph. input: The input to the graph. + command: The command to execute. metadata: Metadata to assign to the run. config: The configuration for the assistant. checkpoint: The checkpoint to resume from. @@ -3779,6 +3793,7 @@ def wait( """ # noqa: E501 payload = { "input": input, + "command": command, "config": config, "metadata": metadata, "assistant_id": assistant_id,