Skip to content

Commit

Permalink
Fix Send order after interrupt/resume
Browse files Browse the repository at this point in the history
- order was incorrectly based on task id, instead of the correct task path
- this requires storing task paths on checkpointers
- addition of task_path to put_writes is made backwards compatible by checking signature on call, and treating it as an optinal arg
  • Loading branch information
nfcampos committed Jan 15, 2025
1 parent 50cb387 commit db2e9aa
Show file tree
Hide file tree
Showing 11 changed files with 264 additions and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def put_writes(
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Store intermediate writes linked to a checkpoint.
Expand All @@ -350,6 +351,7 @@ def put_writes(
config["configurable"]["checkpoint_ns"],
config["configurable"]["checkpoint_id"],
task_id,
task_path,
writes,
),
)
Expand Down
2 changes: 2 additions & 0 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ async def aput_writes(
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Store intermediate writes linked to a checkpoint asynchronously.
Expand All @@ -306,6 +307,7 @@ async def aput_writes(
config["configurable"]["checkpoint_ns"],
config["configurable"]["checkpoint_id"],
task_id,
task_path,
writes,
)
async with self._cursor(pipeline=True) as cur:
Expand Down
13 changes: 8 additions & 5 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS checkpoint_writes_thread_id_idx ON checkpoint_writes(thread_id);
""",
"""ALTER TABLE checkpoint_writes ADD COLUMN task_path TEXT NOT NULL DEFAULT '';""",
]

SELECT_SQL = f"""
Expand Down Expand Up @@ -94,7 +95,7 @@
and cw.checkpoint_id = checkpoints.checkpoint_id
) as pending_writes,
(
select array_agg(array[cw.type::bytea, cw.blob] order by cw.task_id, cw.idx)
select array_agg(array[cw.type::bytea, cw.blob] order by cw.task_path, cw.task_id, cw.idx)
from checkpoint_writes cw
where cw.thread_id = checkpoints.thread_id
and cw.checkpoint_ns = checkpoints.checkpoint_ns
Expand All @@ -119,17 +120,17 @@
"""

UPSERT_CHECKPOINT_WRITES_SQL = """
INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, task_path, idx, channel, type, blob)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO UPDATE SET
channel = EXCLUDED.channel,
type = EXCLUDED.type,
blob = EXCLUDED.blob;
"""

INSERT_CHECKPOINT_WRITES_SQL = """
INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, task_path, idx, channel, type, blob)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO NOTHING
"""

Expand Down Expand Up @@ -220,6 +221,7 @@ def _dump_writes(
checkpoint_ns: str,
checkpoint_id: str,
task_id: str,
task_path: str,
writes: Sequence[tuple[str, Any]],
) -> list[tuple[str, str, str, str, int, str, str, bytes]]:
return [
Expand All @@ -228,6 +230,7 @@ def _dump_writes(
checkpoint_ns,
checkpoint_id,
task_id,
task_path,
WRITES_IDX_MAP.get(channel, idx),
channel,
*self.serde.dumps_typed(value),
Expand Down
17 changes: 12 additions & 5 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/shallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS checkpoint_writes_thread_id_idx ON checkpoint_writes(thread_id);
""",
"""
ALTER TABLE checkpoint_writes ADD COLUMN task_path TEXT NOT NULL DEFAULT '';
""",
]

SELECT_SQL = f"""
Expand All @@ -99,7 +102,7 @@
and cw.checkpoint_id = (checkpoint->>'id')
) as pending_writes,
(
select array_agg(array[cw.type::bytea, cw.blob] order by cw.task_id, cw.idx)
select array_agg(array[cw.type::bytea, cw.blob] order by cw.task_path, cw.task_id, cw.idx)
from checkpoint_writes cw
where cw.thread_id = checkpoints.thread_id
and cw.checkpoint_ns = checkpoints.checkpoint_ns
Expand All @@ -125,17 +128,17 @@
"""

UPSERT_CHECKPOINT_WRITES_SQL = """
INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, task_path, idx, channel, type, blob)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO UPDATE SET
channel = EXCLUDED.channel,
type = EXCLUDED.type,
blob = EXCLUDED.blob;
"""

INSERT_CHECKPOINT_WRITES_SQL = """
INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, task_path, idx, channel, type, blob)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO NOTHING
"""

Expand Down Expand Up @@ -430,6 +433,7 @@ def put_writes(
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Store intermediate writes linked to a checkpoint.
Expand All @@ -453,6 +457,7 @@ def put_writes(
config["configurable"]["checkpoint_ns"],
config["configurable"]["checkpoint_id"],
task_id,
task_path,
writes,
),
)
Expand Down Expand Up @@ -747,6 +752,7 @@ async def aput_writes(
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Store intermediate writes linked to a checkpoint asynchronously.
Expand All @@ -768,6 +774,7 @@ async def aput_writes(
config["configurable"]["checkpoint_ns"],
config["configurable"]["checkpoint_id"],
task_id,
task_path,
writes,
)
async with self._cursor(pipeline=True) as cur:
Expand Down
11 changes: 6 additions & 5 deletions libs/checkpoint/langgraph/checkpoint/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from datetime import datetime, timezone
from typing import (
from typing import ( # noqa: UP035
Any,
AsyncIterator,
Dict,
Generic,
Iterator,
List,
Literal,
Mapping,
NamedTuple,
Optional,
Sequence,
Tuple,
TypedDict,
TypeVar,
Expand Down Expand Up @@ -305,13 +302,15 @@ def put_writes(
config: RunnableConfig,
writes: Sequence[Tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Store intermediate writes linked to a checkpoint.
Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (List[Tuple[str, Any]]): List of writes to store.
task_id (str): Identifier for the task creating the writes.
task_path (str): Path of the task creating the writes.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
Expand Down Expand Up @@ -397,13 +396,15 @@ async def aput_writes(
config: RunnableConfig,
writes: Sequence[Tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Asynchronously store intermediate writes linked to a checkpoint.
Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (List[Tuple[str, Any]]): List of writes to store.
task_id (str): Identifier for the task creating the writes.
task_path (str): Path of the task creating the writes.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
Expand Down
Loading

0 comments on commit db2e9aa

Please sign in to comment.