Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lib: Add support for multiple interrupts per node #2636

Merged
merged 3 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions libs/langgraph/langgraph/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys

Check notice on line 1 in libs/langgraph/langgraph/constants.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 61.3 ms +- 1.7 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 51.3 ms +- 0.7 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 91.6 ms +- 7.8 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 93.5 ms +- 2.0 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 589 ms +- 22 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 500 ms +- 5 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 934 ms +- 40 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 919 ms +- 17 ms ......................................... react_agent_10x: Mean +- std dev: 30.9 ms +- 0.7 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.2 ms +- 0.2 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 46.5 ms +- 0.8 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.9 ms +- 0.5 ms ......................................... react_agent_100x: Mean +- std dev: 343 ms +- 7 ms ......................................... react_agent_100x_sync: Mean +- std dev: 270 ms +- 4 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 939 ms +- 10 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 842 ms +- 11 ms ......................................... wide_state_25x300: Mean +- std dev: 24.1 ms +- 0.4 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 15.4 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 280 ms +- 3 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 267 ms +- 4 ms ......................................... wide_state_15x600: Mean +- std dev: 28.1 ms +- 0.5 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 17.8 ms +- 0.2 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 480 ms +- 5 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 466 ms +- 6 ms ......................................... wide_state_9x1200: Mean +- std dev: 28.0 ms +- 0.5 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 17.7 ms +- 0.1 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 313 ms +- 3 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 298 ms +- 4 ms

Check notice on line 1 in libs/langgraph/langgraph/constants.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +=========================================+=========+=======================+ | fanout_to_subgraph_100x | 606 ms | 589 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint | 47.1 ms | 46.5 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 94.6 ms | 93.5 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 945 ms | 934 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint_sync | 929 ms | 919 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 272 ms | 270 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_sync | 22.3 ms | 22.2 ms: 1.00x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 478 ms | 480 ms: 1.00x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 838 ms | 842 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint | 278 ms | 280 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.00x faster | +-----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (18): fanout_to_subgraph_10x_checkpoint, fanout_to_subgraph_10x, wide_state_25x300, react_agent_10x, wide_state_9x1200, wide_state_9x1200_sync, wide_state_15x600, fanout_to_subgraph_100x_sync, fanout_to_subgraph_10x_sync, wide_state_9x1200_checkpoint_sync, wide_state_15x600_sync, wide_state_9x1200_checkpoint, react_agent_100x_checkpoint, react_agent_10x_checkpoint_sync, wide_state_25x300_sync, react_agent_100x, wide_state_15x600_checkpoint_sync, wide_state_25x300_checkpoint_sync
from os import getenv
from types import MappingProxyType
from typing import Any, Literal, Mapping, cast
Expand Down Expand Up @@ -75,6 +75,10 @@
# callback to be called when a node is finished
CONFIG_KEY_RESUME_VALUE = sys.intern("__pregel_resume_value")
# holds the value that "answers" an interrupt() call
CONFIG_KEY_WRITES = sys.intern("__pregel_writes")
# read-only list of existing task writes
CONFIG_KEY_SCRATCHPAD = sys.intern("__pregel_scratchpad")
# holds a mutable dict for temporary storage scoped to the current task

# --- Other constants ---
PUSH = sys.intern("__pregel_push")
Expand Down
35 changes: 16 additions & 19 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_CHECKPOINTER,
CONFIG_KEY_READ,
CONFIG_KEY_RESUME_VALUE,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is CONFIG_KEY_RESUME_VALUE used anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, removed now

