Skip to content

Commit

Permalink
Merge pull request #2430 from langchain-ai/nc/15nov/command-subgraph
Browse files Browse the repository at this point in the history
Handle interrupt/resume for subgraphs
  • Loading branch information
nfcampos authored Dec 3, 2024
2 parents 15f0765 + 4e26a5c commit d70b659
Show file tree
Hide file tree
Showing 5 changed files with 453 additions and 3 deletions.
4 changes: 2 additions & 2 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def prepare_single_task(
for tid, c, v in pending_writes
if tid in (NULL_TASK_ID, task_id) and c == RESUME
),
MISSING,
configurable.get(CONFIG_KEY_RESUME_VALUE, MISSING),
),
},
),
Expand Down Expand Up @@ -720,7 +720,7 @@ def prepare_single_task(
if tid in (NULL_TASK_ID, task_id)
and c == RESUME
),
MISSING,
configurable.get(CONFIG_KEY_RESUME_VALUE, MISSING),
),
},
),
Expand Down
16 changes: 15 additions & 1 deletion libs/langgraph/langgraph/pregel/remote.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import asdict
from typing import (
Any,
AsyncIterator,
Expand Down Expand Up @@ -27,6 +28,7 @@
get_sync_client,
)
from langgraph_sdk.schema import Checkpoint, ThreadState
from langgraph_sdk.schema import Command as CommandSDK
from langgraph_sdk.schema import StreamMode as StreamModeSDK
from typing_extensions import Self

Expand All @@ -41,7 +43,7 @@
from langgraph.errors import GraphInterrupt
from langgraph.pregel.protocol import PregelProtocol
from langgraph.pregel.types import All, PregelTask, StateSnapshot, StreamMode
from langgraph.types import Interrupt, StreamProtocol
from langgraph.types import Command, Interrupt, StreamProtocol
from langgraph.utils.config import merge_configs


Expand Down Expand Up @@ -597,11 +599,17 @@ def stream(
stream_modes, requested, req_single, stream = self._get_stream_modes(
stream_mode, config
)
if isinstance(input, Command):
command: Optional[CommandSDK] = cast(CommandSDK, asdict(input))
input = None
else:
command = None

for chunk in sync_client.runs.stream(
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
command=command,
config=sanitized_config,
stream_mode=stream_modes,
interrupt_before=interrupt_before,
Expand Down Expand Up @@ -680,11 +688,17 @@ async def astream(
stream_modes, requested, req_single, stream = self._get_stream_modes(
stream_mode, config
)
if isinstance(input, Command):
command: Optional[CommandSDK] = cast(CommandSDK, asdict(input))
input = None
else:
command = None

async for chunk in client.runs.stream(
thread_id=sanitized_config["configurable"].get("thread_id"),
assistant_id=self.name,
input=input,
command=command,
config=sanitized_config,
stream_mode=stream_modes,
interrupt_before=interrupt_before,
Expand Down
202 changes: 202 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8728,6 +8728,176 @@ def start(state: State) -> list[Union[Send, str]]:
)


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_dynamic_interrupt_subgraph(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

class SubgraphState(TypedDict):
my_key: str
market: str

tool_two_node_count = 0

def tool_two_node(s: SubgraphState) -> SubgraphState:
nonlocal tool_two_node_count
tool_two_node_count += 1
if s["market"] == "DE":
answer = interrupt("Just because...")
else:
answer = " all good"
return {"my_key": answer}

subgraph = StateGraph(SubgraphState)
subgraph.add_node("do", tool_two_node, retry=RetryPolicy())
subgraph.add_edge(START, "do")

class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str

tool_two_graph = StateGraph(State)
tool_two_graph.add_node("tool_two", subgraph.compile())
tool_two_graph.add_edge(START, "tool_two")
tool_two = tool_two_graph.compile()

tracer = FakeTracer()
assert tool_two.invoke(
{"my_key": "value", "market": "DE"}, {"callbacks": [tracer]}
) == {
"my_key": "value",
"market": "DE",
}
assert tool_two_node_count == 1, "interrupts aren't retried"
assert len(tracer.runs) == 1
run = tracer.runs[0]
assert run.end_time is not None
assert run.error is None
assert run.outputs == {"market": "DE", "my_key": "value"}

assert tool_two.invoke({"my_key": "value", "market": "US"}) == {
"my_key": "value all good",
"market": "US",
}

tool_two = tool_two_graph.compile(checkpointer=checkpointer)

# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
tool_two.invoke({"my_key": "value", "market": "DE"})

# flow: interrupt -> resume with answer
thread2 = {"configurable": {"thread_id": "2"}}
# stop when about to enter node
assert [
c for c in tool_two.stream({"my_key": "value ⛰️", "market": "DE"}, thread2)
] == [
{
"__interrupt__": (
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:"), AnyStr("do:")],
),
)
},
]
# resume with answer
assert [c for c in tool_two.stream(Command(resume=" my answer"), thread2)] == [
{"tool_two": {"my_key": " my answer", "market": "DE"}},
]

# flow: interrupt -> clear tasks
thread1 = {"configurable": {"thread_id": "1"}}
# stop when about to enter node
assert tool_two.invoke({"my_key": "value ⛰️", "market": "DE"}, thread1) == {
"my_key": "value ⛰️",
"market": "DE",
}
assert [
c.metadata
for c in tool_two.checkpointer.list(
{"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
)
] == [
{
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
},
{
"parents": {},
"source": "input",
"step": -1,
"writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}},
"thread_id": "1",
},
]
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️", "market": "DE"},
next=("tool_two",),
tasks=(
PregelTask(
AnyStr(),
"tool_two",
(PULL, "tool_two"),
interrupts=(
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:"), AnyStr("do:")],
),
),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("tool_two:"),
}
},
),
),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
},
parent_config=[
*tool_two.checkpointer.list(
{"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2
)
][-1].config,
)
# clear the interrupt and next tasks
tool_two.update_state(thread1, None, as_node=END)
# interrupt and next tasks are cleared
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️", "market": "DE"},
next=(),
tasks=(),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {},
"thread_id": "1",
},
parent_config=[
*tool_two.checkpointer.list(
{"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2
)
][-1].config,
)


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_start_branch_then(
snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str
Expand Down Expand Up @@ -14471,3 +14641,35 @@ class CustomParentState(TypedDict):
},
tasks=(),
)


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_interrupt_subgraph(request: pytest.FixtureRequest, checkpointer_name: str):
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

class State(TypedDict):
baz: str

def foo(state):
return {"baz": "foo"}

def bar(state):
value = interrupt("Please provide baz value:")
return {"baz": value}

child_builder = StateGraph(State)
child_builder.add_node(bar)
child_builder.add_edge(START, "bar")

builder = StateGraph(State)
builder.add_node(foo)
builder.add_node("bar", child_builder.compile())
builder.add_edge(START, "foo")
builder.add_edge("foo", "bar")
graph = builder.compile(checkpointer=checkpointer)

thread1 = {"configurable": {"thread_id": "1"}}
# First run, interrupted at bar
assert graph.invoke({"baz": ""}, thread1)
# Resume with answer
assert graph.invoke(Command(resume="bar"), thread1)
Loading

0 comments on commit d70b659

Please sign in to comment.