From 2b65308508fa829198cef3ca547b61b46ca8142e Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 15 Nov 2024 14:01:16 -0800 Subject: [PATCH 1/6] WIP: Handle commands for subgraphs --- libs/langgraph/tests/test_pregel.py | 32 +++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index c2ed63d28..d45c3ae22 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -14471,3 +14471,35 @@ class CustomParentState(TypedDict): }, tasks=(), ) + + +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +def test_interrupt_subgraph(request: pytest.FixtureRequest, checkpointer_name: str): + checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") + + class State(TypedDict): + baz: str + + def foo(state): + return {"baz": "foo"} + + def bar(state): + value = interrupt("Please provide baz value:") + return {"baz": value} + + child_builder = StateGraph(State) + child_builder.add_node(bar) + child_builder.add_edge(START, "bar") + + builder = StateGraph(State) + builder.add_node(foo) + builder.add_node("bar", child_builder.compile()) + builder.add_edge(START, "foo") + builder.add_edge("foo", "bar") + graph = builder.compile(checkpointer=checkpointer) + + thread1 = {"configurable": {"thread_id": "1"}} + # First run, interrupted at bar + assert graph.invoke({"baz": ""}, thread1) + # Resume with answer + assert graph.invoke(Command(resume="bar"), thread1) From 6fc1c602ab151525904fcec033f327f75b0db34e Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 2 Dec 2024 13:53:11 -0800 Subject: [PATCH 2/6] Add one more --- libs/langgraph/tests/test_pregel.py | 153 ++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index d45c3ae22..113a6c3c8 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -8728,6 +8728,159 @@ def start(state: State) -> list[Union[Send, str]]: ) +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +def test_dynamic_interrupt_subgraph( + request: pytest.FixtureRequest, checkpointer_name: str +) -> None: + checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") + + class SubgraphState(TypedDict): + my_key: str + market: str + + tool_two_node_count = 0 + + def tool_two_node(s: SubgraphState) -> SubgraphState: + print("tool_two_node", s) + nonlocal tool_two_node_count + tool_two_node_count += 1 + if s["market"] == "DE": + answer = interrupt("Just because...") + else: + answer = " all good" + print("after interrupt") + return {"my_key": answer} + + subgraph = StateGraph(SubgraphState) + subgraph.add_node("do", tool_two_node, retry=RetryPolicy()) + subgraph.add_edge(START, "do") + + class State(TypedDict): + my_key: Annotated[str, operator.add] + market: str + + tool_two_graph = StateGraph(State) + tool_two_graph.add_node("tool_two", subgraph.compile()) + tool_two_graph.add_edge(START, "tool_two") + tool_two = tool_two_graph.compile() + + tracer = FakeTracer() + assert tool_two.invoke( + {"my_key": "value", "market": "DE"}, {"callbacks": [tracer]} + ) == { + "my_key": "value", + "market": "DE", + } + assert tool_two_node_count == 1, "interrupts aren't retried" + assert len(tracer.runs) == 1 + run = tracer.runs[0] + assert run.end_time is not None + assert run.error is None + assert run.outputs == {"market": "DE", "my_key": "value"} + + assert tool_two.invoke({"my_key": "value", "market": "US"}) == { + "my_key": "value all good", + "market": "US", + } + + tool_two = tool_two_graph.compile(checkpointer=checkpointer) + + # missing thread_id + with pytest.raises(ValueError, match="thread_id"): + tool_two.invoke({"my_key": "value", "market": "DE"}) + + # flow: interrupt -> resume with answer + thread2 = {"configurable": {"thread_id": "2"}} + # stop when about to enter node + assert [ + c for c in tool_two.stream({"my_key": "value ⛰️", "market": "DE"}, thread2) + ] == [ + { + "__interrupt__": ( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:"), AnyStr("do:")], + ), + ) + }, + ] + # resume with answer + assert [c for c in tool_two.stream(Command(resume=" my answer"), thread2)] == [ + {"tool_two": {"my_key": " my answer"}}, + ] + + # flow: interrupt -> clear tasks + thread1 = {"configurable": {"thread_id": "1"}} + # stop when about to enter node + assert tool_two.invoke({"my_key": "value ⛰️", "market": "DE"}, thread1) == { + "my_key": "value ⛰️", + "market": "DE", + } + assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] + assert tool_two.get_state(thread1) == StateSnapshot( + values={"my_key": "value ⛰️", "market": "DE"}, + next=("tool_two",), + tasks=( + PregelTask( + AnyStr(), + "tool_two", + (PULL, "tool_two"), + interrupts=( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:")], + ), + ), + ), + ), + config=tool_two.checkpointer.get_tuple(thread1).config, + created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + metadata={ + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + ) + # clear the interrupt and next tasks + tool_two.update_state(thread1, None, as_node=END) + # interrupt and next tasks are cleared + assert tool_two.get_state(thread1) == StateSnapshot( + values={"my_key": "value ⛰️", "market": "DE"}, + next=(), + tasks=(), + config=tool_two.checkpointer.get_tuple(thread1).config, + created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + metadata={ + "parents": {}, + "source": "update", + "step": 1, + "writes": {}, + "thread_id": "1", + }, + parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + ) + + @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_start_branch_then( snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str From efbd02a27da9abdb15d2c31aeb1c76ec2a430ccd Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 2 Dec 2024 16:43:35 -0800 Subject: [PATCH 3/6] Implement support for interrupt/resume in subgraphs --- libs/langgraph/langgraph/pregel/algo.py | 4 +- libs/langgraph/tests/test_pregel.py | 31 +++- libs/langgraph/tests/test_pregel_async.py | 215 ++++++++++++++++++++++ 3 files changed, 241 insertions(+), 9 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index 3c3008ce4..5d104a85f 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -595,7 +595,7 @@ def prepare_single_task( for tid, c, v in pending_writes if tid in (NULL_TASK_ID, task_id) and c == RESUME ), - MISSING, + configurable.get(CONFIG_KEY_RESUME_VALUE, MISSING), ), }, ), @@ -720,7 +720,7 @@ def prepare_single_task( if tid in (NULL_TASK_ID, task_id) and c == RESUME ), - MISSING, + configurable.get(CONFIG_KEY_RESUME_VALUE, MISSING), ), }, ), diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 113a6c3c8..6ff70276f 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -8741,14 +8741,12 @@ class SubgraphState(TypedDict): tool_two_node_count = 0 def tool_two_node(s: SubgraphState) -> SubgraphState: - print("tool_two_node", s) nonlocal tool_two_node_count tool_two_node_count += 1 if s["market"] == "DE": answer = interrupt("Just because...") else: answer = " all good" - print("after interrupt") return {"my_key": answer} subgraph = StateGraph(SubgraphState) @@ -8807,7 +8805,7 @@ class State(TypedDict): ] # resume with answer assert [c for c in tool_two.stream(Command(resume=" my answer"), thread2)] == [ - {"tool_two": {"my_key": " my answer"}}, + {"tool_two": {"my_key": " my answer", "market": "DE"}}, ] # flow: interrupt -> clear tasks @@ -8817,7 +8815,12 @@ class State(TypedDict): "my_key": "value ⛰️", "market": "DE", } - assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ + assert [ + c.metadata + for c in tool_two.checkpointer.list( + {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} + ) + ] == [ { "parents": {}, "source": "loop", @@ -8845,9 +8848,15 @@ class State(TypedDict): Interrupt( value="Just because...", resumable=True, - ns=[AnyStr("tool_two:")], + ns=[AnyStr("tool_two:"), AnyStr("do:")], ), ), + state={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("tool_two:"), + } + }, ), ), config=tool_two.checkpointer.get_tuple(thread1).config, @@ -8859,7 +8868,11 @@ class State(TypedDict): "writes": None, "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=[ + *tool_two.checkpointer.list( + {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2 + ) + ][-1].config, ) # clear the interrupt and next tasks tool_two.update_state(thread1, None, as_node=END) @@ -8877,7 +8890,11 @@ class State(TypedDict): "writes": {}, "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=[ + *tool_two.checkpointer.list( + {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2 + ) + ][-1].config, ) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 813fa8240..6c85e57ed 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -429,6 +429,189 @@ async def tool_two_node(s: State) -> State: ) +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.11+ is required for async contextvars support", +) +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +async def test_dynamic_interrupt_subgraph(checkpointer_name: str) -> None: + class SubgraphState(TypedDict): + my_key: str + market: str + + tool_two_node_count = 0 + + def tool_two_node(s: SubgraphState) -> SubgraphState: + nonlocal tool_two_node_count + tool_two_node_count += 1 + if s["market"] == "DE": + answer = interrupt("Just because...") + else: + answer = " all good" + return {"my_key": answer} + + subgraph = StateGraph(SubgraphState) + subgraph.add_node("do", tool_two_node, retry=RetryPolicy()) + subgraph.add_edge(START, "do") + + class State(TypedDict): + my_key: Annotated[str, operator.add] + market: str + + tool_two_graph = StateGraph(State) + tool_two_graph.add_node("tool_two", subgraph.compile()) + tool_two_graph.add_edge(START, "tool_two") + tool_two = tool_two_graph.compile() + + tracer = FakeTracer() + assert await tool_two.ainvoke( + {"my_key": "value", "market": "DE"}, {"callbacks": [tracer]} + ) == { + "my_key": "value", + "market": "DE", + } + assert tool_two_node_count == 1, "interrupts aren't retried" + assert len(tracer.runs) == 1 + run = tracer.runs[0] + assert run.end_time is not None + assert run.error is None + assert run.outputs == {"market": "DE", "my_key": "value"} + + assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == { + "my_key": "value all good", + "market": "US", + } + + async with awith_checkpointer(checkpointer_name) as checkpointer: + tool_two = tool_two_graph.compile(checkpointer=checkpointer) + + # missing thread_id + with pytest.raises(ValueError, match="thread_id"): + await tool_two.ainvoke({"my_key": "value", "market": "DE"}) + + # flow: interrupt -> resume with answer + thread2 = {"configurable": {"thread_id": "2"}} + # stop when about to enter node + assert [ + c + async for c in tool_two.astream( + {"my_key": "value ⛰️", "market": "DE"}, thread2 + ) + ] == [ + { + "__interrupt__": ( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:"), AnyStr("do:")], + ), + ) + }, + ] + # resume with answer + assert [ + c async for c in tool_two.astream(Command(resume=" my answer"), thread2) + ] == [ + {"tool_two": {"my_key": " my answer", "market": "DE"}}, + ] + + # flow: interrupt -> clear + thread1 = {"configurable": {"thread_id": "1"}} + thread1root = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} + # stop when about to enter node + assert [ + c + async for c in tool_two.astream( + {"my_key": "value ⛰️", "market": "DE"}, thread1 + ) + ] == [ + { + "__interrupt__": ( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:"), AnyStr("do:")], + ), + ) + }, + ] + assert [c.metadata async for c in tool_two.checkpointer.alist(thread1root)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] + tup = await tool_two.checkpointer.aget_tuple(thread1) + assert await tool_two.aget_state(thread1) == StateSnapshot( + values={"my_key": "value ⛰️", "market": "DE"}, + next=("tool_two",), + tasks=( + PregelTask( + AnyStr(), + "tool_two", + (PULL, "tool_two"), + interrupts=( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:"), AnyStr("do:")], + ), + ), + state={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("tool_two:"), + } + }, + ), + ), + config=tup.config, + created_at=tup.checkpoint["ts"], + metadata={ + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + parent_config=[ + c async for c in tool_two.checkpointer.alist(thread1root, limit=2) + ][-1].config, + ) + + # clear the interrupt and next tasks + await tool_two.aupdate_state(thread1, None, as_node=END) + # interrupt is cleared, as well as the next tasks + tup = await tool_two.checkpointer.aget_tuple(thread1) + assert await tool_two.aget_state(thread1) == StateSnapshot( + values={"my_key": "value ⛰️", "market": "DE"}, + next=(), + tasks=(), + config=tup.config, + created_at=tup.checkpoint["ts"], + metadata={ + "parents": {}, + "source": "update", + "step": 1, + "writes": {}, + "thread_id": "1", + }, + parent_config=[ + c async for c in tool_two.checkpointer.alist(thread1root, limit=2) + ][-1].config, + ) + + @pytest.mark.skipif(not FF_SEND_V2, reason="send v2 is not enabled") @pytest.mark.skipif( sys.version_info < (3, 11), @@ -12677,3 +12860,35 @@ class CustomParentState(TypedDict): }, tasks=(), ) + + +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +async def test_interrupt_subgraph(checkpointer_name: str): + class State(TypedDict): + baz: str + + def foo(state): + return {"baz": "foo"} + + def bar(state): + value = interrupt("Please provide baz value:") + return {"baz": value} + + child_builder = StateGraph(State) + child_builder.add_node(bar) + child_builder.add_edge(START, "bar") + + builder = StateGraph(State) + builder.add_node(foo) + builder.add_node("bar", child_builder.compile()) + builder.add_edge(START, "foo") + builder.add_edge("foo", "bar") + + async with awith_checkpointer(checkpointer_name) as checkpointer: + graph = builder.compile(checkpointer=checkpointer) + + thread1 = {"configurable": {"thread_id": "1"}} + # First run, interrupted at bar + assert await graph.ainvoke({"baz": ""}, thread1) + # Resume with answer + assert await graph.ainvoke(Command(resume="bar"), thread1) From a3feaef2eb0a25ccdc0abbeedd8107c2b1b8e3b8 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 2 Dec 2024 16:45:07 -0800 Subject: [PATCH 4/6] lib: Handle Command in RemoteGraph --- libs/langgraph/langgraph/pregel/remote.py | 15 ++++++++++++++- libs/sdk-py/langgraph_sdk/client.py | 15 +++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index abe27eb28..8ee7ec884 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -1,3 +1,4 @@ +from dataclasses import asdict from typing import ( Any, AsyncIterator, @@ -41,7 +42,7 @@ from langgraph.errors import GraphInterrupt from langgraph.pregel.protocol import PregelProtocol from langgraph.pregel.types import All, PregelTask, StateSnapshot, StreamMode -from langgraph.types import Interrupt, StreamProtocol +from langgraph.types import Command, Interrupt, StreamProtocol from langgraph.utils.config import merge_configs @@ -597,11 +598,17 @@ def stream( stream_modes, requested, req_single, stream = self._get_stream_modes( stream_mode, config ) + if isinstance(input, Command): + command: dict[str, Any] = asdict(input) + input = None + else: + command = None for chunk in sync_client.runs.stream( thread_id=sanitized_config["configurable"].get("thread_id"), assistant_id=self.name, input=input, + command=command, config=sanitized_config, stream_mode=stream_modes, interrupt_before=interrupt_before, @@ -680,11 +687,17 @@ async def astream( stream_modes, requested, req_single, stream = self._get_stream_modes( stream_mode, config ) + if isinstance(input, Command): + command: dict[str, Any] = asdict(input) + input = None + else: + command = None async for chunk in client.runs.stream( thread_id=sanitized_config["configurable"].get("thread_id"), assistant_id=self.name, input=input, + command=command, config=sanitized_config, stream_mode=stream_modes, interrupt_before=interrupt_before, diff --git a/libs/sdk-py/langgraph_sdk/client.py b/libs/sdk-py/langgraph_sdk/client.py index 81e8a8506..be909d336 100644 --- a/libs/sdk-py/langgraph_sdk/client.py +++ b/libs/sdk-py/langgraph_sdk/client.py @@ -3303,6 +3303,7 @@ def stream( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, @@ -3326,6 +3327,7 @@ def stream( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, @@ -3346,6 +3348,7 @@ def stream( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, @@ -3370,6 +3373,7 @@ def stream( assistant_id: The assistant ID or graph name to stream from. If using graph name, will default to first assistant created from that graph. input: The input to the graph. + command: The command to execute. stream_mode: The stream mode(s) to use. stream_subgraphs: Whether to stream output from subgraphs. metadata: Metadata to assign to the run. @@ -3420,6 +3424,7 @@ def stream( """ # noqa: E501 payload = { "input": input, + "command": command, "config": config, "metadata": metadata, "stream_mode": stream_mode, @@ -3453,6 +3458,7 @@ def create( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, @@ -3472,6 +3478,7 @@ def create( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, @@ -3492,6 +3499,7 @@ def create( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, stream_mode: Union[StreamMode, Sequence[StreamMode]] = "values", stream_subgraphs: bool = False, metadata: Optional[dict] = None, @@ -3514,6 +3522,7 @@ def create( assistant_id: The assistant ID or graph name to stream from. If using graph name, will default to first assistant created from that graph. input: The input to the graph. + command: The command to execute. stream_mode: The stream mode(s) to use. stream_subgraphs: Whether to stream output from subgraphs. metadata: Metadata to assign to the run. @@ -3600,6 +3609,7 @@ def create( """ # noqa: E501 payload = { "input": input, + "command": command, "stream_mode": stream_mode, "stream_subgraphs": stream_subgraphs, "config": config, @@ -3637,6 +3647,7 @@ def wait( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, @@ -3657,6 +3668,7 @@ def wait( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, interrupt_before: Optional[Union[All, Sequence[str]]] = None, @@ -3674,6 +3686,7 @@ def wait( assistant_id: str, *, input: Optional[dict] = None, + command: Optional[Command] = None, metadata: Optional[dict] = None, config: Optional[Config] = None, checkpoint: Optional[Checkpoint] = None, @@ -3695,6 +3708,7 @@ def wait( assistant_id: The assistant ID or graph name to run. If using graph name, will default to first assistant created from that graph. input: The input to the graph. + command: The command to execute. metadata: Metadata to assign to the run. config: The configuration for the assistant. checkpoint: The checkpoint to resume from. @@ -3761,6 +3775,7 @@ def wait( """ # noqa: E501 payload = { "input": input, + "command": command, "config": config, "metadata": metadata, "assistant_id": assistant_id, From d36e6ceaaf3c9c9d57ae751bf1ab9e965fec370e Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 2 Dec 2024 16:59:38 -0800 Subject: [PATCH 5/6] Fix --- libs/langgraph/tests/test_pregel_async.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 6c85e57ed..378ca1ac1 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -12862,6 +12862,10 @@ class CustomParentState(TypedDict): ) +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.11+ is required for async contextvars support", +) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_interrupt_subgraph(checkpointer_name: str): class State(TypedDict): From 0071bd1e1cf91f84a9c364ee3eae8eacd9ad1001 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 2 Dec 2024 17:01:12 -0800 Subject: [PATCH 6/6] Lint --- libs/langgraph/langgraph/pregel/remote.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index 8ee7ec884..d45cdb310 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -28,6 +28,7 @@ get_sync_client, ) from langgraph_sdk.schema import Checkpoint, ThreadState +from langgraph_sdk.schema import Command as CommandSDK from langgraph_sdk.schema import StreamMode as StreamModeSDK from typing_extensions import Self @@ -599,7 +600,7 @@ def stream( stream_mode, config ) if isinstance(input, Command): - command: dict[str, Any] = asdict(input) + command: Optional[CommandSDK] = cast(CommandSDK, asdict(input)) input = None else: command = None @@ -688,7 +689,7 @@ async def astream( stream_mode, config ) if isinstance(input, Command): - command: dict[str, Any] = asdict(input) + command: Optional[CommandSDK] = cast(CommandSDK, asdict(input)) input = None else: command = None