diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py index 488d39e99a..9c246ef741 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py @@ -298,18 +298,16 @@ async def aput_writes( if all(w[0] in WRITES_IDX_MAP for w in writes) else self.INSERT_CHECKPOINT_WRITES_SQL ) + params = await asyncio.to_thread( + self._dump_writes, + config["configurable"]["thread_id"], + config["configurable"]["checkpoint_ns"], + config["configurable"]["checkpoint_id"], + task_id, + writes, + ) async with self._cursor(pipeline=True) as cur: - await cur.executemany( - query, - await asyncio.to_thread( - self._dump_writes, - config["configurable"]["thread_id"], - config["configurable"]["checkpoint_ns"], - config["configurable"]["checkpoint_id"], - task_id, - writes, - ), - ) + await cur.executemany(query, params) @asynccontextmanager async def _cursor(self, *, pipeline: bool = False) -> AsyncIterator[AsyncCursor]: diff --git a/libs/langgraph/langgraph/constants.py b/libs/langgraph/langgraph/constants.py index 7cea9de855..15dc47fd9a 100644 --- a/libs/langgraph/langgraph/constants.py +++ b/libs/langgraph/langgraph/constants.py @@ -10,10 +10,12 @@ CONFIG_KEY_RESUMING = "__pregel_resuming" CONFIG_KEY_TASK_ID = "__pregel_task_id" CONFIG_KEY_DEDUPE_TASKS = "__pregel_dedupe_tasks" +CONFIG_KEY_ENSURE_LATEST = "__pregel_ensure_latest" # this one part of public API so more readable CONFIG_KEY_CHECKPOINT_MAP = "checkpoint_map" INTERRUPT = "__interrupt__" ERROR = "__error__" +NO_WRITES = "__no_writes__" SCHEDULED = "__scheduled__" TASKS = "__pregel_tasks" # for backwards compat, this is the original name of PUSH PUSH = "__pregel_push" @@ -23,6 +25,7 @@ SCHEDULED, INTERRUPT, ERROR, + NO_WRITES, TASKS, PUSH, PULL, @@ -34,6 +37,7 @@ CONFIG_KEY_RESUMING, CONFIG_KEY_TASK_ID, CONFIG_KEY_DEDUPE_TASKS, + CONFIG_KEY_ENSURE_LATEST, INPUT, RUNTIME_PLACEHOLDER, } diff --git a/libs/langgraph/langgraph/errors.py b/libs/langgraph/langgraph/errors.py index aa5b57857b..2fcd644734 100644 --- a/libs/langgraph/langgraph/errors.py +++ b/libs/langgraph/langgraph/errors.py @@ -55,6 +55,12 @@ class TaskNotFound(Exception): pass +class CheckpointNotLatest(Exception): + """Raised when the checkpoint is not the latest version.""" + + pass + + __all__ = [ "GraphRecursionError", "InvalidUpdateError", diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index d1229d6e92..c2324f9208 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -21,7 +21,6 @@ from uuid import UUID, uuid5 from langchain_core.globals import get_debug -from langchain_core.load.dump import dumpd from langchain_core.runnables import ( Runnable, RunnableLambda, @@ -1160,7 +1159,7 @@ def output() -> Iterator: config = ensure_config(merge_configs(self.config, config)) callback_manager = get_callback_manager_for_config(config) run_manager = callback_manager.on_chain_start( - dumpd(self), + None, input, name=config.get("run_name", self.get_name()), run_id=config.get("run_id"), @@ -1341,7 +1340,7 @@ def output() -> Iterator: config = ensure_config(merge_configs(self.config, config)) callback_manager = get_async_callback_manager_for_config(config) run_manager = await callback_manager.on_chain_start( - dumpd(self), + None, input, name=config.get("run_name", self.get_name()), run_id=config.get("run_id"), diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index d0f8e712e2..9f98379808 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -28,6 +28,7 @@ CONFIG_KEY_SEND, CONFIG_KEY_TASK_ID, INTERRUPT, + NO_WRITES, NS_SEP, PULL, PUSH, @@ -196,7 +197,9 @@ def apply_writes( pending_writes_by_managed: dict[str, list[Any]] = defaultdict(list) for task in tasks: for chan, val in task.writes: - if chan == TASKS: + if chan == NO_WRITES: + pass + elif chan == TASKS: checkpoint["pending_sends"].append(val) elif chan in channels: pending_writes_by_channel[chan].append(val) @@ -331,6 +334,8 @@ def prepare_single_task( if task_path[0] == PUSH: idx = int(task_path[1]) + if idx >= len(checkpoint["pending_sends"]): + return packet = checkpoint["pending_sends"][idx] if not isinstance(packet, Send): logger.warning( @@ -425,6 +430,8 @@ def prepare_single_task( return PregelTask(task_id, packet.node, task_path) elif task_path[0] == PULL: name = str(task_path[1]) + if name not in processes: + return proc = processes[name] version_type = type(next(iter(checkpoint["channel_versions"].values()), None)) null_version = version_type() diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index 7772827c46..c0551c14b0 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -40,6 +40,7 @@ from langgraph.constants import ( CONFIG_KEY_CHECKPOINT_MAP, CONFIG_KEY_DEDUPE_TASKS, + CONFIG_KEY_ENSURE_LATEST, CONFIG_KEY_RESUMING, CONFIG_KEY_STREAM, CONFIG_KEY_TASK_ID, @@ -49,7 +50,7 @@ SCHEDULED, TAG_HIDDEN, ) -from langgraph.errors import EmptyInputError, GraphInterrupt +from langgraph.errors import CheckpointNotLatest, EmptyInputError, GraphInterrupt from langgraph.managed.base import ( ManagedValueMapping, ManagedValueSpec, @@ -599,11 +600,26 @@ def _update_mv(self, key: str, values: Sequence[Any]) -> None: # context manager def __enter__(self) -> Self: - saved = ( - self.checkpointer.get_tuple(self.checkpoint_config) - if self.checkpointer - else None - ) or CheckpointTuple(self.config, empty_checkpoint(), {"step": -2}, None, []) + if self.config.get("configurable", {}).get( + CONFIG_KEY_ENSURE_LATEST + ) and self.checkpoint_config["configurable"].get("checkpoint_id"): + saved = self.checkpointer.get_tuple( + patch_configurable(self.checkpoint_config, {"checkpoint_id": None}) + ) + if ( + saved is None + or saved.checkpoint["id"] + != self.checkpoint_config["configurable"]["checkpoint_id"] + ): + raise CheckpointNotLatest + elif self.checkpointer: + saved = self.checkpointer.get_tuple(self.checkpoint_config) + else: + saved = None + if saved is None: + saved = CheckpointTuple( + self.config, empty_checkpoint(), {"step": -2}, None, [] + ) self.checkpoint_config = { **self.config, **saved.config, @@ -702,11 +718,26 @@ def _update_mv(self, key: str, values: Sequence[Any]) -> None: # context manager async def __aenter__(self) -> Self: - saved = ( - await self.checkpointer.aget_tuple(self.checkpoint_config) - if self.checkpointer - else None - ) or CheckpointTuple(self.config, empty_checkpoint(), {"step": -2}, None, []) + if self.config.get("configurable", {}).get( + CONFIG_KEY_ENSURE_LATEST + ) and self.checkpoint_config["configurable"].get("checkpoint_id"): + saved = await self.checkpointer.aget_tuple( + patch_configurable(self.checkpoint_config, {"checkpoint_id": None}) + ) + if ( + saved is None + or saved.checkpoint["id"] + != self.checkpoint_config["configurable"]["checkpoint_id"] + ): + raise CheckpointNotLatest + elif self.checkpointer: + saved = await self.checkpointer.aget_tuple(self.checkpoint_config) + else: + saved = None + if saved is None: + saved = CheckpointTuple( + self.config, empty_checkpoint(), {"step": -2}, None, [] + ) self.checkpoint_config = { **self.config, **saved.config, diff --git a/libs/langgraph/langgraph/pregel/runner.py b/libs/langgraph/langgraph/pregel/runner.py index 3f42fbd5f5..086afa9daa 100644 --- a/libs/langgraph/langgraph/pregel/runner.py +++ b/libs/langgraph/langgraph/pregel/runner.py @@ -12,7 +12,7 @@ Union, ) -from langgraph.constants import ERROR, INTERRUPT +from langgraph.constants import ERROR, INTERRUPT, NO_WRITES from langgraph.errors import GraphInterrupt from langgraph.pregel.executor import Submit from langgraph.pregel.retry import arun_with_retry, run_with_retry @@ -69,12 +69,15 @@ def tick( if exc := _exception(fut): if isinstance(exc, GraphInterrupt): # save interrupt to checkpointer - self.put_writes(task.id, [(INTERRUPT, i) for i in exc.args[0]]) + if interrupts := [(INTERRUPT, i) for i in exc.args[0]]: + self.put_writes(task.id, interrupts) else: # save error to checkpointer self.put_writes(task.id, [(ERROR, exc)]) - else: + if not task.writes: + # add no writes marker + task.writes.append((NO_WRITES, None)) # save task writes to checkpointer self.put_writes(task.id, task.writes) else: @@ -130,11 +133,15 @@ async def atick( if exc := _exception(fut): if isinstance(exc, GraphInterrupt): # save interrupt to checkpointer - self.put_writes(task.id, [(INTERRUPT, i) for i in exc.args[0]]) + if interrupts := [(INTERRUPT, i) for i in exc.args[0]]: + self.put_writes(task.id, interrupts) else: # save error to checkpointer self.put_writes(task.id, [(ERROR, exc)]) else: + if not task.writes: + # add no writes marker + task.writes.append((NO_WRITES, None)) # save task writes to checkpointer self.put_writes(task.id, task.writes) else: diff --git a/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py b/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py index 91b53f31d7..714fbb647f 100644 --- a/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py +++ b/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py @@ -8,7 +8,7 @@ import langgraph.scheduler.kafka.serde as serde from langgraph.constants import ERROR -from langgraph.errors import TaskNotFound +from langgraph.errors import CheckpointNotLatest, TaskNotFound from langgraph.pregel import Pregel from langgraph.pregel.algo import prepare_single_task from langgraph.pregel.executor import AsyncBackgroundExecutor, Submit @@ -22,6 +22,7 @@ MessageToOrchestrator, Topics, ) +from langgraph.utils.config import patch_configurable class KafkaExecutor(AbstractAsyncContextManager): @@ -91,6 +92,8 @@ async def __anext__(self) -> Sequence[MessageToExecutor]: async def each(self, msg: MessageToExecutor) -> None: try: await aretry(self.retry_policy, self.attempt, msg) + except CheckpointNotLatest: + pass except Exception as exc: await self.producer.send_and_wait( self.topics.error, @@ -103,9 +106,13 @@ async def each(self, msg: MessageToExecutor) -> None: async def attempt(self, msg: MessageToExecutor) -> None: # process message - saved = await self.graph.checkpointer.aget_tuple(msg["config"]) + saved = await self.graph.checkpointer.aget_tuple( + patch_configurable(msg["config"], {"checkpoint_id": None}) + ) if saved is None: raise RuntimeError("Checkpoint not found") + if saved.checkpoint["id"] != msg["config"]["configurable"]["checkpoint_id"]: + raise CheckpointNotLatest() async with AsyncChannelsManager( self.graph.channels, saved.checkpoint, msg["config"], self.graph.store ) as (channels, managed), AsyncBackgroundExecutor() as submit: @@ -138,6 +145,8 @@ async def attempt(self, msg: MessageToExecutor) -> None: await self.producer.send_and_wait( self.topics.orchestrator, value=MessageToOrchestrator(input=None, config=msg["config"]), + # use thread_id as partition key + key=msg["config"]["configurable"]["thread_id"].encode(), ) def _put_writes( diff --git a/libs/scheduler-kafka/langgraph/scheduler/kafka/orchestrator.py b/libs/scheduler-kafka/langgraph/scheduler/kafka/orchestrator.py index efa18d9e95..9abadf743f 100644 --- a/libs/scheduler-kafka/langgraph/scheduler/kafka/orchestrator.py +++ b/libs/scheduler-kafka/langgraph/scheduler/kafka/orchestrator.py @@ -6,7 +6,13 @@ from langchain_core.runnables import ensure_config import langgraph.scheduler.kafka.serde as serde -from langgraph.constants import CONFIG_KEY_DEDUPE_TASKS, INTERRUPT, SCHEDULED +from langgraph.constants import ( + CONFIG_KEY_DEDUPE_TASKS, + CONFIG_KEY_ENSURE_LATEST, + INTERRUPT, + SCHEDULED, +) +from langgraph.errors import CheckpointNotLatest from langgraph.pregel import Pregel from langgraph.pregel.loop import AsyncPregelLoop from langgraph.pregel.types import RetryPolicy @@ -83,6 +89,8 @@ async def __anext__(self) -> list[MessageToOrchestrator]: async def each(self, msg: MessageToOrchestrator) -> None: try: await aretry(self.retry_policy, self.attempt, msg) + except CheckpointNotLatest: + pass except Exception as exc: await self.producer.send_and_wait( self.topics.error, @@ -127,6 +135,7 @@ async def attempt(self, msg: MessageToOrchestrator) -> None: { **loop.checkpoint_config["configurable"], CONFIG_KEY_DEDUPE_TASKS: True, + CONFIG_KEY_ENSURE_LATEST: True, }, ), task=ExecutorTask(id=task.id, path=task.path), diff --git a/libs/scheduler-kafka/tests/test_fanout.py b/libs/scheduler-kafka/tests/test_fanout.py index e00787aee7..4e5e9586a6 100644 --- a/libs/scheduler-kafka/tests/test_fanout.py +++ b/libs/scheduler-kafka/tests/test_fanout.py @@ -126,6 +126,7 @@ async def test_fanout_graph(topics: Topics, checkpointer: BaseCheckpointSaver) - "config": { "callbacks": None, "configurable": { + "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], @@ -146,6 +147,7 @@ async def test_fanout_graph(topics: Topics, checkpointer: BaseCheckpointSaver) - "config": { "callbacks": None, "configurable": { + "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], @@ -206,6 +208,7 @@ async def test_fanout_graph_w_interrupt( "config": { "callbacks": None, "configurable": { + "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"], @@ -226,6 +229,7 @@ async def test_fanout_graph_w_interrupt( "config": { "callbacks": None, "configurable": { + "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, "checkpoint_id": c.config["configurable"]["checkpoint_id"],