CONFIG_KEY_SCRATCHPAD,
CONFIG_KEY_SEND,
CONFIG_KEY_STORE,
CONFIG_KEY_TASK_ID,
CONFIG_KEY_WRITES,
EMPTY_SEQ,
INTERRUPT,
MISSING,
NO_WRITES,
NS_END,
NS_SEP,
Expand Down Expand Up @@ -589,14 +589,13 @@ def prepare_single_task(
},
CONFIG_KEY_CHECKPOINT_ID: None,
CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns,
CONFIG_KEY_RESUME_VALUE: next(
(
v
for tid, c, v in pending_writes
if tid in (NULL_TASK_ID, task_id) and c == RESUME
),
configurable.get(CONFIG_KEY_RESUME_VALUE, MISSING),
),
CONFIG_KEY_WRITES: [
w
for w in pending_writes
+ configurable.get(CONFIG_KEY_WRITES, [])
if w[0] in (NULL_TASK_ID, task_id)
],
CONFIG_KEY_SCRATCHPAD: {},
},
),
triggers,
Expand Down Expand Up @@ -713,15 +712,13 @@ def prepare_single_task(
},
CONFIG_KEY_CHECKPOINT_ID: None,
CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns,
CONFIG_KEY_RESUME_VALUE: next(
(
v
for tid, c, v in pending_writes
if tid in (NULL_TASK_ID, task_id)
and c == RESUME
),
configurable.get(CONFIG_KEY_RESUME_VALUE, MISSING),
),
CONFIG_KEY_WRITES: [
w
for w in pending_writes
+ configurable.get(CONFIG_KEY_WRITES, [])
if w[0] in (NULL_TASK_ID, task_id)
],
CONFIG_KEY_SCRATCHPAD: {},
},
),
triggers,
Expand Down
9 changes: 7 additions & 2 deletions libs/langgraph/langgraph/pregel/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain_core.runnables.utils import AddableDict

