Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lib: Add support for multiple interrupts per node #2636

Merged
merged 3 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions libs/langgraph/langgraph/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys

Check notice on line 1 in libs/langgraph/langgraph/constants.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 61.7 ms +- 1.3 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 52.1 ms +- 1.3 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 93.4 ms +- 8.5 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 95.5 ms +- 1.7 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 641 ms +- 30 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 507 ms +- 6 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 924 ms +- 43 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 927 ms +- 16 ms ......................................... react_agent_10x: Mean +- std dev: 30.8 ms +- 0.8 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.4 ms +- 0.2 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 46.9 ms +- 0.8 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.7 ms +- 0.8 ms ......................................... react_agent_100x: Mean +- std dev: 344 ms +- 7 ms ......................................... react_agent_100x_sync: Mean +- std dev: 272 ms +- 3 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 935 ms +- 9 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 834 ms +- 11 ms ......................................... wide_state_25x300: Mean +- std dev: 24.2 ms +- 0.5 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 15.5 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 279 ms +- 5 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 266 ms +- 5 ms ......................................... wide_state_15x600: Mean +- std dev: 28.2 ms +- 0.4 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 17.8 ms +- 0.2 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 478 ms +- 5 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 465 ms +- 9 ms ......................................... wide_state_9x1200: Mean +- std dev: 28.3 ms +- 0.5 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 17.9 ms +- 0.3 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 312 ms +- 5 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 298 ms +- 4 ms

Check notice on line 1 in libs/langgraph/langgraph/constants.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +========================================+=========+=======================+ | fanout_to_subgraph_100x_checkpoint | 945 ms | 924 ms: 1.02x faster | +----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint | 47.1 ms | 46.9 ms: 1.01x faster | +----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 838 ms | 834 ms: 1.00x faster | +----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 939 ms | 935 ms: 1.00x faster | +----------------------------------------+---------+-----------------------+ | react_agent_10x_sync | 22.3 ms | 22.4 ms: 1.00x slower | +----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 17.8 ms | 17.8 ms: 1.00x slower | +----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 17.8 ms | 17.9 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 28.2 ms | 28.3 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 15.3 ms | 15.5 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 94.6 ms | 95.5 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 501 ms | 507 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_sync | 51.4 ms | 52.1 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 606 ms | 641 ms: 1.06x slower | +----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.00x slower | +----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (15): react_agent_10x, react_agent_10x_checkpoint_sync, wide_state_9x1200_checkpoint_sync, fanout_to_subgraph_100x_checkpoint_sync, wide_state_9x1200_checkpoint, react_agent_100x_sync, wide_state_25x300, wide_state_15x600_checkpoint, wide_state_25x300_checkpoint_sync, fanout_to_subgraph_10x, wide_state_25x300_checkpoint, wide_state_15x600_checkpoint_sync, react_agent_100x, wide_state_15x600, fanout_to_subgraph_10x_checkpoint
from os import getenv
from types import MappingProxyType
from typing import Any, Literal, Mapping, cast
Expand Down Expand Up @@ -72,9 +72,11 @@
CONFIG_KEY_CHECKPOINT_NS = sys.intern("checkpoint_ns")
# holds the current checkpoint_ns, "" for root graph
CONFIG_KEY_NODE_FINISHED = sys.intern("__pregel_node_finished")
# callback to be called when a node is finished
CONFIG_KEY_RESUME_VALUE = sys.intern("__pregel_resume_value")
# holds the value that "answers" an interrupt() call
CONFIG_KEY_WRITES = sys.intern("__pregel_writes")
# read-only list of existing task writes
CONFIG_KEY_SCRATCHPAD = sys.intern("__pregel_scratchpad")
# holds a mutable dict for temporary storage scoped to the current task

# --- Other constants ---
PUSH = sys.intern("__pregel_push")
Expand Down
35 changes: 16 additions & 19 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_CHECKPOINTER,
CONFIG_KEY_READ,
CONFIG_KEY_RESUME_VALUE,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is CONFIG_KEY_RESUME_VALUE used anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, removed now

