Skip to content

Commit

Permalink
Use thread_id for partition key, ignore tasks for stale checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Sep 10, 2024
1 parent df70717 commit b395824
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 33 deletions.
20 changes: 9 additions & 11 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,18 +298,16 @@ async def aput_writes(
if all(w[0] in WRITES_IDX_MAP for w in writes)
else self.INSERT_CHECKPOINT_WRITES_SQL
)
params = await asyncio.to_thread(
self._dump_writes,
config["configurable"]["thread_id"],
config["configurable"]["checkpoint_ns"],
config["configurable"]["checkpoint_id"],
task_id,
writes,
)
async with self._cursor(pipeline=True) as cur:
await cur.executemany(
query,
await asyncio.to_thread(
self._dump_writes,
config["configurable"]["thread_id"],
config["configurable"]["checkpoint_ns"],
config["configurable"]["checkpoint_id"],
task_id,
writes,
),
)
await cur.executemany(query, params)

@asynccontextmanager
async def _cursor(self, *, pipeline: bool = False) -> AsyncIterator[AsyncCursor]:
Expand Down
4 changes: 4 additions & 0 deletions libs/langgraph/langgraph/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
CONFIG_KEY_RESUMING = "__pregel_resuming"
CONFIG_KEY_TASK_ID = "__pregel_task_id"
CONFIG_KEY_DEDUPE_TASKS = "__pregel_dedupe_tasks"
CONFIG_KEY_ENSURE_LATEST = "__pregel_ensure_latest"
# this one part of public API so more readable
CONFIG_KEY_CHECKPOINT_MAP = "checkpoint_map"
INTERRUPT = "__interrupt__"
ERROR = "__error__"
NO_WRITES = "__no_writes__"
SCHEDULED = "__scheduled__"
TASKS = "__pregel_tasks" # for backwards compat, this is the original name of PUSH
PUSH = "__pregel_push"
Expand All @@ -23,6 +25,7 @@
SCHEDULED,
INTERRUPT,
ERROR,
NO_WRITES,
TASKS,
PUSH,
PULL,
Expand All @@ -34,6 +37,7 @@
CONFIG_KEY_RESUMING,
CONFIG_KEY_TASK_ID,
CONFIG_KEY_DEDUPE_TASKS,
CONFIG_KEY_ENSURE_LATEST,
INPUT,
RUNTIME_PLACEHOLDER,
}
Expand Down
6 changes: 6 additions & 0 deletions libs/langgraph/langgraph/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ class TaskNotFound(Exception):
pass


class CheckpointNotLatest(Exception):
"""Raised when the checkpoint is not the latest version."""

pass