from langgraph.channels.base import BaseChannel, EmptyChannelError
from langgraph.checkpoint.base import PendingWrite
from langgraph.constants import (
EMPTY_SEQ,
ERROR,
Expand Down Expand Up @@ -66,7 +67,7 @@ def read_channels(


def map_command(
cmd: Command,
cmd: Command, pending_writes: list[PendingWrite]
) -> Iterator[tuple[str, str, Any]]:
"""Map input chunk to a sequence of pending writes in the form (channel, value)."""
if cmd.graph == Command.PARENT:
Expand All @@ -85,7 +86,11 @@ def map_command(
if cmd.resume:
if isinstance(cmd.resume, dict) and all(is_task_id(k) for k in cmd.resume):
for tid, resume in cmd.resume.items():
yield (tid, RESUME, resume)
existing = next(
(w for w in pending_writes if w[0] == tid and w[1] == RESUME), []
)
existing.append(resume)
yield (tid, RESUME, existing)
else:
yield (NULL_TASK_ID, RESUME, cmd.resume)
if cmd.update:
Expand Down
25 changes: 23 additions & 2 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from langgraph.channels.base import BaseChannel
from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
BaseCheckpointSaver,
ChannelVersions,
Checkpoint,
Expand Down Expand Up @@ -263,8 +264,28 @@ def put_writes(self, task_id: str, writes: Sequence[tuple[str, Any]]) -> None:
"""Put writes for a task, to be read by the next tick."""
if not writes:
return
# deduplicate writes to special channels, last write wins
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooc: why do we know we don't have to deduplicate if one of the writes is to a regular channel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regular channels can't be written to more than once (writes after the first are ignored)

if all(w[0] in WRITES_IDX_MAP for w in writes):
writes = list({w[0]: w for w in writes}.values())
# save writes
self.checkpoint_pending_writes.extend((task_id, k, v) for k, v in writes)
for c, v in writes:
if (
c in WRITES_IDX_MAP
and (
idx := next(
(
i
for i, w in enumerate(self.checkpoint_pending_writes)
if w[0] == task_id and w[1] == c
),
None,
)
)
is not None
):
self.checkpoint_pending_writes[idx] = (task_id, c, v)
else:
self.checkpoint_pending_writes.append((task_id, c, v))
if self.checkpointer_put_writes is not None:
self.submit(
self.checkpointer_put_writes,
Expand Down Expand Up @@ -536,7 +557,7 @@ def _first(self, *, input_keys: Union[str, Sequence[str]]) -> None:
elif isinstance(self.input, Command):
writes: defaultdict[str, list[tuple[str, Any]]] = defaultdict(list)
# group writes by task ID
for tid, c, v in map_command(self.input):
for tid, c, v in map_command(self.input, self.checkpoint_pending_writes):
writes[tid].append((c, v))
if not writes:
raise EmptyInputError("Received empty Command input")
Expand Down
3 changes: 3 additions & 0 deletions libs/langgraph/langgraph/pregel/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
INTERRUPT,
NO_WRITES,
PUSH,
RESUME,
TAG_HIDDEN,
)
from langgraph.errors import GraphBubbleUp, GraphInterrupt
Expand Down Expand Up @@ -297,6 +298,8 @@ def commit(
if isinstance(exception, GraphInterrupt):
# save interrupt to checkpointer
if interrupts := [(INTERRUPT, i) for i in exception.args[0]]:
if resumes := [w for w in task.writes if w[0] == RESUME]:
interrupts.extend(resumes)
self.put_writes(task.id, interrupts)
elif isinstance(exception, GraphBubbleUp):
raise exception
Expand Down
66 changes: 53 additions & 13 deletions libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Optional,
Sequence,
Type,
TypedDict,
TypeVar,
Union,
cast,
Expand All @@ -21,11 +22,16 @@
from langchain_core.runnables import Runnable, RunnableConfig
from typing_extensions import Self

from langgraph.checkpoint.base import BaseCheckpointSaver, CheckpointMetadata
from langgraph.checkpoint.base import (
BaseCheckpointSaver,
CheckpointMetadata,
PendingWrite,
)

if TYPE_CHECKING:
from langgraph.store.base import BaseStore


All = Literal["*"]
"""Special value to indicate that graph should interrupt on all nodes."""

Expand Down Expand Up @@ -300,26 +306,60 @@ def __init__(
self.stop = stop


class PregelScratchpad(TypedDict, total=False):
interrupt_counter: int
used_null_resume: bool
resume: list[Any]


def interrupt(value: Any) -> Any:
from langgraph.constants import (
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_RESUME_VALUE,
MISSING,
CONFIG_KEY_SCRATCHPAD,
CONFIG_KEY_SEND,
CONFIG_KEY_TASK_ID,
CONFIG_KEY_WRITES,
NS_SEP,
NULL_TASK_ID,
RESUME,
)
from langgraph.errors import GraphInterrupt
from langgraph.utils.config import get_configurable

conf = get_configurable()
if (resume := conf.get(CONFIG_KEY_RESUME_VALUE, MISSING)) and resume is not MISSING:
return resume
# track interrupt index
scratchpad: PregelScratchpad = conf[CONFIG_KEY_SCRATCHPAD]
if "interrupt_counter" not in scratchpad:
scratchpad["interrupt_counter"] = 0
else:
raise GraphInterrupt(
(
Interrupt(
value=value,
resumable=True,
ns=cast(str, conf[CONFIG_KEY_CHECKPOINT_NS]).split(NS_SEP),
),
)
scratchpad["interrupt_counter"] += 1
idx = scratchpad["interrupt_counter"]
# find previous resume values
task_id = conf[CONFIG_KEY_TASK_ID]
writes: list[PendingWrite] = conf[CONFIG_KEY_WRITES]
scratchpad.setdefault(
"resume", next((w[2] for w in writes if w[0] == task_id and w[1] == RESUME), [])
)
if scratchpad["resume"]:
if idx < len(scratchpad["resume"]):
return scratchpad["resume"][idx]
# find current resume value
if not scratchpad.get("used_null_resume"):
scratchpad["used_null_resume"] = True
for tid, c, v in sorted(writes, key=lambda x: x[0], reverse=True):
if tid == NULL_TASK_ID and c == RESUME:
assert len(scratchpad["resume"]) == idx, (scratchpad["resume"], idx)
scratchpad["resume"].append(v)
print("saving:", scratchpad["resume"])
conf[CONFIG_KEY_SEND]([(RESUME, scratchpad["resume"])])
return v
# no resume value found
raise GraphInterrupt(
(
Interrupt(
value=value,
resumable=True,
ns=cast(str, conf[CONFIG_KEY_CHECKPOINT_NS]).split(NS_SEP),
),
)
)
Loading
Loading