CONFIG_KEY_SCRATCHPAD,
CONFIG_KEY_SEND,
CONFIG_KEY_STORE,
CONFIG_KEY_TASK_ID,
CONFIG_KEY_WRITES,
EMPTY_SEQ,
INTERRUPT,
MISSING,
NO_WRITES,
NS_END,
NS_SEP,
Expand Down Expand Up @@ -589,14 +589,13 @@ def prepare_single_task(
},
CONFIG_KEY_CHECKPOINT_ID: None,
CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns,
CONFIG_KEY_RESUME_VALUE: next(
(
v
for tid, c, v in pending_writes
if tid in (NULL_TASK_ID, task_id) and c == RESUME
),
configurable.get(CONFIG_KEY_RESUME_VALUE, MISSING),
),
CONFIG_KEY_WRITES: [
w
for w in pending_writes
+ configurable.get(CONFIG_KEY_WRITES, [])
if w[0] in (NULL_TASK_ID, task_id)
],
CONFIG_KEY_SCRATCHPAD: {},
},
),
triggers,
Expand Down Expand Up @@ -713,15 +712,13 @@ def prepare_single_task(
},
CONFIG_KEY_CHECKPOINT_ID: None,
CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns,
CONFIG_KEY_RESUME_VALUE: next(
(
v
for tid, c, v in pending_writes
if tid in (NULL_TASK_ID, task_id)
and c == RESUME
),
configurable.get(CONFIG_KEY_RESUME_VALUE, MISSING),
),
CONFIG_KEY_WRITES: [
w
for w in pending_writes
+ configurable.get(CONFIG_KEY_WRITES, [])
if w[0] in (NULL_TASK_ID, task_id)
],
CONFIG_KEY_SCRATCHPAD: {},
},
),
triggers,
Expand Down
9 changes: 7 additions & 2 deletions libs/langgraph/langgraph/pregel/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain_core.runnables.utils import AddableDict

