Skip to content

Commit

Permalink
Fix Send order after interrupt/resume (#3037)
Browse files Browse the repository at this point in the history
- order was incorrectly based on task id, instead of the correct task
path
- this requires storing task paths on checkpointers
- addition of task_path to put_writes is made backwards compatible by
checking signature on call, and treating it as an optional arg
  • Loading branch information
nfcampos authored Jan 15, 2025
2 parents be1d035 + 0adbd89 commit aab6fdf
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 110 deletions.
10 changes: 5 additions & 5 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def apply_writes(
# sort tasks on path, to ensure deterministic order for update application
# any path parts after the 3rd are ignored for sorting
# (we use them for eg. task ids which aren't good for sorting)
tasks = sorted(tasks, key=lambda t: _tuple_str(t.path[:3]))
tasks = sorted(tasks, key=lambda t: task_path_str(t.path[:3]))
# if no task has triggers this is applying writes from the null task only
# so we don't do anything other than update the channels written to
bump_step = any(t.triggers for t in tasks)
Expand Down Expand Up @@ -450,7 +450,7 @@ def prepare_single_task(
str(step),
name,
PUSH,
_tuple_str(task_path[1]),
task_path_str(task_path[1]),
str(task_path[2]),
)
task_checkpoint_ns = f"{checkpoint_ns}:{task_id}"
Expand Down Expand Up @@ -813,10 +813,10 @@ def _uuid5_str(namespace: bytes, *parts: str) -> str:
return f"{hex[:8]}-{hex[8:12]}-{hex[12:16]}-{hex[16:20]}-{hex[20:32]}"


def _tuple_str(tup: Union[str, int, tuple]) -> str:
"""Generate a string representation of a tuple."""
def task_path_str(tup: Union[str, int, tuple]) -> str:
"""Generate a string representation of the task path."""
return (
f"~{', '.join(_tuple_str(x) for x in tup)}"
f"~{', '.join(task_path_str(x) for x in tup)}"
if isinstance(tup, (tuple, list))
else f"{tup:010d}"
if isinstance(tup, int)
Expand Down
53 changes: 40 additions & 13 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import concurrent.futures
from collections import defaultdict, deque
from contextlib import AsyncExitStack, ExitStack
from inspect import signature
from types import TracebackType
from typing import (
Any,
Expand Down Expand Up @@ -81,6 +82,7 @@
prepare_next_tasks,
prepare_single_task,
should_interrupt,
task_path_str,
)
from langgraph.pregel.debug import (
map_debug_checkpoint,
Expand Down Expand Up @@ -151,6 +153,7 @@ class PregelLoop(LoopProtocol):
checkpointer_put_writes: Optional[
Callable[[RunnableConfig, Sequence[tuple[str, Any]], str], Any]
]
checkpointer_put_writes_accepts_task_path: bool
_checkpointer_put_after_previous: Optional[
Callable[
[
Expand Down Expand Up @@ -288,20 +291,34 @@ def put_writes(self, task_id: str, writes: Sequence[tuple[str, Any]]) -> None:
else:
self.checkpoint_pending_writes.append((task_id, c, v))
if self.checkpointer_put_writes is not None:
self.submit(
self.checkpointer_put_writes,
patch_configurable(
self.checkpoint_config,
{
CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get(
CONFIG_KEY_CHECKPOINT_NS, ""
),
CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"],
},
),
writes,
task_id,
config = patch_configurable(
self.checkpoint_config,
{
CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get(
CONFIG_KEY_CHECKPOINT_NS, ""
),
CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"],
},
)
if self.checkpointer_put_writes_accepts_task_path:
if hasattr(self, "tasks"):
task = self.tasks.get(task_id)
else:
task = None
self.submit(
self.checkpointer_put_writes,
config,
writes,
task_id,
task_path_str(task.path) if task else "",
)
else:
self.submit(
self.checkpointer_put_writes,
config,
writes,
task_id,
)
# output writes
if hasattr(self, "tasks"):
self._output_writes(task_id, writes)
Expand Down Expand Up @@ -813,10 +830,15 @@ def __init__(
if checkpointer:
self.checkpointer_get_next_version = checkpointer.get_next_version
self.checkpointer_put_writes = checkpointer.put_writes
self.checkpointer_put_writes_accepts_task_path = (
signature(checkpointer.put_writes).parameters.get("task_path")
is not None
)
else:
self.checkpointer_get_next_version = increment
self._checkpointer_put_after_previous = None # type: ignore[assignment]
self.checkpointer_put_writes = None
self.checkpointer_put_writes_accepts_task_path = False

def _checkpointer_put_after_previous(
self,
Expand Down Expand Up @@ -945,10 +967,15 @@ def __init__(
if checkpointer:
self.checkpointer_get_next_version = checkpointer.get_next_version
self.checkpointer_put_writes = checkpointer.aput_writes
self.checkpointer_put_writes_accepts_task_path = (
signature(checkpointer.aput_writes).parameters.get("task_path")
is not None
)
else:
self.checkpointer_get_next_version = increment
self._checkpointer_put_after_previous = None # type: ignore[assignment]
self.checkpointer_put_writes = None
self.checkpointer_put_writes_accepts_task_path = False

async def _checkpointer_put_after_previous(
self,
Expand Down
12 changes: 6 additions & 6 deletions libs/langgraph/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/langgraph/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ repository = "https://www.github.com/langchain-ai/langgraph"
[tool.poetry.dependencies]
python = ">=3.9.0,<4.0"
langchain-core = ">=0.2.43,<0.4.0,!=0.3.0,!=0.3.1,!=0.3.2,!=0.3.3,!=0.3.4,!=0.3.5,!=0.3.6,!=0.3.7,!=0.3.8,!=0.3.9,!=0.3.10,!=0.3.11,!=0.3.12,!=0.3.13,!=0.3.14,!=0.3.15,!=0.3.16,!=0.3.17,!=0.3.18,!=0.3.19,!=0.3.20,!=0.3.21,!=0.3.22"
langgraph-checkpoint = "^2.0.4"
langgraph-checkpoint = "^2.0.10"
langgraph-sdk = "^0.1.42"

[tool.poetry.group.dev.dependencies]
Expand Down
12 changes: 6 additions & 6 deletions libs/langgraph/tests/test_algo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from langgraph.checkpoint.base import empty_checkpoint
from langgraph.constants import PULL, PUSH
from langgraph.pregel.algo import _tuple_str, prepare_next_tasks
from langgraph.pregel.algo import prepare_next_tasks, task_path_str
from langgraph.pregel.manager import ChannelsManager


Expand Down Expand Up @@ -49,16 +49,16 @@ def test_tuple_str() -> None:
push_path_b = (PUSH, push_path_a, 1)
push_path_c = (PUSH, push_path_b, 3)

assert _tuple_str(push_path_a) == f"~{PUSH}, 0000000002"
assert _tuple_str(push_path_b) == f"~{PUSH}, ~{PUSH}, 0000000002, 0000000001"
assert task_path_str(push_path_a) == f"~{PUSH}, 0000000002"
assert task_path_str(push_path_b) == f"~{PUSH}, ~{PUSH}, 0000000002, 0000000001"
assert (
_tuple_str(push_path_c)
task_path_str(push_path_c)
== f"~{PUSH}, ~{PUSH}, ~{PUSH}, 0000000002, 0000000001, 0000000003"
)
assert _tuple_str(pull_path_a) == f"~{PULL}, abc"
assert task_path_str(pull_path_a) == f"~{PULL}, abc"

path_list = [push_path_b, push_path_a, pull_path_a, push_path_c]
assert sorted(map(_tuple_str, path_list)) == [
assert sorted(map(task_path_str, path_list)) == [
f"~{PULL}, abc",
f"~{PUSH}, 0000000002",
f"~{PUSH}, ~{PUSH}, 0000000002, 0000000001",
Expand Down
Loading

0 comments on commit aab6fdf

Please sign in to comment.