__all__ = [
"GraphRecursionError",
"InvalidUpdateError",
Expand Down
5 changes: 2 additions & 3 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from uuid import UUID, uuid5

from langchain_core.globals import get_debug
from langchain_core.load.dump import dumpd
from langchain_core.runnables import (
Runnable,
RunnableLambda,
Expand Down Expand Up @@ -1160,7 +1159,7 @@ def output() -> Iterator:
config = ensure_config(merge_configs(self.config, config))
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
None,
input,
name=config.get("run_name", self.get_name()),
run_id=config.get("run_id"),
Expand Down Expand Up @@ -1341,7 +1340,7 @@ def output() -> Iterator:
config = ensure_config(merge_configs(self.config, config))
callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
None,
input,
name=config.get("run_name", self.get_name()),
run_id=config.get("run_id"),
Expand Down
9 changes: 8 additions & 1 deletion libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
CONFIG_KEY_SEND,
CONFIG_KEY_TASK_ID,
INTERRUPT,
NO_WRITES,
NS_SEP,
PULL,
PUSH,
Expand Down Expand Up @@ -196,7 +197,9 @@ def apply_writes(
pending_writes_by_managed: dict[str, list[Any]] = defaultdict(list)
for task in tasks:
for chan, val in task.writes:
if chan == TASKS:
if chan == NO_WRITES:
pass
elif chan == TASKS:
checkpoint["pending_sends"].append(val)
elif chan in channels:
pending_writes_by_channel[chan].append(val)
Expand Down Expand Up @@ -331,6 +334,8 @@ def prepare_single_task(

if task_path[0] == PUSH:
idx = int(task_path[1])
if idx >= len(checkpoint["pending_sends"]):
return
packet = checkpoint["pending_sends"][idx]
if not isinstance(packet, Send):
logger.warning(
Expand Down Expand Up @@ -425,6 +430,8 @@ def prepare_single_task(
return PregelTask(task_id, packet.node, task_path)
elif task_path[0] == PULL:
name = str(task_path[1])
if name not in processes:
return
proc = processes[name]
version_type = type(next(iter(checkpoint["channel_versions"].values()), None))
null_version = version_type()
Expand Down
53 changes: 42 additions & 11 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from langgraph.constants import (
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_DEDUPE_TASKS,
CONFIG_KEY_ENSURE_LATEST,
CONFIG_KEY_RESUMING,
CONFIG_KEY_STREAM,
CONFIG_KEY_TASK_ID,
Expand All @@ -49,7 +50,7 @@
SCHEDULED,
TAG_HIDDEN,
)
from langgraph.errors import EmptyInputError, GraphInterrupt
from langgraph.errors import CheckpointNotLatest, EmptyInputError, GraphInterrupt
from langgraph.managed.base import (
ManagedValueMapping,
ManagedValueSpec,
Expand Down Expand Up @@ -599,11 +600,26 @@ def _update_mv(self, key: str, values: Sequence[Any]) -> None:
# context manager

def __enter__(self) -> Self:
saved = (
self.checkpointer.get_tuple(self.checkpoint_config)
if self.checkpointer
else None
) or CheckpointTuple(self.config, empty_checkpoint(), {"step": -2}, None, [])
if self.config.get("configurable", {}).get(
CONFIG_KEY_ENSURE_LATEST
) and self.checkpoint_config["configurable"].get("checkpoint_id"):
saved = self.checkpointer.get_tuple(
patch_configurable(self.checkpoint_config, {"checkpoint_id": None})
)
if (
saved is None
or saved.checkpoint["id"]
!= self.checkpoint_config["configurable"]["checkpoint_id"]
):
raise CheckpointNotLatest
elif self.checkpointer:
saved = self.checkpointer.get_tuple(self.checkpoint_config)
else:
saved = None
if saved is None:
saved = CheckpointTuple(
self.config, empty_checkpoint(), {"step": -2}, None, []
)
self.checkpoint_config = {
**self.config,
**saved.config,
Expand Down Expand Up @@ -702,11 +718,26 @@ def _update_mv(self, key: str, values: Sequence[Any]) -> None:
# context manager

async def __aenter__(self) -> Self:
saved = (
await self.checkpointer.aget_tuple(self.checkpoint_config)
if self.checkpointer
else None
) or CheckpointTuple(self.config, empty_checkpoint(), {"step": -2}, None, [])
if self.config.get("configurable", {}).get(
CONFIG_KEY_ENSURE_LATEST
) and self.checkpoint_config["configurable"].get("checkpoint_id"):
saved = await self.checkpointer.aget_tuple(
patch_configurable(self.checkpoint_config, {"checkpoint_id": None})
)
if (
saved is None
or saved.checkpoint["id"]
!= self.checkpoint_config["configurable"]["checkpoint_id"]
):
raise CheckpointNotLatest
elif self.checkpointer:
saved = await self.checkpointer.aget_tuple(self.checkpoint_config)
else:
saved = None
if saved is None:
saved = CheckpointTuple(
self.config, empty_checkpoint(), {"step": -2}, None, []
)
self.checkpoint_config = {
**self.config,
**saved.config,
Expand Down
15 changes: 11 additions & 4 deletions libs/langgraph/langgraph/pregel/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Union,
)

from langgraph.constants import ERROR, INTERRUPT
from langgraph.constants import ERROR, INTERRUPT, NO_WRITES
from langgraph.errors import GraphInterrupt
from langgraph.pregel.executor import Submit
from langgraph.pregel.retry import arun_with_retry, run_with_retry
Expand Down Expand Up @@ -69,12 +69,15 @@ def tick(
if exc := _exception(fut):
if isinstance(exc, GraphInterrupt):
# save interrupt to checkpointer
self.put_writes(task.id, [(INTERRUPT, i) for i in exc.args[0]])
if interrupts := [(INTERRUPT, i) for i in exc.args[0]]:
self.put_writes(task.id, interrupts)
else:
# save error to checkpointer
self.put_writes(task.id, [(ERROR, exc)])

else:
if not task.writes:
# add no writes marker
task.writes.append((NO_WRITES, None))
# save task writes to checkpointer
self.put_writes(task.id, task.writes)
else:
Expand Down Expand Up @@ -130,11 +133,15 @@ async def atick(
if exc := _exception(fut):
if isinstance(exc, GraphInterrupt):
# save interrupt to checkpointer
self.put_writes(task.id, [(INTERRUPT, i) for i in exc.args[0]])
if interrupts := [(INTERRUPT, i) for i in exc.args[0]]:
self.put_writes(task.id, interrupts)
else:
# save error to checkpointer
self.put_writes(task.id, [(ERROR, exc)])
else:
if not task.writes:
# add no writes marker
task.writes.append((NO_WRITES, None))
# save task writes to checkpointer
self.put_writes(task.id, task.writes)
else:
Expand Down
13 changes: 11 additions & 2 deletions libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import langgraph.scheduler.kafka.serde as serde
from langgraph.constants import ERROR
from langgraph.errors import TaskNotFound
from langgraph.errors import CheckpointNotLatest, TaskNotFound
from langgraph.pregel import Pregel
from langgraph.pregel.algo import prepare_single_task
from langgraph.pregel.executor import AsyncBackgroundExecutor, Submit
Expand All @@ -22,6 +22,7 @@
MessageToOrchestrator,
Topics,
)
from langgraph.utils.config import patch_configurable


class KafkaExecutor(AbstractAsyncContextManager):
Expand Down Expand Up @@ -91,6 +92,8 @@ async def __anext__(self) -> Sequence[MessageToExecutor]:
async def each(self, msg: MessageToExecutor) -> None:
try:
await aretry(self.retry_policy, self.attempt, msg)
except CheckpointNotLatest:
pass
except Exception as exc:
await self.producer.send_and_wait(
self.topics.error,
Expand All @@ -103,9 +106,13 @@ async def each(self, msg: MessageToExecutor) -> None:

async def attempt(self, msg: MessageToExecutor) -> None:
# process message
saved = await self.graph.checkpointer.aget_tuple(msg["config"])
saved = await self.graph.checkpointer.aget_tuple(
patch_configurable(msg["config"], {"checkpoint_id": None})
)
if saved is None:
raise RuntimeError("Checkpoint not found")
if saved.checkpoint["id"] != msg["config"]["configurable"]["checkpoint_id"]:
raise CheckpointNotLatest()
async with AsyncChannelsManager(
self.graph.channels, saved.checkpoint, msg["config"], self.graph.store
) as (channels, managed), AsyncBackgroundExecutor() as submit:
Expand Down Expand Up @@ -138,6 +145,8 @@ async def attempt(self, msg: MessageToExecutor) -> None:
await self.producer.send_and_wait(
self.topics.orchestrator,
value=MessageToOrchestrator(input=None, config=msg["config"]),
# use thread_id as partition key
key=msg["config"]["configurable"]["thread_id"].encode(),
)

def _put_writes(
Expand Down
11 changes: 10 additions & 1 deletion libs/scheduler-kafka/langgraph/scheduler/kafka/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
from langchain_core.runnables import ensure_config

import langgraph.scheduler.kafka.serde as serde
from langgraph.constants import CONFIG_KEY_DEDUPE_TASKS, INTERRUPT, SCHEDULED
from langgraph.constants import (
CONFIG_KEY_DEDUPE_TASKS,
CONFIG_KEY_ENSURE_LATEST,
INTERRUPT,
SCHEDULED,
)
from langgraph.errors import CheckpointNotLatest
from langgraph.pregel import Pregel
from langgraph.pregel.loop import AsyncPregelLoop
from langgraph.pregel.types import RetryPolicy
Expand Down Expand Up @@ -83,6 +89,8 @@ async def __anext__(self) -> list[MessageToOrchestrator]:
async def each(self, msg: MessageToOrchestrator) -> None:
try:
await aretry(self.retry_policy, self.attempt, msg)
except CheckpointNotLatest:
pass
except Exception as exc:
await self.producer.send_and_wait(
self.topics.error,
Expand Down Expand Up @@ -127,6 +135,7 @@ async def attempt(self, msg: MessageToOrchestrator) -> None:
{
**loop.checkpoint_config["configurable"],
CONFIG_KEY_DEDUPE_TASKS: True,
CONFIG_KEY_ENSURE_LATEST: True,
},
),
task=ExecutorTask(id=task.id, path=task.path),
Expand Down
4 changes: 4 additions & 0 deletions libs/scheduler-kafka/tests/test_fanout.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ async def test_fanout_graph(topics: Topics, checkpointer: BaseCheckpointSaver) -
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
Expand All @@ -146,6 +147,7 @@ async def test_fanout_graph(topics: Topics, checkpointer: BaseCheckpointSaver) -
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
Expand Down Expand Up @@ -206,6 +208,7 @@ async def test_fanout_graph_w_interrupt(
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
Expand All @@ -226,6 +229,7 @@ async def test_fanout_graph_w_interrupt(
"config": {
"callbacks": None,
"configurable": {
"__pregel_ensure_latest": True,
"__pregel_dedupe_tasks": True,
"__pregel_resuming": False,
"checkpoint_id": c.config["configurable"]["checkpoint_id"],
Expand Down

0 comments on commit b395824

Please sign in to comment.