from langgraph.channels.base import BaseChannel, EmptyChannelError
from langgraph.checkpoint.base import PendingWrite
from langgraph.constants import (
EMPTY_SEQ,
ERROR,
Expand Down Expand Up @@ -66,7 +67,7 @@ def read_channels(


def map_command(
cmd: Command,
cmd: Command, pending_writes: list[PendingWrite]
) -> Iterator[tuple[str, str, Any]]:
"""Map input chunk to a sequence of pending writes in the form (channel, value)."""
if cmd.graph == Command.PARENT:
Expand All @@ -85,7 +86,11 @@ def map_command(
if cmd.resume:
if isinstance(cmd.resume, dict) and all(is_task_id(k) for k in cmd.resume):
for tid, resume in cmd.resume.items():
yield (tid, RESUME, resume)
existing: list[Any] = next(
(w[2] for w in pending_writes if w[0] == tid and w[1] == RESUME), []
)
existing.append(resume)
yield (tid, RESUME, existing)
else:
yield (NULL_TASK_ID, RESUME, cmd.resume)
if cmd.update:
Expand Down
25 changes: 23 additions & 2 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from langgraph.channels.base import BaseChannel
from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
BaseCheckpointSaver,
ChannelVersions,
Checkpoint,
Expand Down Expand Up @@ -263,8 +264,28 @@ def put_writes(self, task_id: str, writes: Sequence[tuple[str, Any]]) -> None:
"""Put writes for a task, to be read by the next tick."""
if not writes:
return
# deduplicate writes to special channels, last write wins
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooc: why do we know we don't have to deduplicate if one of the writes is to a regular channel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regular channels can't be written to more than once (writes after the first are ignored)

if all(w[0] in WRITES_IDX_MAP for w in writes):
writes = list({w[0]: w for w in writes}.values())
# save writes
self.checkpoint_pending_writes.extend((task_id, k, v) for k, v in writes)
for c, v in writes:
if (
c in WRITES_IDX_MAP
and (
idx := next(
(
i
for i, w in enumerate(self.checkpoint_pending_writes)
if w[0] == task_id and w[1] == c
),
None,
)
)
is not None
):
self.checkpoint_pending_writes[idx] = (task_id, c, v)
else:
self.checkpoint_pending_writes.append((task_id, c, v))
if self.checkpointer_put_writes is not None:
self.submit(
self.checkpointer_put_writes,
Expand Down Expand Up @@ -536,7 +557,7 @@ def _first(self, *, input_keys: Union[str, Sequence[str]]) -> None:
elif isinstance(self.input, Command):
writes: defaultdict[str, list[tuple[str, Any]]] = defaultdict(list)
# group writes by task ID
for tid, c, v in map_command(self.input):
for tid, c, v in map_command(self.input, self.checkpoint_pending_writes):
writes[tid].append((c, v))
if not writes:
raise EmptyInputError("Received empty Command input")
Expand Down
3 changes: 3 additions & 0 deletions libs/langgraph/langgraph/pregel/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
INTERRUPT,
NO_WRITES,
PUSH,
RESUME,
TAG_HIDDEN,
)
from langgraph.errors import GraphBubbleUp, GraphInterrupt
Expand Down Expand Up @@ -297,6 +298,8 @@ def commit(
if isinstance(exception, GraphInterrupt):
# save interrupt to checkpointer
if interrupts := [(INTERRUPT, i) for i in exception.args[0]]:
if resumes := [w for w in task.writes if w[0] == RESUME]:
interrupts.extend(resumes)
self.put_writes(task.id, interrupts)
elif isinstance(exception, GraphBubbleUp):
raise exception
Expand Down
66 changes: 53 additions & 13 deletions libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Optional,
Sequence,
Type,
TypedDict,
TypeVar,
Union,
cast,
Expand All @@ -21,11 +22,16 @@
from langchain_core.runnables import Runnable, RunnableConfig
from typing_extensions import Self

from langgraph.checkpoint.base import BaseCheckpointSaver, CheckpointMetadata
from langgraph.checkpoint.base import (
BaseCheckpointSaver,
CheckpointMetadata,
PendingWrite,
)

if TYPE_CHECKING:
from langgraph.store.base import BaseStore


All = Literal["*"]
"""Special value to indicate that graph should interrupt on all nodes."""

Expand Down Expand Up @@ -300,26 +306,60 @@ def __init__(
self.stop = stop


class PregelScratchpad(TypedDict, total=False):
interrupt_counter: int
used_null_resume: bool
resume: list[Any]


def interrupt(value: Any) -> Any:
from langgraph.constants import (
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_RESUME_VALUE,
MISSING,
CONFIG_KEY_SCRATCHPAD,
CONFIG_KEY_SEND,
CONFIG_KEY_TASK_ID,
CONFIG_KEY_WRITES,
NS_SEP,
NULL_TASK_ID,
RESUME,
)
from langgraph.errors import GraphInterrupt
from langgraph.utils.config import get_configurable

conf = get_configurable()
if (resume := conf.get(CONFIG_KEY_RESUME_VALUE, MISSING)) and resume is not MISSING:
return resume
# track interrupt index
scratchpad: PregelScratchpad = conf[CONFIG_KEY_SCRATCHPAD]
if "interrupt_counter" not in scratchpad:
scratchpad["interrupt_counter"] = 0
else:
raise GraphInterrupt(
(
Interrupt(
value=value,
resumable=True,
ns=cast(str, conf[CONFIG_KEY_CHECKPOINT_NS]).split(NS_SEP),
),
)
scratchpad["interrupt_counter"] += 1
idx = scratchpad["interrupt_counter"]
# find previous resume values
task_id = conf[CONFIG_KEY_TASK_ID]
writes: list[PendingWrite] = conf[CONFIG_KEY_WRITES]
scratchpad.setdefault(
"resume", next((w[2] for w in writes if w[0] == task_id and w[1] == RESUME), [])
)
if scratchpad["resume"]:
if idx < len(scratchpad["resume"]):
return scratchpad["resume"][idx]
# find current resume value
if not scratchpad.get("used_null_resume"):
scratchpad["used_null_resume"] = True
for tid, c, v in sorted(writes, key=lambda x: x[0], reverse=True):
if tid == NULL_TASK_ID and c == RESUME:
assert len(scratchpad["resume"]) == idx, (scratchpad["resume"], idx)
scratchpad["resume"].append(v)
print("saving:", scratchpad["resume"])
conf[CONFIG_KEY_SEND]([(RESUME, scratchpad["resume"])])
return v
# no resume value found
raise GraphInterrupt(
(
Interrupt(
value=value,
resumable=True,
ns=cast(str, conf[CONFIG_KEY_CHECKPOINT_NS]).split(NS_SEP),
),
)
)
Loading
Loading