From 44ee0199fd39c3124577954c7367325760d89433 Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Fri, 20 Dec 2024 12:51:20 -0500 Subject: [PATCH] checkpoint postgres: add a shallow checkpointer (#2826) This PR adds a "shallow" version of `PostgresSaver` checkpointer that ONLY stores the most recent checkpoint and does NOT retain any history. It is meant to be a light-weight drop-in replacement for the PostgresSaver that supports most of the LangGraph persistence functionality with the exception of time travel. --- .../langgraph/checkpoint/postgres/__init__.py | 3 +- .../langgraph/checkpoint/postgres/aio.py | 3 +- .../langgraph/checkpoint/postgres/shallow.py | 918 ++++++++ libs/checkpoint-postgres/tests/test_async.py | 39 +- libs/checkpoint-postgres/tests/test_sync.py | 32 +- .../tests/__snapshots__/test_large_cases.ambr | 530 +++++ .../__snapshots__/test_large_cases_async.ambr | 25 + .../tests/__snapshots__/test_pregel.ambr | 166 ++ .../__snapshots__/test_pregel_async.ambr | 136 ++ libs/langgraph/tests/conftest.py | 68 +- libs/langgraph/tests/test_large_cases.py | 1899 ++++++++++++----- .../langgraph/tests/test_large_cases_async.py | 1250 +++++++---- libs/langgraph/tests/test_pregel.py | 47 +- libs/langgraph/tests/test_pregel_async.py | 306 ++- 14 files changed, 4331 insertions(+), 1091 deletions(-) create mode 100644 libs/checkpoint-postgres/langgraph/checkpoint/postgres/shallow.py diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py index e5a3cce55..475761808 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py @@ -19,6 +19,7 @@ ) from langgraph.checkpoint.postgres import _internal from langgraph.checkpoint.postgres.base import BasePostgresSaver +from langgraph.checkpoint.postgres.shallow import ShallowPostgresSaver from langgraph.checkpoint.serde.base import SerializerProtocol Conn = _internal.Conn # For backward compatibility @@ -396,4 +397,4 @@ def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]: yield cur -__all__ = ["PostgresSaver", "BasePostgresSaver", "Conn"] +__all__ = ["PostgresSaver", "BasePostgresSaver", "ShallowPostgresSaver", "Conn"] diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py index 3a1e13db1..a5a82b185 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py @@ -19,6 +19,7 @@ ) from langgraph.checkpoint.postgres import _ainternal from langgraph.checkpoint.postgres.base import BasePostgresSaver +from langgraph.checkpoint.postgres.shallow import AsyncShallowPostgresSaver from langgraph.checkpoint.serde.base import SerializerProtocol Conn = _ainternal.Conn # For backward compatibility @@ -464,4 +465,4 @@ def put_writes( ).result() -__all__ = ["AsyncPostgresSaver", "Conn"] +__all__ = ["AsyncPostgresSaver", "AsyncShallowPostgresSaver", "Conn"] diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/shallow.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/shallow.py new file mode 100644 index 000000000..163607cf9 --- /dev/null +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/shallow.py @@ -0,0 +1,918 @@ +import asyncio +import threading +from collections.abc import AsyncIterator, Iterator, Sequence +from contextlib import asynccontextmanager, contextmanager +from typing import Any, Optional + +from langchain_core.runnables import RunnableConfig +from psycopg import ( + AsyncConnection, + AsyncCursor, + AsyncPipeline, + Capabilities, + Connection, + Cursor, + Pipeline, +) +from psycopg.rows import DictRow, dict_row +from psycopg.types.json import Jsonb +from psycopg_pool import AsyncConnectionPool, ConnectionPool + +from langgraph.checkpoint.base import ( + WRITES_IDX_MAP, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + CheckpointTuple, +) +from langgraph.checkpoint.postgres import _ainternal, _internal +from langgraph.checkpoint.postgres.base import BasePostgresSaver +from langgraph.checkpoint.serde.base import SerializerProtocol +from langgraph.checkpoint.serde.types import TASKS + +""" +To add a new migration, add a new string to the MIGRATIONS list. +The position of the migration in the list is the version number. +""" +MIGRATIONS = [ + """CREATE TABLE IF NOT EXISTS checkpoint_migrations ( + v INTEGER PRIMARY KEY +);""", + """CREATE TABLE IF NOT EXISTS checkpoints ( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + type TEXT, + checkpoint JSONB NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}', + PRIMARY KEY (thread_id, checkpoint_ns) +);""", + """CREATE TABLE IF NOT EXISTS checkpoint_blobs ( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + channel TEXT NOT NULL, + type TEXT NOT NULL, + blob BYTEA, + PRIMARY KEY (thread_id, checkpoint_ns, channel) +);""", + """CREATE TABLE IF NOT EXISTS checkpoint_writes ( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + task_id TEXT NOT NULL, + idx INTEGER NOT NULL, + channel TEXT NOT NULL, + type TEXT, + blob BYTEA NOT NULL, + PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) +);""", + """ + CREATE INDEX CONCURRENTLY IF NOT EXISTS checkpoints_thread_id_idx ON checkpoints(thread_id); + """, + """ + CREATE INDEX CONCURRENTLY IF NOT EXISTS checkpoint_blobs_thread_id_idx ON checkpoint_blobs(thread_id); + """, + """ + CREATE INDEX CONCURRENTLY IF NOT EXISTS checkpoint_writes_thread_id_idx ON checkpoint_writes(thread_id); + """, +] + +SELECT_SQL = f""" +select + thread_id, + checkpoint, + checkpoint_ns, + metadata, + ( + select array_agg(array[bl.channel::bytea, bl.type::bytea, bl.blob]) + from jsonb_each_text(checkpoint -> 'channel_versions') + inner join checkpoint_blobs bl + on bl.thread_id = checkpoints.thread_id + and bl.checkpoint_ns = checkpoints.checkpoint_ns + and bl.channel = jsonb_each_text.key + ) as channel_values, + ( + select + array_agg(array[cw.task_id::text::bytea, cw.channel::bytea, cw.type::bytea, cw.blob] order by cw.task_id, cw.idx) + from checkpoint_writes cw + where cw.thread_id = checkpoints.thread_id + and cw.checkpoint_ns = checkpoints.checkpoint_ns + and cw.checkpoint_id = (checkpoint->>'id') + ) as pending_writes, + ( + select array_agg(array[cw.type::bytea, cw.blob] order by cw.task_id, cw.idx) + from checkpoint_writes cw + where cw.thread_id = checkpoints.thread_id + and cw.checkpoint_ns = checkpoints.checkpoint_ns + and cw.channel = '{TASKS}' + ) as pending_sends +from checkpoints """ + +UPSERT_CHECKPOINT_BLOBS_SQL = """ + INSERT INTO checkpoint_blobs (thread_id, checkpoint_ns, channel, type, blob) + VALUES (%s, %s, %s, %s, %s) + ON CONFLICT (thread_id, checkpoint_ns, channel) DO UPDATE SET + type = EXCLUDED.type, + blob = EXCLUDED.blob; +""" + +UPSERT_CHECKPOINTS_SQL = """ + INSERT INTO checkpoints (thread_id, checkpoint_ns, checkpoint, metadata) + VALUES (%s, %s, %s, %s) + ON CONFLICT (thread_id, checkpoint_ns) + DO UPDATE SET + checkpoint = EXCLUDED.checkpoint, + metadata = EXCLUDED.metadata; +""" + +UPSERT_CHECKPOINT_WRITES_SQL = """ + INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO UPDATE SET + channel = EXCLUDED.channel, + type = EXCLUDED.type, + blob = EXCLUDED.blob; +""" + +INSERT_CHECKPOINT_WRITES_SQL = """ + INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO NOTHING +""" + + +def _dump_blobs( + serde: SerializerProtocol, + thread_id: str, + checkpoint_ns: str, + values: dict[str, Any], + versions: ChannelVersions, +) -> list[tuple[str, str, str, str, str, Optional[bytes]]]: + if not versions: + return [] + + return [ + ( + thread_id, + checkpoint_ns, + k, + *(serde.dumps_typed(values[k]) if k in values else ("empty", None)), + ) + for k in versions + ] + + +class ShallowPostgresSaver(BasePostgresSaver): + """A checkpoint saver that uses Postgres to store checkpoints. + + This checkpointer ONLY stores the most recent checkpoint and does NOT retain any history. + It is meant to be a light-weight drop-in replacement for the PostgresSaver that + supports most of the LangGraph persistence functionality with the exception of time travel. + """ + + SELECT_SQL = SELECT_SQL + MIGRATIONS = MIGRATIONS + UPSERT_CHECKPOINT_BLOBS_SQL = UPSERT_CHECKPOINT_BLOBS_SQL + UPSERT_CHECKPOINTS_SQL = UPSERT_CHECKPOINTS_SQL + UPSERT_CHECKPOINT_WRITES_SQL = UPSERT_CHECKPOINT_WRITES_SQL + INSERT_CHECKPOINT_WRITES_SQL = INSERT_CHECKPOINT_WRITES_SQL + + lock: threading.Lock + + def __init__( + self, + conn: _internal.Conn, + pipe: Optional[Pipeline] = None, + serde: Optional[SerializerProtocol] = None, + ) -> None: + super().__init__(serde=serde) + if isinstance(conn, ConnectionPool) and pipe is not None: + raise ValueError( + "Pipeline should be used only with a single Connection, not ConnectionPool." + ) + + self.conn = conn + self.pipe = pipe + self.lock = threading.Lock() + self.supports_pipeline = Capabilities().has_pipeline() + + @classmethod + @contextmanager + def from_conn_string( + cls, conn_string: str, *, pipeline: bool = False + ) -> Iterator["ShallowPostgresSaver"]: + """Create a new ShallowPostgresSaver instance from a connection string. + + Args: + conn_string (str): The Postgres connection info string. + pipeline (bool): whether to use Pipeline + + Returns: + ShallowPostgresSaver: A new ShallowPostgresSaver instance. + """ + with Connection.connect( + conn_string, autocommit=True, prepare_threshold=0, row_factory=dict_row + ) as conn: + if pipeline: + with conn.pipeline() as pipe: + yield cls(conn, pipe) + else: + yield cls(conn) + + def setup(self) -> None: + """Set up the checkpoint database asynchronously. + + This method creates the necessary tables in the Postgres database if they don't + already exist and runs database migrations. It MUST be called directly by the user + the first time checkpointer is used. + """ + with self._cursor() as cur: + cur.execute(self.MIGRATIONS[0]) + results = cur.execute( + "SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1" + ) + row = results.fetchone() + if row is None: + version = -1 + else: + version = row["v"] + for v, migration in zip( + range(version + 1, len(self.MIGRATIONS)), + self.MIGRATIONS[version + 1 :], + ): + cur.execute(migration) + cur.execute(f"INSERT INTO checkpoint_migrations (v) VALUES ({v})") + if self.pipe: + self.pipe.sync() + + def list( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> Iterator[CheckpointTuple]: + """List checkpoints from the database. + + This method retrieves a list of checkpoint tuples from the Postgres database based + on the provided config. For ShallowPostgresSaver, this method returns a list with + ONLY the most recent checkpoint. + """ + where, args = self._search_where(config, filter, before) + query = self.SELECT_SQL + where + if limit: + query += f" LIMIT {limit}" + with self._cursor() as cur: + cur.execute(self.SELECT_SQL + where, args, binary=True) + for value in cur: + checkpoint = self._load_checkpoint( + value["checkpoint"], + value["channel_values"], + value["pending_sends"], + ) + yield CheckpointTuple( + config={ + "configurable": { + "thread_id": value["thread_id"], + "checkpoint_ns": value["checkpoint_ns"], + "checkpoint_id": checkpoint["id"], + } + }, + checkpoint=checkpoint, + metadata=self._load_metadata(value["metadata"]), + pending_writes=self._load_writes(value["pending_writes"]), + ) + + def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Get a checkpoint tuple from the database. + + This method retrieves a checkpoint tuple from the Postgres database based on the + provided config (matching the thread ID in the config). + + Args: + config (RunnableConfig): The config to use for retrieving the checkpoint. + + Returns: + Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. + + Examples: + + Basic: + >>> config = {"configurable": {"thread_id": "1"}} + >>> checkpoint_tuple = memory.get_tuple(config) + >>> print(checkpoint_tuple) + CheckpointTuple(...) + + With timestamp: + + >>> config = { + ... "configurable": { + ... "thread_id": "1", + ... "checkpoint_ns": "", + ... "checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875", + ... } + ... } + >>> checkpoint_tuple = memory.get_tuple(config) + >>> print(checkpoint_tuple) + CheckpointTuple(...) + """ # noqa + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"].get("checkpoint_ns", "") + args = (thread_id, checkpoint_ns) + where = "WHERE thread_id = %s AND checkpoint_ns = %s" + + with self._cursor() as cur: + cur.execute( + self.SELECT_SQL + where, + args, + binary=True, + ) + + for value in cur: + checkpoint = self._load_checkpoint( + value["checkpoint"], + value["channel_values"], + value["pending_sends"], + ) + return CheckpointTuple( + config={ + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + }, + checkpoint=checkpoint, + metadata=self._load_metadata(value["metadata"]), + pending_writes=self._load_writes(value["pending_writes"]), + ) + + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Save a checkpoint to the database. + + This method saves a checkpoint to the Postgres database. The checkpoint is associated + with the provided config. For ShallowPostgresSaver, this method saves ONLY the most recent + checkpoint and overwrites a previous checkpoint, if it exists. + + Args: + config (RunnableConfig): The config to associate with the checkpoint. + checkpoint (Checkpoint): The checkpoint to save. + metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + + Examples: + + >>> from langgraph.checkpoint.postgres import ShallowPostgresSaver + >>> DB_URI = "postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable" + >>> with ShallowPostgresSaver.from_conn_string(DB_URI) as memory: + >>> config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} + >>> checkpoint = {"ts": "2024-05-04T06:32:42.235444+00:00", "id": "1ef4f797-8335-6428-8001-8a1503f9b875", "channel_values": {"key": "value"}} + >>> saved_config = memory.put(config, checkpoint, {"source": "input", "step": 1, "writes": {"key": "value"}}, {}) + >>> print(saved_config) + {'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '1ef4f797-8335-6428-8001-8a1503f9b875'}} + """ + configurable = config["configurable"].copy() + thread_id = configurable.pop("thread_id") + checkpoint_ns = configurable.pop("checkpoint_ns") + + copy = checkpoint.copy() + next_config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + } + + with self._cursor(pipeline=True) as cur: + cur.execute( + """DELETE FROM checkpoint_writes + WHERE thread_id = %s AND checkpoint_ns = %s AND checkpoint_id NOT IN (%s, %s)""", + ( + thread_id, + checkpoint_ns, + checkpoint["id"], + configurable.get("checkpoint_id", ""), + ), + ) + cur.executemany( + self.UPSERT_CHECKPOINT_BLOBS_SQL, + _dump_blobs( + self.serde, + thread_id, + checkpoint_ns, + copy.pop("channel_values"), # type: ignore[misc] + new_versions, + ), + ) + cur.execute( + self.UPSERT_CHECKPOINTS_SQL, + ( + thread_id, + checkpoint_ns, + Jsonb(self._dump_checkpoint(copy)), + self._dump_metadata(metadata), + ), + ) + return next_config + + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str, + ) -> None: + """Store intermediate writes linked to a checkpoint. + + This method saves intermediate writes associated with a checkpoint to the Postgres database. + + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (List[Tuple[str, Any]]): List of writes to store. + task_id (str): Identifier for the task creating the writes. + """ + query = ( + self.UPSERT_CHECKPOINT_WRITES_SQL + if all(w[0] in WRITES_IDX_MAP for w in writes) + else self.INSERT_CHECKPOINT_WRITES_SQL + ) + with self._cursor(pipeline=True) as cur: + cur.executemany( + query, + self._dump_writes( + config["configurable"]["thread_id"], + config["configurable"]["checkpoint_ns"], + config["configurable"]["checkpoint_id"], + task_id, + writes, + ), + ) + + @contextmanager + def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]: + """Create a database cursor as a context manager. + + Args: + pipeline (bool): whether to use pipeline for the DB operations inside the context manager. + Will be applied regardless of whether the ShallowPostgresSaver instance was initialized with a pipeline. + If pipeline mode is not supported, will fall back to using transaction context manager. + """ + with _internal.get_connection(self.conn) as conn: + if self.pipe: + # a connection in pipeline mode can be used concurrently + # in multiple threads/coroutines, but only one cursor can be + # used at a time + try: + with conn.cursor(binary=True, row_factory=dict_row) as cur: + yield cur + finally: + if pipeline: + self.pipe.sync() + elif pipeline: + # a connection not in pipeline mode can only be used by one + # thread/coroutine at a time, so we acquire a lock + if self.supports_pipeline: + with ( + self.lock, + conn.pipeline(), + conn.cursor(binary=True, row_factory=dict_row) as cur, + ): + yield cur + else: + # Use connection's transaction context manager when pipeline mode not supported + with ( + self.lock, + conn.transaction(), + conn.cursor(binary=True, row_factory=dict_row) as cur, + ): + yield cur + else: + with self.lock, conn.cursor(binary=True, row_factory=dict_row) as cur: + yield cur + + +class AsyncShallowPostgresSaver(BasePostgresSaver): + """A checkpoint saver that uses Postgres to store checkpoints asynchronously. + + This checkpointer ONLY stores the most recent checkpoint and does NOT retain any history. + It is meant to be a light-weight drop-in replacement for the AsyncPostgresSaver that + supports most of the LangGraph persistence functionality with the exception of time travel. + """ + + SELECT_SQL = SELECT_SQL + MIGRATIONS = MIGRATIONS + UPSERT_CHECKPOINT_BLOBS_SQL = UPSERT_CHECKPOINT_BLOBS_SQL + UPSERT_CHECKPOINTS_SQL = UPSERT_CHECKPOINTS_SQL + UPSERT_CHECKPOINT_WRITES_SQL = UPSERT_CHECKPOINT_WRITES_SQL + INSERT_CHECKPOINT_WRITES_SQL = INSERT_CHECKPOINT_WRITES_SQL + lock: asyncio.Lock + + def __init__( + self, + conn: _ainternal.Conn, + pipe: Optional[AsyncPipeline] = None, + serde: Optional[SerializerProtocol] = None, + ) -> None: + super().__init__(serde=serde) + if isinstance(conn, AsyncConnectionPool) and pipe is not None: + raise ValueError( + "Pipeline should be used only with a single AsyncConnection, not AsyncConnectionPool." + ) + + self.conn = conn + self.pipe = pipe + self.lock = asyncio.Lock() + self.loop = asyncio.get_running_loop() + self.supports_pipeline = Capabilities().has_pipeline() + + @classmethod + @asynccontextmanager + async def from_conn_string( + cls, + conn_string: str, + *, + pipeline: bool = False, + serde: Optional[SerializerProtocol] = None, + ) -> AsyncIterator["AsyncShallowPostgresSaver"]: + """Create a new AsyncShallowPostgresSaver instance from a connection string. + + Args: + conn_string (str): The Postgres connection info string. + pipeline (bool): whether to use AsyncPipeline + + Returns: + AsyncShallowPostgresSaver: A new AsyncShallowPostgresSaver instance. + """ + async with await AsyncConnection.connect( + conn_string, autocommit=True, prepare_threshold=0, row_factory=dict_row + ) as conn: + if pipeline: + async with conn.pipeline() as pipe: + yield cls(conn=conn, pipe=pipe, serde=serde) + else: + yield cls(conn=conn, serde=serde) + + async def setup(self) -> None: + """Set up the checkpoint database asynchronously. + + This method creates the necessary tables in the Postgres database if they don't + already exist and runs database migrations. It MUST be called directly by the user + the first time checkpointer is used. + """ + async with self._cursor() as cur: + await cur.execute(self.MIGRATIONS[0]) + results = await cur.execute( + "SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1" + ) + row = await results.fetchone() + if row is None: + version = -1 + else: + version = row["v"] + for v, migration in zip( + range(version + 1, len(self.MIGRATIONS)), + self.MIGRATIONS[version + 1 :], + ): + await cur.execute(migration) + await cur.execute(f"INSERT INTO checkpoint_migrations (v) VALUES ({v})") + if self.pipe: + await self.pipe.sync() + + async def alist( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> AsyncIterator[CheckpointTuple]: + """List checkpoints from the database asynchronously. + + This method retrieves a list of checkpoint tuples from the Postgres database based + on the provided config. For ShallowPostgresSaver, this method returns a list with + ONLY the most recent checkpoint. + """ + where, args = self._search_where(config, filter, before) + query = self.SELECT_SQL + where + if limit: + query += f" LIMIT {limit}" + async with self._cursor() as cur: + await cur.execute(self.SELECT_SQL + where, args, binary=True) + async for value in cur: + checkpoint = await asyncio.to_thread( + self._load_checkpoint, + value["checkpoint"], + value["channel_values"], + value["pending_sends"], + ) + yield CheckpointTuple( + config={ + "configurable": { + "thread_id": value["thread_id"], + "checkpoint_ns": value["checkpoint_ns"], + "checkpoint_id": checkpoint["id"], + } + }, + checkpoint=checkpoint, + metadata=self._load_metadata(value["metadata"]), + pending_writes=await asyncio.to_thread( + self._load_writes, value["pending_writes"] + ), + ) + + async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Get a checkpoint tuple from the database asynchronously. + + This method retrieves a checkpoint tuple from the Postgres database based on the + provided config (matching the thread ID in the config). + + Args: + config (RunnableConfig): The config to use for retrieving the checkpoint. + + Returns: + Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. + """ + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"].get("checkpoint_ns", "") + args = (thread_id, checkpoint_ns) + where = "WHERE thread_id = %s AND checkpoint_ns = %s" + + async with self._cursor() as cur: + await cur.execute( + self.SELECT_SQL + where, + args, + binary=True, + ) + + async for value in cur: + checkpoint = await asyncio.to_thread( + self._load_checkpoint, + value["checkpoint"], + value["channel_values"], + value["pending_sends"], + ) + return CheckpointTuple( + config={ + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + }, + checkpoint=checkpoint, + metadata=self._load_metadata(value["metadata"]), + pending_writes=await asyncio.to_thread( + self._load_writes, value["pending_writes"] + ), + ) + + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Save a checkpoint to the database asynchronously. + + This method saves a checkpoint to the Postgres database. The checkpoint is associated + with the provided config. For AsyncShallowPostgresSaver, this method saves ONLY the most recent + checkpoint and overwrites a previous checkpoint, if it exists. + + Args: + config (RunnableConfig): The config to associate with the checkpoint. + checkpoint (Checkpoint): The checkpoint to save. + metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + """ + configurable = config["configurable"].copy() + thread_id = configurable.pop("thread_id") + checkpoint_ns = configurable.pop("checkpoint_ns") + + copy = checkpoint.copy() + next_config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + } + + async with self._cursor(pipeline=True) as cur: + await cur.execute( + """DELETE FROM checkpoint_writes + WHERE thread_id = %s AND checkpoint_ns = %s AND checkpoint_id NOT IN (%s, %s)""", + ( + thread_id, + checkpoint_ns, + checkpoint["id"], + configurable.get("checkpoint_id", ""), + ), + ) + await cur.executemany( + self.UPSERT_CHECKPOINT_BLOBS_SQL, + _dump_blobs( + self.serde, + thread_id, + checkpoint_ns, + copy.pop("channel_values"), # type: ignore[misc] + new_versions, + ), + ) + await cur.execute( + self.UPSERT_CHECKPOINTS_SQL, + ( + thread_id, + checkpoint_ns, + Jsonb(self._dump_checkpoint(copy)), + self._dump_metadata(metadata), + ), + ) + return next_config + + async def aput_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str, + ) -> None: + """Store intermediate writes linked to a checkpoint asynchronously. + + This method saves intermediate writes associated with a checkpoint to the database. + + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair. + task_id (str): Identifier for the task creating the writes. + """ + query = ( + self.UPSERT_CHECKPOINT_WRITES_SQL + if all(w[0] in WRITES_IDX_MAP for w in writes) + else self.INSERT_CHECKPOINT_WRITES_SQL + ) + params = await asyncio.to_thread( + self._dump_writes, + config["configurable"]["thread_id"], + config["configurable"]["checkpoint_ns"], + config["configurable"]["checkpoint_id"], + task_id, + writes, + ) + async with self._cursor(pipeline=True) as cur: + await cur.executemany(query, params) + + @asynccontextmanager + async def _cursor( + self, *, pipeline: bool = False + ) -> AsyncIterator[AsyncCursor[DictRow]]: + """Create a database cursor as a context manager. + + Args: + pipeline (bool): whether to use pipeline for the DB operations inside the context manager. + Will be applied regardless of whether the AsyncShallowPostgresSaver instance was initialized with a pipeline. + If pipeline mode is not supported, will fall back to using transaction context manager. + """ + async with _ainternal.get_connection(self.conn) as conn: + if self.pipe: + # a connection in pipeline mode can be used concurrently + # in multiple threads/coroutines, but only one cursor can be + # used at a time + try: + async with conn.cursor(binary=True, row_factory=dict_row) as cur: + yield cur + finally: + if pipeline: + await self.pipe.sync() + elif pipeline: + # a connection not in pipeline mode can only be used by one + # thread/coroutine at a time, so we acquire a lock + if self.supports_pipeline: + async with ( + self.lock, + conn.pipeline(), + conn.cursor(binary=True, row_factory=dict_row) as cur, + ): + yield cur + else: + # Use connection's transaction context manager when pipeline mode not supported + async with ( + self.lock, + conn.transaction(), + conn.cursor(binary=True, row_factory=dict_row) as cur, + ): + yield cur + else: + async with ( + self.lock, + conn.cursor(binary=True, row_factory=dict_row) as cur, + ): + yield cur + + def list( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> Iterator[CheckpointTuple]: + """List checkpoints from the database. + + This method retrieves a list of checkpoint tuples from the Postgres database based + on the provided config. For ShallowPostgresSaver, this method returns a list with + ONLY the most recent checkpoint. + """ + aiter_ = self.alist(config, filter=filter, before=before, limit=limit) + while True: + try: + yield asyncio.run_coroutine_threadsafe( + anext(aiter_), # noqa: F821 + self.loop, + ).result() + except StopAsyncIteration: + break + + def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Get a checkpoint tuple from the database. + + This method retrieves a checkpoint tuple from the Postgres database based on the + provided config (matching the thread ID in the config). + + Args: + config (RunnableConfig): The config to use for retrieving the checkpoint. + + Returns: + Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. + """ + try: + # check if we are in the main thread, only bg threads can block + # we don't check in other methods to avoid the overhead + if asyncio.get_running_loop() is self.loop: + raise asyncio.InvalidStateError( + "Synchronous calls to AsyncShallowPostgresSaver are only allowed from a " + "different thread. From the main thread, use the async interface." + "For example, use `await checkpointer.aget_tuple(...)` or `await " + "graph.ainvoke(...)`." + ) + except RuntimeError: + pass + return asyncio.run_coroutine_threadsafe( + self.aget_tuple(config), self.loop + ).result() + + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Save a checkpoint to the database. + + This method saves a checkpoint to the Postgres database. The checkpoint is associated + with the provided config. For AsyncShallowPostgresSaver, this method saves ONLY the most recent + checkpoint and overwrites a previous checkpoint, if it exists. + + Args: + config (RunnableConfig): The config to associate with the checkpoint. + checkpoint (Checkpoint): The checkpoint to save. + metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + """ + return asyncio.run_coroutine_threadsafe( + self.aput(config, checkpoint, metadata, new_versions), self.loop + ).result() + + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str, + ) -> None: + """Store intermediate writes linked to a checkpoint. + + This method saves intermediate writes associated with a checkpoint to the database. + + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair. + task_id (str): Identifier for the task creating the writes. + """ + return asyncio.run_coroutine_threadsafe( + self.aput_writes(config, writes, task_id), self.loop + ).result() diff --git a/libs/checkpoint-postgres/tests/test_async.py b/libs/checkpoint-postgres/tests/test_async.py index 67848707c..7590012dd 100644 --- a/libs/checkpoint-postgres/tests/test_async.py +++ b/libs/checkpoint-postgres/tests/test_async.py @@ -16,7 +16,10 @@ create_checkpoint, empty_checkpoint, ) -from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver +from langgraph.checkpoint.postgres.aio import ( + AsyncPostgresSaver, + AsyncShallowPostgresSaver, +) from tests.conftest import DEFAULT_POSTGRES_URI @@ -103,11 +106,41 @@ async def _base_saver(): await conn.execute(f"DROP DATABASE {database}") +@asynccontextmanager +async def _shallow_saver(): + """Fixture for shallow connection mode testing.""" + database = f"test_{uuid4().hex[:16]}" + # create unique db + async with await AsyncConnection.connect( + DEFAULT_POSTGRES_URI, autocommit=True + ) as conn: + await conn.execute(f"CREATE DATABASE {database}") + try: + async with await AsyncConnection.connect( + DEFAULT_POSTGRES_URI + database, + autocommit=True, + prepare_threshold=0, + row_factory=dict_row, + ) as conn: + checkpointer = AsyncShallowPostgresSaver(conn) + await checkpointer.setup() + yield checkpointer + finally: + # drop unique db + async with await AsyncConnection.connect( + DEFAULT_POSTGRES_URI, autocommit=True + ) as conn: + await conn.execute(f"DROP DATABASE {database}") + + @asynccontextmanager async def _saver(name: str): if name == "base": async with _base_saver() as saver: yield saver + elif name == "shallow": + async with _shallow_saver() as saver: + yield saver elif name == "pool": async with _pool_saver() as saver: yield saver @@ -167,7 +200,7 @@ def test_data(): } -@pytest.mark.parametrize("saver_name", ["base", "pool", "pipe"]) +@pytest.mark.parametrize("saver_name", ["base", "pool", "pipe", "shallow"]) async def test_asearch(request, saver_name: str, test_data) -> None: async with _saver(saver_name) as saver: configs = test_data["configs"] @@ -212,7 +245,7 @@ async def test_asearch(request, saver_name: str, test_data) -> None: } == {"", "inner"} -@pytest.mark.parametrize("saver_name", ["base", "pool", "pipe"]) +@pytest.mark.parametrize("saver_name", ["base", "pool", "pipe", "shallow"]) async def test_null_chars(request, saver_name: str, test_data) -> None: async with _saver(saver_name) as saver: config = await saver.aput( diff --git a/libs/checkpoint-postgres/tests/test_sync.py b/libs/checkpoint-postgres/tests/test_sync.py index 52e186a55..c5cdbd431 100644 --- a/libs/checkpoint-postgres/tests/test_sync.py +++ b/libs/checkpoint-postgres/tests/test_sync.py @@ -16,7 +16,7 @@ create_checkpoint, empty_checkpoint, ) -from langgraph.checkpoint.postgres import PostgresSaver +from langgraph.checkpoint.postgres import PostgresSaver, ShallowPostgresSaver from tests.conftest import DEFAULT_POSTGRES_URI @@ -91,11 +91,37 @@ def _base_saver(): conn.execute(f"DROP DATABASE {database}") +@contextmanager +def _shallow_saver(): + """Fixture for regular connection mode testing with a shallow checkpointer.""" + database = f"test_{uuid4().hex[:16]}" + # create unique db + with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: + conn.execute(f"CREATE DATABASE {database}") + try: + with Connection.connect( + DEFAULT_POSTGRES_URI + database, + autocommit=True, + prepare_threshold=0, + row_factory=dict_row, + ) as conn: + checkpointer = ShallowPostgresSaver(conn) + checkpointer.setup() + yield checkpointer + finally: + # drop unique db + with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: + conn.execute(f"DROP DATABASE {database}") + + @contextmanager def _saver(name: str): if name == "base": with _base_saver() as saver: yield saver + elif name == "shallow": + with _shallow_saver() as saver: + yield saver elif name == "pool": with _pool_saver() as saver: yield saver @@ -155,7 +181,7 @@ def test_data(): } -@pytest.mark.parametrize("saver_name", ["base", "pool", "pipe"]) +@pytest.mark.parametrize("saver_name", ["base", "pool", "pipe", "shallow"]) def test_search(saver_name: str, test_data) -> None: with _saver(saver_name) as saver: configs = test_data["configs"] @@ -198,7 +224,7 @@ def test_search(saver_name: str, test_data) -> None: } == {"", "inner"} -@pytest.mark.parametrize("saver_name", ["base", "pool", "pipe"]) +@pytest.mark.parametrize("saver_name", ["base", "pool", "pipe", "shallow"]) def test_null_chars(saver_name: str, test_data) -> None: with _saver(saver_name) as saver: config = saver.put( diff --git a/libs/langgraph/tests/__snapshots__/test_large_cases.ambr b/libs/langgraph/tests/__snapshots__/test_large_cases.ambr index 6ba27e7ed..017b1aa9b 100644 --- a/libs/langgraph/tests/__snapshots__/test_large_cases.ambr +++ b/libs/langgraph/tests/__snapshots__/test_large_cases.ambr @@ -135,6 +135,40 @@ ''' # --- +# name: test_branch_then[postgres_shallow] + ''' + graph TD; + __start__ --> prepare; + finish --> __end__; + prepare -.-> tool_two_slow; + tool_two_slow --> finish; + prepare -.-> tool_two_fast; + tool_two_fast --> finish; + + ''' +# --- +# name: test_branch_then[postgres_shallow].1 + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([

__start__

]):::first + prepare(prepare) + tool_two_slow(tool_two_slow) + tool_two_fast(tool_two_fast) + finish(finish) + __end__([

__end__

]):::last + __start__ --> prepare; + finish --> __end__; + prepare -.-> tool_two_slow; + tool_two_slow --> finish; + prepare -.-> tool_two_fast; + tool_two_fast --> finish; + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- # name: test_branch_then[sqlite] ''' graph TD; @@ -1269,6 +1303,281 @@ ''' # --- +# name: test_conditional_graph[postgres_shallow] + ''' + { + "nodes": [ + { + "id": "__start__", + "type": "schema", + "data": "__start__" + }, + { + "id": "agent", + "type": "runnable", + "data": { + "id": [ + "langchain", + "schema", + "runnable", + "RunnableAssign" + ], + "name": "agent" + } + }, + { + "id": "tools", + "type": "runnable", + "data": { + "id": [ + "langgraph", + "utils", + "runnable", + "RunnableCallable" + ], + "name": "tools" + }, + "metadata": { + "parents": {}, + "version": 2, + "variant": "b" + } + }, + { + "id": "__end__", + "type": "schema", + "data": "__end__" + } + ], + "edges": [ + { + "source": "__start__", + "target": "agent" + }, + { + "source": "tools", + "target": "agent" + }, + { + "source": "agent", + "target": "tools", + "data": "continue", + "conditional": true + }, + { + "source": "agent", + "target": "__end__", + "data": "exit", + "conditional": true + } + ] + } + ''' +# --- +# name: test_conditional_graph[postgres_shallow].1 + ''' + graph TD; + __start__ --> agent; + tools --> agent; + agent -.  continue  .-> tools; + agent -.  exit  .-> __end__; + + ''' +# --- +# name: test_conditional_graph[postgres_shallow].2 + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([

__start__

]):::first + agent(agent) + tools(tools
parents = {} + version = 2 + variant = b) + __end__([

__end__

]):::last + __start__ --> agent; + tools --> agent; + agent -.  continue  .-> tools; + agent -.  exit  .-> __end__; + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- +# name: test_conditional_graph[postgres_shallow].3 + ''' + { + "nodes": [ + { + "id": "__start__", + "type": "schema", + "data": "__start__" + }, + { + "id": "agent", + "type": "runnable", + "data": { + "id": [ + "langchain", + "schema", + "runnable", + "RunnableAssign" + ], + "name": "agent" + } + }, + { + "id": "tools", + "type": "runnable", + "data": { + "id": [ + "langgraph", + "utils", + "runnable", + "RunnableCallable" + ], + "name": "tools" + }, + "metadata": { + "parents": {}, + "version": 2, + "variant": "b" + } + }, + { + "id": "__end__", + "type": "schema", + "data": "__end__" + } + ], + "edges": [ + { + "source": "__start__", + "target": "agent" + }, + { + "source": "tools", + "target": "agent" + }, + { + "source": "agent", + "target": "tools", + "data": "continue", + "conditional": true + }, + { + "source": "agent", + "target": "__end__", + "data": "exit", + "conditional": true + } + ] + } + ''' +# --- +# name: test_conditional_graph[postgres_shallow].4 + ''' + graph TD; + __start__ --> agent; + tools --> agent; + agent -.  continue  .-> tools; + agent -.  exit  .-> __end__; + + ''' +# --- +# name: test_conditional_graph[postgres_shallow].5 + dict({ + 'edges': list([ + dict({ + 'source': '__start__', + 'target': 'agent', + }), + dict({ + 'source': 'tools', + 'target': 'agent', + }), + dict({ + 'conditional': True, + 'data': 'continue', + 'source': 'agent', + 'target': 'tools', + }), + dict({ + 'conditional': True, + 'data': 'exit', + 'source': 'agent', + 'target': '__end__', + }), + ]), + 'nodes': list([ + dict({ + 'data': '__start__', + 'id': '__start__', + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'schema', + 'runnable', + 'RunnableAssign', + ]), + 'name': 'agent', + }), + 'id': 'agent', + 'metadata': dict({ + '__interrupt': 'after', + }), + 'type': 'runnable', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langgraph', + 'utils', + 'runnable', + 'RunnableCallable', + ]), + 'name': 'tools', + }), + 'id': 'tools', + 'metadata': dict({ + 'parents': dict({ + }), + 'variant': 'b', + 'version': 2, + }), + 'type': 'runnable', + }), + dict({ + 'data': '__end__', + 'id': '__end__', + 'type': 'schema', + }), + ]), + }) +# --- +# name: test_conditional_graph[postgres_shallow].6 + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([

__start__

]):::first + agent(agent
__interrupt = after) + tools(tools
parents = {} + version = 2 + variant = b) + __end__([

__end__

]):::last + __start__ --> agent; + tools --> agent; + agent -.  continue  .-> tools; + agent -.  exit  .-> __end__; + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- # name: test_conditional_graph[sqlite] ''' { @@ -1872,6 +2181,88 @@ ''' # --- +# name: test_conditional_state_graph[postgres_shallow] + '{"$defs": {"AgentAction": {"description": "Represents a request to execute an action by an agent.\\n\\nThe action consists of the name of the tool to execute and the input to pass\\nto the tool. The log is used to pass along extra information about the action.", "properties": {"tool": {"title": "Tool", "type": "string"}, "tool_input": {"anyOf": [{"type": "string"}, {"type": "object"}], "title": "Tool Input"}, "log": {"title": "Log", "type": "string"}, "type": {"const": "AgentAction", "default": "AgentAction", "enum": ["AgentAction"], "title": "Type", "type": "string"}}, "required": ["tool", "tool_input", "log"], "title": "AgentAction", "type": "object"}, "AgentFinish": {"description": "Final return value of an ActionAgent.\\n\\nAgents return an AgentFinish when they have reached a stopping condition.", "properties": {"return_values": {"title": "Return Values", "type": "object"}, "log": {"title": "Log", "type": "string"}, "type": {"const": "AgentFinish", "default": "AgentFinish", "enum": ["AgentFinish"], "title": "Type", "type": "string"}}, "required": ["return_values", "log"], "title": "AgentFinish", "type": "object"}}, "properties": {"input": {"default": null, "title": "Input", "type": "string"}, "agent_outcome": {"anyOf": [{"$ref": "#/$defs/AgentAction"}, {"$ref": "#/$defs/AgentFinish"}, {"type": "null"}], "default": null, "title": "Agent Outcome"}, "intermediate_steps": {"default": null, "items": {"maxItems": 2, "minItems": 2, "prefixItems": [{"$ref": "#/$defs/AgentAction"}, {"type": "string"}], "type": "array"}, "title": "Intermediate Steps", "type": "array"}}, "title": "LangGraphInput", "type": "object"}' +# --- +# name: test_conditional_state_graph[postgres_shallow].1 + '{"$defs": {"AgentAction": {"description": "Represents a request to execute an action by an agent.\\n\\nThe action consists of the name of the tool to execute and the input to pass\\nto the tool. The log is used to pass along extra information about the action.", "properties": {"tool": {"title": "Tool", "type": "string"}, "tool_input": {"anyOf": [{"type": "string"}, {"type": "object"}], "title": "Tool Input"}, "log": {"title": "Log", "type": "string"}, "type": {"const": "AgentAction", "default": "AgentAction", "enum": ["AgentAction"], "title": "Type", "type": "string"}}, "required": ["tool", "tool_input", "log"], "title": "AgentAction", "type": "object"}, "AgentFinish": {"description": "Final return value of an ActionAgent.\\n\\nAgents return an AgentFinish when they have reached a stopping condition.", "properties": {"return_values": {"title": "Return Values", "type": "object"}, "log": {"title": "Log", "type": "string"}, "type": {"const": "AgentFinish", "default": "AgentFinish", "enum": ["AgentFinish"], "title": "Type", "type": "string"}}, "required": ["return_values", "log"], "title": "AgentFinish", "type": "object"}}, "properties": {"input": {"default": null, "title": "Input", "type": "string"}, "agent_outcome": {"anyOf": [{"$ref": "#/$defs/AgentAction"}, {"$ref": "#/$defs/AgentFinish"}, {"type": "null"}], "default": null, "title": "Agent Outcome"}, "intermediate_steps": {"default": null, "items": {"maxItems": 2, "minItems": 2, "prefixItems": [{"$ref": "#/$defs/AgentAction"}, {"type": "string"}], "type": "array"}, "title": "Intermediate Steps", "type": "array"}}, "title": "LangGraphOutput", "type": "object"}' +# --- +# name: test_conditional_state_graph[postgres_shallow].2 + ''' + { + "nodes": [ + { + "id": "__start__", + "type": "schema", + "data": "__start__" + }, + { + "id": "agent", + "type": "runnable", + "data": { + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "name": "agent" + } + }, + { + "id": "tools", + "type": "runnable", + "data": { + "id": [ + "langgraph", + "utils", + "runnable", + "RunnableCallable" + ], + "name": "tools" + } + }, + { + "id": "__end__", + "type": "schema", + "data": "__end__" + } + ], + "edges": [ + { + "source": "__start__", + "target": "agent" + }, + { + "source": "tools", + "target": "agent" + }, + { + "source": "agent", + "target": "tools", + "data": "continue", + "conditional": true + }, + { + "source": "agent", + "target": "__end__", + "data": "exit", + "conditional": true + } + ] + } + ''' +# --- +# name: test_conditional_state_graph[postgres_shallow].3 + ''' + graph TD; + __start__ --> agent; + tools --> agent; + agent -.  continue  .-> tools; + agent -.  exit  .-> __end__; + + ''' +# --- # name: test_conditional_state_graph[sqlite] '{"$defs": {"AgentAction": {"description": "Represents a request to execute an action by an agent.\\n\\nThe action consists of the name of the tool to execute and the input to pass\\nto the tool. The log is used to pass along extra information about the action.", "properties": {"tool": {"title": "Tool", "type": "string"}, "tool_input": {"anyOf": [{"type": "string"}, {"type": "object"}], "title": "Tool Input"}, "log": {"title": "Log", "type": "string"}, "type": {"const": "AgentAction", "default": "AgentAction", "enum": ["AgentAction"], "title": "Type", "type": "string"}}, "required": ["tool", "tool_input", "log"], "title": "AgentAction", "type": "object"}, "AgentFinish": {"description": "Final return value of an ActionAgent.\\n\\nAgents return an AgentFinish when they have reached a stopping condition.", "properties": {"return_values": {"title": "Return Values", "type": "object"}, "log": {"title": "Log", "type": "string"}, "type": {"const": "AgentFinish", "default": "AgentFinish", "enum": ["AgentFinish"], "title": "Type", "type": "string"}}, "required": ["return_values", "log"], "title": "AgentFinish", "type": "object"}}, "properties": {"input": {"default": null, "title": "Input", "type": "string"}, "agent_outcome": {"anyOf": [{"$ref": "#/$defs/AgentAction"}, {"$ref": "#/$defs/AgentFinish"}, {"type": "null"}], "default": null, "title": "Agent Outcome"}, "intermediate_steps": {"default": null, "items": {"maxItems": 2, "minItems": 2, "prefixItems": [{"$ref": "#/$defs/AgentAction"}, {"type": "string"}], "type": "array"}, "title": "Intermediate Steps", "type": "array"}}, "title": "LangGraphInput", "type": "object"}' # --- @@ -2278,6 +2669,87 @@ ''' # --- +# name: test_message_graph[postgres_shallow] + '{"$defs": {"AIMessage": {"additionalProperties": true, "description": "Message from an AI.\\n\\nAIMessage is returned from a chat model as a response to a prompt.\\n\\nThis message represents the output of the model and consists of both\\nthe raw output as returned by the model together standardized fields\\n(e.g., tool calls, usage metadata) added by the LangChain framework.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ai", "default": "ai", "enum": ["ai"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}, "tool_calls": {"default": [], "items": {"$ref": "#/$defs/ToolCall"}, "title": "Tool Calls", "type": "array"}, "invalid_tool_calls": {"default": [], "items": {"$ref": "#/$defs/InvalidToolCall"}, "title": "Invalid Tool Calls", "type": "array"}, "usage_metadata": {"anyOf": [{"$ref": "#/$defs/UsageMetadata"}, {"type": "null"}], "default": null}}, "required": ["content"], "title": "AIMessage", "type": "object"}, "AIMessageChunk": {"additionalProperties": true, "description": "Message chunk from an AI.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "AIMessageChunk", "default": "AIMessageChunk", "enum": ["AIMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}, "tool_calls": {"default": [], "items": {"$ref": "#/$defs/ToolCall"}, "title": "Tool Calls", "type": "array"}, "invalid_tool_calls": {"default": [], "items": {"$ref": "#/$defs/InvalidToolCall"}, "title": "Invalid Tool Calls", "type": "array"}, "usage_metadata": {"anyOf": [{"$ref": "#/$defs/UsageMetadata"}, {"type": "null"}], "default": null}, "tool_call_chunks": {"default": [], "items": {"$ref": "#/$defs/ToolCallChunk"}, "title": "Tool Call Chunks", "type": "array"}}, "required": ["content"], "title": "AIMessageChunk", "type": "object"}, "ChatMessage": {"additionalProperties": true, "description": "Message that can be assigned an arbitrary speaker (i.e. role).", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "chat", "default": "chat", "enum": ["chat"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "role": {"title": "Role", "type": "string"}}, "required": ["content", "role"], "title": "ChatMessage", "type": "object"}, "ChatMessageChunk": {"additionalProperties": true, "description": "Chat Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ChatMessageChunk", "default": "ChatMessageChunk", "enum": ["ChatMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "role": {"title": "Role", "type": "string"}}, "required": ["content", "role"], "title": "ChatMessageChunk", "type": "object"}, "FunctionMessage": {"additionalProperties": true, "description": "Message for passing the result of executing a tool back to a model.\\n\\nFunctionMessage are an older version of the ToolMessage schema, and\\ndo not contain the tool_call_id field.\\n\\nThe tool_call_id field is used to associate the tool call request with the\\ntool call response. This is useful in situations where a chat model is able\\nto request multiple tool calls in parallel.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "function", "default": "function", "enum": ["function"], "title": "Type", "type": "string"}, "name": {"title": "Name", "type": "string"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "name"], "title": "FunctionMessage", "type": "object"}, "FunctionMessageChunk": {"additionalProperties": true, "description": "Function Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "FunctionMessageChunk", "default": "FunctionMessageChunk", "enum": ["FunctionMessageChunk"], "title": "Type", "type": "string"}, "name": {"title": "Name", "type": "string"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "name"], "title": "FunctionMessageChunk", "type": "object"}, "HumanMessage": {"additionalProperties": true, "description": "Message from a human.\\n\\nHumanMessages are messages that are passed in from a human to the model.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import HumanMessage, SystemMessage\\n\\n messages = [\\n SystemMessage(\\n content=\\"You are a helpful assistant! Your name is Bob.\\"\\n ),\\n HumanMessage(\\n content=\\"What is your name?\\"\\n )\\n ]\\n\\n # Instantiate a chat model and invoke it with the messages\\n model = ...\\n print(model.invoke(messages))", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "human", "default": "human", "enum": ["human"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}}, "required": ["content"], "title": "HumanMessage", "type": "object"}, "HumanMessageChunk": {"additionalProperties": true, "description": "Human Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "HumanMessageChunk", "default": "HumanMessageChunk", "enum": ["HumanMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}}, "required": ["content"], "title": "HumanMessageChunk", "type": "object"}, "InputTokenDetails": {"description": "Breakdown of input token counts.\\n\\nDoes *not* need to sum to full input token count. Does *not* need to have all keys.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"audio\\": 10,\\n \\"cache_creation\\": 200,\\n \\"cache_read\\": 100,\\n }\\n\\n.. versionadded:: 0.3.9", "properties": {"audio": {"title": "Audio", "type": "integer"}, "cache_creation": {"title": "Cache Creation", "type": "integer"}, "cache_read": {"title": "Cache Read", "type": "integer"}}, "title": "InputTokenDetails", "type": "object"}, "InvalidToolCall": {"description": "Allowance for errors made by LLM.\\n\\nHere we add an `error` key to surface errors made during generation\\n(e.g., invalid JSON arguments.)", "properties": {"name": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Name"}, "args": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Args"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "error": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Error"}, "type": {"const": "invalid_tool_call", "enum": ["invalid_tool_call"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id", "error"], "title": "InvalidToolCall", "type": "object"}, "OutputTokenDetails": {"description": "Breakdown of output token counts.\\n\\nDoes *not* need to sum to full output token count. Does *not* need to have all keys.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"audio\\": 10,\\n \\"reasoning\\": 200,\\n }\\n\\n.. versionadded:: 0.3.9", "properties": {"audio": {"title": "Audio", "type": "integer"}, "reasoning": {"title": "Reasoning", "type": "integer"}}, "title": "OutputTokenDetails", "type": "object"}, "SystemMessage": {"additionalProperties": true, "description": "Message for priming AI behavior.\\n\\nThe system message is usually passed in as the first of a sequence\\nof input messages.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import HumanMessage, SystemMessage\\n\\n messages = [\\n SystemMessage(\\n content=\\"You are a helpful assistant! Your name is Bob.\\"\\n ),\\n HumanMessage(\\n content=\\"What is your name?\\"\\n )\\n ]\\n\\n # Define a chat model and invoke it with the messages\\n print(model.invoke(messages))", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "system", "default": "system", "enum": ["system"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content"], "title": "SystemMessage", "type": "object"}, "SystemMessageChunk": {"additionalProperties": true, "description": "System Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "SystemMessageChunk", "default": "SystemMessageChunk", "enum": ["SystemMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content"], "title": "SystemMessageChunk", "type": "object"}, "ToolCall": {"description": "Represents a request to call a tool.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"name\\": \\"foo\\",\\n \\"args\\": {\\"a\\": 1},\\n \\"id\\": \\"123\\"\\n }\\n\\n This represents a request to call the tool named \\"foo\\" with arguments {\\"a\\": 1}\\n and an identifier of \\"123\\".", "properties": {"name": {"title": "Name", "type": "string"}, "args": {"title": "Args", "type": "object"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "type": {"const": "tool_call", "enum": ["tool_call"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id"], "title": "ToolCall", "type": "object"}, "ToolCallChunk": {"description": "A chunk of a tool call (e.g., as part of a stream).\\n\\nWhen merging ToolCallChunks (e.g., via AIMessageChunk.__add__),\\nall string attributes are concatenated. Chunks are only merged if their\\nvalues of `index` are equal and not None.\\n\\nExample:\\n\\n.. code-block:: python\\n\\n left_chunks = [ToolCallChunk(name=\\"foo\\", args=\'{\\"a\\":\', index=0)]\\n right_chunks = [ToolCallChunk(name=None, args=\'1}\', index=0)]\\n\\n (\\n AIMessageChunk(content=\\"\\", tool_call_chunks=left_chunks)\\n + AIMessageChunk(content=\\"\\", tool_call_chunks=right_chunks)\\n ).tool_call_chunks == [ToolCallChunk(name=\'foo\', args=\'{\\"a\\":1}\', index=0)]", "properties": {"name": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Name"}, "args": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Args"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "index": {"anyOf": [{"type": "integer"}, {"type": "null"}], "title": "Index"}, "type": {"const": "tool_call_chunk", "enum": ["tool_call_chunk"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id", "index"], "title": "ToolCallChunk", "type": "object"}, "ToolMessage": {"additionalProperties": true, "description": "Message for passing the result of executing a tool back to a model.\\n\\nToolMessages contain the result of a tool invocation. Typically, the result\\nis encoded inside the `content` field.\\n\\nExample: A ToolMessage representing a result of 42 from a tool call with id\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import ToolMessage\\n\\n ToolMessage(content=\'42\', tool_call_id=\'call_Jja7J89XsjrOLA5r!MEOW!SL\')\\n\\n\\nExample: A ToolMessage where only part of the tool output is sent to the model\\n and the full output is passed in to artifact.\\n\\n .. versionadded:: 0.2.17\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import ToolMessage\\n\\n tool_output = {\\n \\"stdout\\": \\"From the graph we can see that the correlation between x and y is ...\\",\\n \\"stderr\\": None,\\n \\"artifacts\\": {\\"type\\": \\"image\\", \\"base64_data\\": \\"/9j/4gIcSU...\\"},\\n }\\n\\n ToolMessage(\\n content=tool_output[\\"stdout\\"],\\n artifact=tool_output,\\n tool_call_id=\'call_Jja7J89XsjrOLA5r!MEOW!SL\',\\n )\\n\\nThe tool_call_id field is used to associate the tool call request with the\\ntool call response. This is useful in situations where a chat model is able\\nto request multiple tool calls in parallel.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "tool", "default": "tool", "enum": ["tool"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "tool_call_id": {"title": "Tool Call Id", "type": "string"}, "artifact": {"default": null, "title": "Artifact"}, "status": {"default": "success", "enum": ["success", "error"], "title": "Status", "type": "string"}}, "required": ["content", "tool_call_id"], "title": "ToolMessage", "type": "object"}, "ToolMessageChunk": {"additionalProperties": true, "description": "Tool Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ToolMessageChunk", "default": "ToolMessageChunk", "enum": ["ToolMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "tool_call_id": {"title": "Tool Call Id", "type": "string"}, "artifact": {"default": null, "title": "Artifact"}, "status": {"default": "success", "enum": ["success", "error"], "title": "Status", "type": "string"}}, "required": ["content", "tool_call_id"], "title": "ToolMessageChunk", "type": "object"}, "UsageMetadata": {"description": "Usage metadata for a message, such as token counts.\\n\\nThis is a standard representation of token usage that is consistent across models.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"input_tokens\\": 350,\\n \\"output_tokens\\": 240,\\n \\"total_tokens\\": 590,\\n \\"input_token_details\\": {\\n \\"audio\\": 10,\\n \\"cache_creation\\": 200,\\n \\"cache_read\\": 100,\\n },\\n \\"output_token_details\\": {\\n \\"audio\\": 10,\\n \\"reasoning\\": 200,\\n }\\n }\\n\\n.. versionchanged:: 0.3.9\\n\\n Added ``input_token_details`` and ``output_token_details``.", "properties": {"input_tokens": {"title": "Input Tokens", "type": "integer"}, "output_tokens": {"title": "Output Tokens", "type": "integer"}, "total_tokens": {"title": "Total Tokens", "type": "integer"}, "input_token_details": {"$ref": "#/$defs/InputTokenDetails"}, "output_token_details": {"$ref": "#/$defs/OutputTokenDetails"}}, "required": ["input_tokens", "output_tokens", "total_tokens"], "title": "UsageMetadata", "type": "object"}}, "default": null, "items": {"oneOf": [{"$ref": "#/$defs/AIMessage"}, {"$ref": "#/$defs/HumanMessage"}, {"$ref": "#/$defs/ChatMessage"}, {"$ref": "#/$defs/SystemMessage"}, {"$ref": "#/$defs/FunctionMessage"}, {"$ref": "#/$defs/ToolMessage"}, {"$ref": "#/$defs/AIMessageChunk"}, {"$ref": "#/$defs/HumanMessageChunk"}, {"$ref": "#/$defs/ChatMessageChunk"}, {"$ref": "#/$defs/SystemMessageChunk"}, {"$ref": "#/$defs/FunctionMessageChunk"}, {"$ref": "#/$defs/ToolMessageChunk"}]}, "title": "LangGraphInput", "type": "array"}' +# --- +# name: test_message_graph[postgres_shallow].1 + '{"$defs": {"AIMessage": {"additionalProperties": true, "description": "Message from an AI.\\n\\nAIMessage is returned from a chat model as a response to a prompt.\\n\\nThis message represents the output of the model and consists of both\\nthe raw output as returned by the model together standardized fields\\n(e.g., tool calls, usage metadata) added by the LangChain framework.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ai", "default": "ai", "enum": ["ai"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}, "tool_calls": {"default": [], "items": {"$ref": "#/$defs/ToolCall"}, "title": "Tool Calls", "type": "array"}, "invalid_tool_calls": {"default": [], "items": {"$ref": "#/$defs/InvalidToolCall"}, "title": "Invalid Tool Calls", "type": "array"}, "usage_metadata": {"anyOf": [{"$ref": "#/$defs/UsageMetadata"}, {"type": "null"}], "default": null}}, "required": ["content"], "title": "AIMessage", "type": "object"}, "AIMessageChunk": {"additionalProperties": true, "description": "Message chunk from an AI.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "AIMessageChunk", "default": "AIMessageChunk", "enum": ["AIMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}, "tool_calls": {"default": [], "items": {"$ref": "#/$defs/ToolCall"}, "title": "Tool Calls", "type": "array"}, "invalid_tool_calls": {"default": [], "items": {"$ref": "#/$defs/InvalidToolCall"}, "title": "Invalid Tool Calls", "type": "array"}, "usage_metadata": {"anyOf": [{"$ref": "#/$defs/UsageMetadata"}, {"type": "null"}], "default": null}, "tool_call_chunks": {"default": [], "items": {"$ref": "#/$defs/ToolCallChunk"}, "title": "Tool Call Chunks", "type": "array"}}, "required": ["content"], "title": "AIMessageChunk", "type": "object"}, "ChatMessage": {"additionalProperties": true, "description": "Message that can be assigned an arbitrary speaker (i.e. role).", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "chat", "default": "chat", "enum": ["chat"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "role": {"title": "Role", "type": "string"}}, "required": ["content", "role"], "title": "ChatMessage", "type": "object"}, "ChatMessageChunk": {"additionalProperties": true, "description": "Chat Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ChatMessageChunk", "default": "ChatMessageChunk", "enum": ["ChatMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "role": {"title": "Role", "type": "string"}}, "required": ["content", "role"], "title": "ChatMessageChunk", "type": "object"}, "FunctionMessage": {"additionalProperties": true, "description": "Message for passing the result of executing a tool back to a model.\\n\\nFunctionMessage are an older version of the ToolMessage schema, and\\ndo not contain the tool_call_id field.\\n\\nThe tool_call_id field is used to associate the tool call request with the\\ntool call response. This is useful in situations where a chat model is able\\nto request multiple tool calls in parallel.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "function", "default": "function", "enum": ["function"], "title": "Type", "type": "string"}, "name": {"title": "Name", "type": "string"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "name"], "title": "FunctionMessage", "type": "object"}, "FunctionMessageChunk": {"additionalProperties": true, "description": "Function Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "FunctionMessageChunk", "default": "FunctionMessageChunk", "enum": ["FunctionMessageChunk"], "title": "Type", "type": "string"}, "name": {"title": "Name", "type": "string"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "name"], "title": "FunctionMessageChunk", "type": "object"}, "HumanMessage": {"additionalProperties": true, "description": "Message from a human.\\n\\nHumanMessages are messages that are passed in from a human to the model.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import HumanMessage, SystemMessage\\n\\n messages = [\\n SystemMessage(\\n content=\\"You are a helpful assistant! Your name is Bob.\\"\\n ),\\n HumanMessage(\\n content=\\"What is your name?\\"\\n )\\n ]\\n\\n # Instantiate a chat model and invoke it with the messages\\n model = ...\\n print(model.invoke(messages))", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "human", "default": "human", "enum": ["human"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}}, "required": ["content"], "title": "HumanMessage", "type": "object"}, "HumanMessageChunk": {"additionalProperties": true, "description": "Human Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "HumanMessageChunk", "default": "HumanMessageChunk", "enum": ["HumanMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}}, "required": ["content"], "title": "HumanMessageChunk", "type": "object"}, "InputTokenDetails": {"description": "Breakdown of input token counts.\\n\\nDoes *not* need to sum to full input token count. Does *not* need to have all keys.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"audio\\": 10,\\n \\"cache_creation\\": 200,\\n \\"cache_read\\": 100,\\n }\\n\\n.. versionadded:: 0.3.9", "properties": {"audio": {"title": "Audio", "type": "integer"}, "cache_creation": {"title": "Cache Creation", "type": "integer"}, "cache_read": {"title": "Cache Read", "type": "integer"}}, "title": "InputTokenDetails", "type": "object"}, "InvalidToolCall": {"description": "Allowance for errors made by LLM.\\n\\nHere we add an `error` key to surface errors made during generation\\n(e.g., invalid JSON arguments.)", "properties": {"name": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Name"}, "args": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Args"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "error": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Error"}, "type": {"const": "invalid_tool_call", "enum": ["invalid_tool_call"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id", "error"], "title": "InvalidToolCall", "type": "object"}, "OutputTokenDetails": {"description": "Breakdown of output token counts.\\n\\nDoes *not* need to sum to full output token count. Does *not* need to have all keys.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"audio\\": 10,\\n \\"reasoning\\": 200,\\n }\\n\\n.. versionadded:: 0.3.9", "properties": {"audio": {"title": "Audio", "type": "integer"}, "reasoning": {"title": "Reasoning", "type": "integer"}}, "title": "OutputTokenDetails", "type": "object"}, "SystemMessage": {"additionalProperties": true, "description": "Message for priming AI behavior.\\n\\nThe system message is usually passed in as the first of a sequence\\nof input messages.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import HumanMessage, SystemMessage\\n\\n messages = [\\n SystemMessage(\\n content=\\"You are a helpful assistant! Your name is Bob.\\"\\n ),\\n HumanMessage(\\n content=\\"What is your name?\\"\\n )\\n ]\\n\\n # Define a chat model and invoke it with the messages\\n print(model.invoke(messages))", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "system", "default": "system", "enum": ["system"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content"], "title": "SystemMessage", "type": "object"}, "SystemMessageChunk": {"additionalProperties": true, "description": "System Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "SystemMessageChunk", "default": "SystemMessageChunk", "enum": ["SystemMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content"], "title": "SystemMessageChunk", "type": "object"}, "ToolCall": {"description": "Represents a request to call a tool.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"name\\": \\"foo\\",\\n \\"args\\": {\\"a\\": 1},\\n \\"id\\": \\"123\\"\\n }\\n\\n This represents a request to call the tool named \\"foo\\" with arguments {\\"a\\": 1}\\n and an identifier of \\"123\\".", "properties": {"name": {"title": "Name", "type": "string"}, "args": {"title": "Args", "type": "object"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "type": {"const": "tool_call", "enum": ["tool_call"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id"], "title": "ToolCall", "type": "object"}, "ToolCallChunk": {"description": "A chunk of a tool call (e.g., as part of a stream).\\n\\nWhen merging ToolCallChunks (e.g., via AIMessageChunk.__add__),\\nall string attributes are concatenated. Chunks are only merged if their\\nvalues of `index` are equal and not None.\\n\\nExample:\\n\\n.. code-block:: python\\n\\n left_chunks = [ToolCallChunk(name=\\"foo\\", args=\'{\\"a\\":\', index=0)]\\n right_chunks = [ToolCallChunk(name=None, args=\'1}\', index=0)]\\n\\n (\\n AIMessageChunk(content=\\"\\", tool_call_chunks=left_chunks)\\n + AIMessageChunk(content=\\"\\", tool_call_chunks=right_chunks)\\n ).tool_call_chunks == [ToolCallChunk(name=\'foo\', args=\'{\\"a\\":1}\', index=0)]", "properties": {"name": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Name"}, "args": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Args"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "index": {"anyOf": [{"type": "integer"}, {"type": "null"}], "title": "Index"}, "type": {"const": "tool_call_chunk", "enum": ["tool_call_chunk"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id", "index"], "title": "ToolCallChunk", "type": "object"}, "ToolMessage": {"additionalProperties": true, "description": "Message for passing the result of executing a tool back to a model.\\n\\nToolMessages contain the result of a tool invocation. Typically, the result\\nis encoded inside the `content` field.\\n\\nExample: A ToolMessage representing a result of 42 from a tool call with id\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import ToolMessage\\n\\n ToolMessage(content=\'42\', tool_call_id=\'call_Jja7J89XsjrOLA5r!MEOW!SL\')\\n\\n\\nExample: A ToolMessage where only part of the tool output is sent to the model\\n and the full output is passed in to artifact.\\n\\n .. versionadded:: 0.2.17\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import ToolMessage\\n\\n tool_output = {\\n \\"stdout\\": \\"From the graph we can see that the correlation between x and y is ...\\",\\n \\"stderr\\": None,\\n \\"artifacts\\": {\\"type\\": \\"image\\", \\"base64_data\\": \\"/9j/4gIcSU...\\"},\\n }\\n\\n ToolMessage(\\n content=tool_output[\\"stdout\\"],\\n artifact=tool_output,\\n tool_call_id=\'call_Jja7J89XsjrOLA5r!MEOW!SL\',\\n )\\n\\nThe tool_call_id field is used to associate the tool call request with the\\ntool call response. This is useful in situations where a chat model is able\\nto request multiple tool calls in parallel.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "tool", "default": "tool", "enum": ["tool"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "tool_call_id": {"title": "Tool Call Id", "type": "string"}, "artifact": {"default": null, "title": "Artifact"}, "status": {"default": "success", "enum": ["success", "error"], "title": "Status", "type": "string"}}, "required": ["content", "tool_call_id"], "title": "ToolMessage", "type": "object"}, "ToolMessageChunk": {"additionalProperties": true, "description": "Tool Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ToolMessageChunk", "default": "ToolMessageChunk", "enum": ["ToolMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "tool_call_id": {"title": "Tool Call Id", "type": "string"}, "artifact": {"default": null, "title": "Artifact"}, "status": {"default": "success", "enum": ["success", "error"], "title": "Status", "type": "string"}}, "required": ["content", "tool_call_id"], "title": "ToolMessageChunk", "type": "object"}, "UsageMetadata": {"description": "Usage metadata for a message, such as token counts.\\n\\nThis is a standard representation of token usage that is consistent across models.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"input_tokens\\": 350,\\n \\"output_tokens\\": 240,\\n \\"total_tokens\\": 590,\\n \\"input_token_details\\": {\\n \\"audio\\": 10,\\n \\"cache_creation\\": 200,\\n \\"cache_read\\": 100,\\n },\\n \\"output_token_details\\": {\\n \\"audio\\": 10,\\n \\"reasoning\\": 200,\\n }\\n }\\n\\n.. versionchanged:: 0.3.9\\n\\n Added ``input_token_details`` and ``output_token_details``.", "properties": {"input_tokens": {"title": "Input Tokens", "type": "integer"}, "output_tokens": {"title": "Output Tokens", "type": "integer"}, "total_tokens": {"title": "Total Tokens", "type": "integer"}, "input_token_details": {"$ref": "#/$defs/InputTokenDetails"}, "output_token_details": {"$ref": "#/$defs/OutputTokenDetails"}}, "required": ["input_tokens", "output_tokens", "total_tokens"], "title": "UsageMetadata", "type": "object"}}, "default": null, "items": {"oneOf": [{"$ref": "#/$defs/AIMessage"}, {"$ref": "#/$defs/HumanMessage"}, {"$ref": "#/$defs/ChatMessage"}, {"$ref": "#/$defs/SystemMessage"}, {"$ref": "#/$defs/FunctionMessage"}, {"$ref": "#/$defs/ToolMessage"}, {"$ref": "#/$defs/AIMessageChunk"}, {"$ref": "#/$defs/HumanMessageChunk"}, {"$ref": "#/$defs/ChatMessageChunk"}, {"$ref": "#/$defs/SystemMessageChunk"}, {"$ref": "#/$defs/FunctionMessageChunk"}, {"$ref": "#/$defs/ToolMessageChunk"}]}, "title": "LangGraphOutput", "type": "array"}' +# --- +# name: test_message_graph[postgres_shallow].2 + ''' + { + "nodes": [ + { + "id": "__start__", + "type": "schema", + "data": "__start__" + }, + { + "id": "agent", + "type": "runnable", + "data": { + "id": [ + "tests", + "test_large_cases", + "FakeFuntionChatModel" + ], + "name": "agent" + } + }, + { + "id": "tools", + "type": "runnable", + "data": { + "id": [ + "langgraph", + "prebuilt", + "tool_node", + "ToolNode" + ], + "name": "tools" + } + }, + { + "id": "__end__", + "type": "schema", + "data": "__end__" + } + ], + "edges": [ + { + "source": "__start__", + "target": "agent" + }, + { + "source": "tools", + "target": "agent" + }, + { + "source": "agent", + "target": "tools", + "data": "continue", + "conditional": true + }, + { + "source": "agent", + "target": "__end__", + "data": "end", + "conditional": true + } + ] + } + ''' +# --- +# name: test_message_graph[postgres_shallow].3 + ''' + graph TD; + __start__ --> agent; + tools --> agent; + agent -.  continue  .-> tools; + agent -.  end  .-> __end__; + + ''' +# --- # name: test_message_graph[sqlite] '{"$defs": {"AIMessage": {"additionalProperties": true, "description": "Message from an AI.\\n\\nAIMessage is returned from a chat model as a response to a prompt.\\n\\nThis message represents the output of the model and consists of both\\nthe raw output as returned by the model together standardized fields\\n(e.g., tool calls, usage metadata) added by the LangChain framework.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ai", "default": "ai", "enum": ["ai"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}, "tool_calls": {"default": [], "items": {"$ref": "#/$defs/ToolCall"}, "title": "Tool Calls", "type": "array"}, "invalid_tool_calls": {"default": [], "items": {"$ref": "#/$defs/InvalidToolCall"}, "title": "Invalid Tool Calls", "type": "array"}, "usage_metadata": {"anyOf": [{"$ref": "#/$defs/UsageMetadata"}, {"type": "null"}], "default": null}}, "required": ["content"], "title": "AIMessage", "type": "object"}, "AIMessageChunk": {"additionalProperties": true, "description": "Message chunk from an AI.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "AIMessageChunk", "default": "AIMessageChunk", "enum": ["AIMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}, "tool_calls": {"default": [], "items": {"$ref": "#/$defs/ToolCall"}, "title": "Tool Calls", "type": "array"}, "invalid_tool_calls": {"default": [], "items": {"$ref": "#/$defs/InvalidToolCall"}, "title": "Invalid Tool Calls", "type": "array"}, "usage_metadata": {"anyOf": [{"$ref": "#/$defs/UsageMetadata"}, {"type": "null"}], "default": null}, "tool_call_chunks": {"default": [], "items": {"$ref": "#/$defs/ToolCallChunk"}, "title": "Tool Call Chunks", "type": "array"}}, "required": ["content"], "title": "AIMessageChunk", "type": "object"}, "ChatMessage": {"additionalProperties": true, "description": "Message that can be assigned an arbitrary speaker (i.e. role).", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "chat", "default": "chat", "enum": ["chat"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "role": {"title": "Role", "type": "string"}}, "required": ["content", "role"], "title": "ChatMessage", "type": "object"}, "ChatMessageChunk": {"additionalProperties": true, "description": "Chat Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ChatMessageChunk", "default": "ChatMessageChunk", "enum": ["ChatMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "role": {"title": "Role", "type": "string"}}, "required": ["content", "role"], "title": "ChatMessageChunk", "type": "object"}, "FunctionMessage": {"additionalProperties": true, "description": "Message for passing the result of executing a tool back to a model.\\n\\nFunctionMessage are an older version of the ToolMessage schema, and\\ndo not contain the tool_call_id field.\\n\\nThe tool_call_id field is used to associate the tool call request with the\\ntool call response. This is useful in situations where a chat model is able\\nto request multiple tool calls in parallel.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "function", "default": "function", "enum": ["function"], "title": "Type", "type": "string"}, "name": {"title": "Name", "type": "string"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "name"], "title": "FunctionMessage", "type": "object"}, "FunctionMessageChunk": {"additionalProperties": true, "description": "Function Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "FunctionMessageChunk", "default": "FunctionMessageChunk", "enum": ["FunctionMessageChunk"], "title": "Type", "type": "string"}, "name": {"title": "Name", "type": "string"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "name"], "title": "FunctionMessageChunk", "type": "object"}, "HumanMessage": {"additionalProperties": true, "description": "Message from a human.\\n\\nHumanMessages are messages that are passed in from a human to the model.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import HumanMessage, SystemMessage\\n\\n messages = [\\n SystemMessage(\\n content=\\"You are a helpful assistant! Your name is Bob.\\"\\n ),\\n HumanMessage(\\n content=\\"What is your name?\\"\\n )\\n ]\\n\\n # Instantiate a chat model and invoke it with the messages\\n model = ...\\n print(model.invoke(messages))", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "human", "default": "human", "enum": ["human"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}}, "required": ["content"], "title": "HumanMessage", "type": "object"}, "HumanMessageChunk": {"additionalProperties": true, "description": "Human Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "HumanMessageChunk", "default": "HumanMessageChunk", "enum": ["HumanMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}}, "required": ["content"], "title": "HumanMessageChunk", "type": "object"}, "InputTokenDetails": {"description": "Breakdown of input token counts.\\n\\nDoes *not* need to sum to full input token count. Does *not* need to have all keys.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"audio\\": 10,\\n \\"cache_creation\\": 200,\\n \\"cache_read\\": 100,\\n }\\n\\n.. versionadded:: 0.3.9", "properties": {"audio": {"title": "Audio", "type": "integer"}, "cache_creation": {"title": "Cache Creation", "type": "integer"}, "cache_read": {"title": "Cache Read", "type": "integer"}}, "title": "InputTokenDetails", "type": "object"}, "InvalidToolCall": {"description": "Allowance for errors made by LLM.\\n\\nHere we add an `error` key to surface errors made during generation\\n(e.g., invalid JSON arguments.)", "properties": {"name": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Name"}, "args": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Args"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "error": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Error"}, "type": {"const": "invalid_tool_call", "enum": ["invalid_tool_call"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id", "error"], "title": "InvalidToolCall", "type": "object"}, "OutputTokenDetails": {"description": "Breakdown of output token counts.\\n\\nDoes *not* need to sum to full output token count. Does *not* need to have all keys.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"audio\\": 10,\\n \\"reasoning\\": 200,\\n }\\n\\n.. versionadded:: 0.3.9", "properties": {"audio": {"title": "Audio", "type": "integer"}, "reasoning": {"title": "Reasoning", "type": "integer"}}, "title": "OutputTokenDetails", "type": "object"}, "SystemMessage": {"additionalProperties": true, "description": "Message for priming AI behavior.\\n\\nThe system message is usually passed in as the first of a sequence\\nof input messages.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import HumanMessage, SystemMessage\\n\\n messages = [\\n SystemMessage(\\n content=\\"You are a helpful assistant! Your name is Bob.\\"\\n ),\\n HumanMessage(\\n content=\\"What is your name?\\"\\n )\\n ]\\n\\n # Define a chat model and invoke it with the messages\\n print(model.invoke(messages))", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "system", "default": "system", "enum": ["system"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content"], "title": "SystemMessage", "type": "object"}, "SystemMessageChunk": {"additionalProperties": true, "description": "System Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "SystemMessageChunk", "default": "SystemMessageChunk", "enum": ["SystemMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content"], "title": "SystemMessageChunk", "type": "object"}, "ToolCall": {"description": "Represents a request to call a tool.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"name\\": \\"foo\\",\\n \\"args\\": {\\"a\\": 1},\\n \\"id\\": \\"123\\"\\n }\\n\\n This represents a request to call the tool named \\"foo\\" with arguments {\\"a\\": 1}\\n and an identifier of \\"123\\".", "properties": {"name": {"title": "Name", "type": "string"}, "args": {"title": "Args", "type": "object"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "type": {"const": "tool_call", "enum": ["tool_call"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id"], "title": "ToolCall", "type": "object"}, "ToolCallChunk": {"description": "A chunk of a tool call (e.g., as part of a stream).\\n\\nWhen merging ToolCallChunks (e.g., via AIMessageChunk.__add__),\\nall string attributes are concatenated. Chunks are only merged if their\\nvalues of `index` are equal and not None.\\n\\nExample:\\n\\n.. code-block:: python\\n\\n left_chunks = [ToolCallChunk(name=\\"foo\\", args=\'{\\"a\\":\', index=0)]\\n right_chunks = [ToolCallChunk(name=None, args=\'1}\', index=0)]\\n\\n (\\n AIMessageChunk(content=\\"\\", tool_call_chunks=left_chunks)\\n + AIMessageChunk(content=\\"\\", tool_call_chunks=right_chunks)\\n ).tool_call_chunks == [ToolCallChunk(name=\'foo\', args=\'{\\"a\\":1}\', index=0)]", "properties": {"name": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Name"}, "args": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Args"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "index": {"anyOf": [{"type": "integer"}, {"type": "null"}], "title": "Index"}, "type": {"const": "tool_call_chunk", "enum": ["tool_call_chunk"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id", "index"], "title": "ToolCallChunk", "type": "object"}, "ToolMessage": {"additionalProperties": true, "description": "Message for passing the result of executing a tool back to a model.\\n\\nToolMessages contain the result of a tool invocation. Typically, the result\\nis encoded inside the `content` field.\\n\\nExample: A ToolMessage representing a result of 42 from a tool call with id\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import ToolMessage\\n\\n ToolMessage(content=\'42\', tool_call_id=\'call_Jja7J89XsjrOLA5r!MEOW!SL\')\\n\\n\\nExample: A ToolMessage where only part of the tool output is sent to the model\\n and the full output is passed in to artifact.\\n\\n .. versionadded:: 0.2.17\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import ToolMessage\\n\\n tool_output = {\\n \\"stdout\\": \\"From the graph we can see that the correlation between x and y is ...\\",\\n \\"stderr\\": None,\\n \\"artifacts\\": {\\"type\\": \\"image\\", \\"base64_data\\": \\"/9j/4gIcSU...\\"},\\n }\\n\\n ToolMessage(\\n content=tool_output[\\"stdout\\"],\\n artifact=tool_output,\\n tool_call_id=\'call_Jja7J89XsjrOLA5r!MEOW!SL\',\\n )\\n\\nThe tool_call_id field is used to associate the tool call request with the\\ntool call response. This is useful in situations where a chat model is able\\nto request multiple tool calls in parallel.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "tool", "default": "tool", "enum": ["tool"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "tool_call_id": {"title": "Tool Call Id", "type": "string"}, "artifact": {"default": null, "title": "Artifact"}, "status": {"default": "success", "enum": ["success", "error"], "title": "Status", "type": "string"}}, "required": ["content", "tool_call_id"], "title": "ToolMessage", "type": "object"}, "ToolMessageChunk": {"additionalProperties": true, "description": "Tool Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ToolMessageChunk", "default": "ToolMessageChunk", "enum": ["ToolMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "tool_call_id": {"title": "Tool Call Id", "type": "string"}, "artifact": {"default": null, "title": "Artifact"}, "status": {"default": "success", "enum": ["success", "error"], "title": "Status", "type": "string"}}, "required": ["content", "tool_call_id"], "title": "ToolMessageChunk", "type": "object"}, "UsageMetadata": {"description": "Usage metadata for a message, such as token counts.\\n\\nThis is a standard representation of token usage that is consistent across models.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"input_tokens\\": 350,\\n \\"output_tokens\\": 240,\\n \\"total_tokens\\": 590,\\n \\"input_token_details\\": {\\n \\"audio\\": 10,\\n \\"cache_creation\\": 200,\\n \\"cache_read\\": 100,\\n },\\n \\"output_token_details\\": {\\n \\"audio\\": 10,\\n \\"reasoning\\": 200,\\n }\\n }\\n\\n.. versionchanged:: 0.3.9\\n\\n Added ``input_token_details`` and ``output_token_details``.", "properties": {"input_tokens": {"title": "Input Tokens", "type": "integer"}, "output_tokens": {"title": "Output Tokens", "type": "integer"}, "total_tokens": {"title": "Total Tokens", "type": "integer"}, "input_token_details": {"$ref": "#/$defs/InputTokenDetails"}, "output_token_details": {"$ref": "#/$defs/OutputTokenDetails"}}, "required": ["input_tokens", "output_tokens", "total_tokens"], "title": "UsageMetadata", "type": "object"}}, "default": null, "items": {"oneOf": [{"$ref": "#/$defs/AIMessage"}, {"$ref": "#/$defs/HumanMessage"}, {"$ref": "#/$defs/ChatMessage"}, {"$ref": "#/$defs/SystemMessage"}, {"$ref": "#/$defs/FunctionMessage"}, {"$ref": "#/$defs/ToolMessage"}, {"$ref": "#/$defs/AIMessageChunk"}, {"$ref": "#/$defs/HumanMessageChunk"}, {"$ref": "#/$defs/ChatMessageChunk"}, {"$ref": "#/$defs/SystemMessageChunk"}, {"$ref": "#/$defs/FunctionMessageChunk"}, {"$ref": "#/$defs/ToolMessageChunk"}]}, "title": "LangGraphInput", "type": "array"}' # --- @@ -2499,6 +2971,21 @@ ''' # --- +# name: test_send_react_interrupt_control[postgres_shallow] + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([

__start__

]):::first + agent(agent) + foo([foo]):::last + __start__ --> agent; + agent -.-> foo; + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- # name: test_send_react_interrupt_control[sqlite] ''' %%{init: {'flowchart': {'curve': 'linear'}}}%% @@ -2586,6 +3073,24 @@ ''' # --- +# name: test_start_branch_then[postgres_shallow] + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([

__start__

]):::first + tool_two_slow(tool_two_slow) + tool_two_fast(tool_two_fast) + __end__([

__end__

]):::last + __start__ -.-> tool_two_slow; + tool_two_slow --> __end__; + __start__ -.-> tool_two_fast; + tool_two_fast --> __end__; + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- # name: test_start_branch_then[sqlite] ''' %%{init: {'flowchart': {'curve': 'linear'}}}%% @@ -2704,6 +3209,31 @@ ''' # --- +# name: test_weather_subgraph[postgres_shallow] + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([

__start__

]):::first + router_node(router_node) + normal_llm_node(normal_llm_node) + weather_graph_model_node(model_node) + weather_graph_weather_node(weather_node
__interrupt = before) + __end__([

__end__

]):::last + __start__ --> router_node; + normal_llm_node --> __end__; + weather_graph_weather_node --> __end__; + router_node -.-> normal_llm_node; + router_node -.-> weather_graph_model_node; + router_node -.-> __end__; + subgraph weather_graph + weather_graph_model_node --> weather_graph_weather_node; + end + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- # name: test_weather_subgraph[sqlite] ''' %%{init: {'flowchart': {'curve': 'linear'}}}%% diff --git a/libs/langgraph/tests/__snapshots__/test_large_cases_async.ambr b/libs/langgraph/tests/__snapshots__/test_large_cases_async.ambr index a374d02d5..898dd4f89 100644 --- a/libs/langgraph/tests/__snapshots__/test_large_cases_async.ambr +++ b/libs/langgraph/tests/__snapshots__/test_large_cases_async.ambr @@ -99,6 +99,31 @@ ''' # --- +# name: test_weather_subgraph[postgres_aio_shallow] + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([

__start__

]):::first + router_node(router_node) + normal_llm_node(normal_llm_node) + weather_graph_model_node(model_node) + weather_graph_weather_node(weather_node
__interrupt = before) + __end__([

__end__

]):::last + __start__ --> router_node; + normal_llm_node --> __end__; + weather_graph_weather_node --> __end__; + router_node -.-> normal_llm_node; + router_node -.-> weather_graph_model_node; + router_node -.-> __end__; + subgraph weather_graph + weather_graph_model_node --> weather_graph_weather_node; + end + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- # name: test_weather_subgraph[sqlite_aio] ''' %%{init: {'flowchart': {'curve': 'linear'}}}%% diff --git a/libs/langgraph/tests/__snapshots__/test_pregel.ambr b/libs/langgraph/tests/__snapshots__/test_pregel.ambr index a44711ab3..5036f370e 100644 --- a/libs/langgraph/tests/__snapshots__/test_pregel.ambr +++ b/libs/langgraph/tests/__snapshots__/test_pregel.ambr @@ -2878,6 +2878,19 @@ ''' # --- +# name: test_in_one_fan_out_state_graph_waiting_edge[postgres_shallow] + ''' + graph TD; + __start__ --> rewrite_query; + analyzer_one --> retriever_one; + qa --> __end__; + retriever_one --> qa; + retriever_two --> qa; + rewrite_query --> analyzer_one; + rewrite_query --> retriever_two; + + ''' +# --- # name: test_in_one_fan_out_state_graph_waiting_edge[sqlite] ''' graph TD; @@ -3311,6 +3324,76 @@ 'type': 'object', }) # --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[postgres_shallow] + ''' + graph TD; + __start__ --> rewrite_query; + analyzer_one --> retriever_one; + qa --> __end__; + retriever_one --> qa; + retriever_two --> qa; + rewrite_query --> analyzer_one; + rewrite_query -.-> retriever_two; + + ''' +# --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[postgres_shallow].1 + dict({ + 'definitions': dict({ + 'InnerObject': dict({ + 'properties': dict({ + 'yo': dict({ + 'title': 'Yo', + 'type': 'integer', + }), + }), + 'required': list([ + 'yo', + ]), + 'title': 'InnerObject', + 'type': 'object', + }), + }), + 'properties': dict({ + 'inner': dict({ + '$ref': '#/definitions/InnerObject', + }), + 'query': dict({ + 'title': 'Query', + 'type': 'string', + }), + }), + 'required': list([ + 'query', + 'inner', + ]), + 'title': 'Input', + 'type': 'object', + }) +# --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[postgres_shallow].2 + dict({ + 'properties': dict({ + 'answer': dict({ + 'title': 'Answer', + 'type': 'string', + }), + 'docs': dict({ + 'items': dict({ + 'type': 'string', + }), + 'title': 'Docs', + 'type': 'array', + }), + }), + 'required': list([ + 'answer', + 'docs', + ]), + 'title': 'Output', + 'type': 'object', + }) +# --- # name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[sqlite] ''' graph TD; @@ -3788,6 +3871,76 @@ 'type': 'object', }) # --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[postgres_shallow] + ''' + graph TD; + __start__ --> rewrite_query; + analyzer_one --> retriever_one; + qa --> __end__; + retriever_one --> qa; + retriever_two --> qa; + rewrite_query --> analyzer_one; + rewrite_query -.-> retriever_two; + + ''' +# --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[postgres_shallow].1 + dict({ + '$defs': dict({ + 'InnerObject': dict({ + 'properties': dict({ + 'yo': dict({ + 'title': 'Yo', + 'type': 'integer', + }), + }), + 'required': list([ + 'yo', + ]), + 'title': 'InnerObject', + 'type': 'object', + }), + }), + 'properties': dict({ + 'inner': dict({ + '$ref': '#/$defs/InnerObject', + }), + 'query': dict({ + 'title': 'Query', + 'type': 'string', + }), + }), + 'required': list([ + 'query', + 'inner', + ]), + 'title': 'Input', + 'type': 'object', + }) +# --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[postgres_shallow].2 + dict({ + 'properties': dict({ + 'answer': dict({ + 'title': 'Answer', + 'type': 'string', + }), + 'docs': dict({ + 'items': dict({ + 'type': 'string', + }), + 'title': 'Docs', + 'type': 'array', + }), + }), + 'required': list([ + 'answer', + 'docs', + ]), + 'title': 'Output', + 'type': 'object', + }) +# --- # name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[sqlite] ''' graph TD; @@ -3923,6 +4076,19 @@ ''' # --- +# name: test_in_one_fan_out_state_graph_waiting_edge_via_branch[postgres_shallow] + ''' + graph TD; + __start__ --> rewrite_query; + analyzer_one --> retriever_one; + qa --> __end__; + retriever_one --> qa; + retriever_two --> qa; + rewrite_query --> analyzer_one; + rewrite_query -.-> retriever_two; + + ''' +# --- # name: test_in_one_fan_out_state_graph_waiting_edge_via_branch[sqlite] ''' graph TD; diff --git a/libs/langgraph/tests/__snapshots__/test_pregel_async.ambr b/libs/langgraph/tests/__snapshots__/test_pregel_async.ambr index 46916c7a4..22c2562ce 100644 --- a/libs/langgraph/tests/__snapshots__/test_pregel_async.ambr +++ b/libs/langgraph/tests/__snapshots__/test_pregel_async.ambr @@ -934,6 +934,127 @@ 'type': 'object', }) # --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[postgres_aio_shallow] + ''' + graph TD; + __start__ --> rewrite_query; + analyzer_one --> retriever_one; + qa --> __end__; + retriever_one --> qa; + retriever_two --> qa; + rewrite_query --> analyzer_one; + rewrite_query -.-> retriever_two; + + ''' +# --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[postgres_aio_shallow].1 + dict({ + '$defs': dict({ + 'InnerObject': dict({ + 'properties': dict({ + 'yo': dict({ + 'title': 'Yo', + 'type': 'integer', + }), + }), + 'required': list([ + 'yo', + ]), + 'title': 'InnerObject', + 'type': 'object', + }), + }), + 'properties': dict({ + 'answer': dict({ + 'anyOf': list([ + dict({ + 'type': 'string', + }), + dict({ + 'type': 'null', + }), + ]), + 'default': None, + 'title': 'Answer', + }), + 'docs': dict({ + 'items': dict({ + 'type': 'string', + }), + 'title': 'Docs', + 'type': 'array', + }), + 'inner': dict({ + '$ref': '#/$defs/InnerObject', + }), + 'query': dict({ + 'title': 'Query', + 'type': 'string', + }), + }), + 'required': list([ + 'query', + 'inner', + 'docs', + ]), + 'title': 'State', + 'type': 'object', + }) +# --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[postgres_aio_shallow].2 + dict({ + '$defs': dict({ + 'InnerObject': dict({ + 'properties': dict({ + 'yo': dict({ + 'title': 'Yo', + 'type': 'integer', + }), + }), + 'required': list([ + 'yo', + ]), + 'title': 'InnerObject', + 'type': 'object', + }), + }), + 'properties': dict({ + 'answer': dict({ + 'anyOf': list([ + dict({ + 'type': 'string', + }), + dict({ + 'type': 'null', + }), + ]), + 'default': None, + 'title': 'Answer', + }), + 'docs': dict({ + 'items': dict({ + 'type': 'string', + }), + 'title': 'Docs', + 'type': 'array', + }), + 'inner': dict({ + '$ref': '#/$defs/InnerObject', + }), + 'query': dict({ + 'title': 'Query', + 'type': 'string', + }), + }), + 'required': list([ + 'query', + 'inner', + 'docs', + ]), + 'title': 'State', + 'type': 'object', + }) +# --- # name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[sqlite_aio] ''' graph TD; @@ -1362,6 +1483,21 @@ ''' # --- +# name: test_send_react_interrupt_control[postgres_aio_shallow] + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([

__start__

]):::first + agent(agent) + foo([foo]):::last + __start__ --> agent; + agent -.-> foo; + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- # name: test_send_react_interrupt_control[sqlite_aio] ''' %%{init: {'flowchart': {'curve': 'linear'}}}%% diff --git a/libs/langgraph/tests/conftest.py b/libs/langgraph/tests/conftest.py index 0381206e3..a7909eb15 100644 --- a/libs/langgraph/tests/conftest.py +++ b/libs/langgraph/tests/conftest.py @@ -13,8 +13,11 @@ from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.checkpoint.duckdb import DuckDBSaver from langgraph.checkpoint.duckdb.aio import AsyncDuckDBSaver -from langgraph.checkpoint.postgres import PostgresSaver -from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver +from langgraph.checkpoint.postgres import PostgresSaver, ShallowPostgresSaver +from langgraph.checkpoint.postgres.aio import ( + AsyncPostgresSaver, + AsyncShallowPostgresSaver, +) from langgraph.checkpoint.sqlite import SqliteSaver from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from langgraph.store.base import BaseStore @@ -100,6 +103,25 @@ def checkpointer_postgres(): conn.execute(f"DROP DATABASE {database}") +@pytest.fixture(scope="function") +def checkpointer_postgres_shallow(): + database = f"test_{uuid4().hex[:16]}" + # create unique db + with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: + conn.execute(f"CREATE DATABASE {database}") + try: + # yield checkpointer + with ShallowPostgresSaver.from_conn_string( + DEFAULT_POSTGRES_URI + database + ) as checkpointer: + checkpointer.setup() + yield checkpointer + finally: + # drop unique db + with Connection.connect(DEFAULT_POSTGRES_URI, autocommit=True) as conn: + conn.execute(f"DROP DATABASE {database}") + + @pytest.fixture(scope="function") def checkpointer_postgres_pipe(): database = f"test_{uuid4().hex[:16]}" @@ -167,6 +189,31 @@ async def _checkpointer_postgres_aio(): await conn.execute(f"DROP DATABASE {database}") +@asynccontextmanager +async def _checkpointer_postgres_aio_shallow(): + if sys.version_info < (3, 10): + pytest.skip("Async Postgres tests require Python 3.10+") + database = f"test_{uuid4().hex[:16]}" + # create unique db + async with await AsyncConnection.connect( + DEFAULT_POSTGRES_URI, autocommit=True + ) as conn: + await conn.execute(f"CREATE DATABASE {database}") + try: + # yield checkpointer + async with AsyncShallowPostgresSaver.from_conn_string( + DEFAULT_POSTGRES_URI + database + ) as checkpointer: + await checkpointer.setup() + yield checkpointer + finally: + # drop unique db + async with await AsyncConnection.connect( + DEFAULT_POSTGRES_URI, autocommit=True + ) as conn: + await conn.execute(f"DROP DATABASE {database}") + + @asynccontextmanager async def _checkpointer_postgres_aio_pipe(): if sys.version_info < (3, 10): @@ -240,6 +287,9 @@ async def awith_checkpointer( elif checkpointer_name == "postgres_aio": async with _checkpointer_postgres_aio() as checkpointer: yield checkpointer + elif checkpointer_name == "postgres_aio_shallow": + async with _checkpointer_postgres_aio_shallow() as checkpointer: + yield checkpointer elif checkpointer_name == "postgres_aio_pipe": async with _checkpointer_postgres_aio_pipe() as checkpointer: yield checkpointer @@ -417,20 +467,30 @@ async def awith_store(store_name: Optional[str]) -> AsyncIterator[BaseStore]: raise NotImplementedError(f"Unknown store {store_name}") -ALL_CHECKPOINTERS_SYNC = [ +SHALLOW_CHECKPOINTERS_SYNC = ["postgres_shallow"] +REGULAR_CHECKPOINTERS_SYNC = [ "memory", "sqlite", "postgres", "postgres_pipe", "postgres_pool", ] -ALL_CHECKPOINTERS_ASYNC = [ +ALL_CHECKPOINTERS_SYNC = [ + *REGULAR_CHECKPOINTERS_SYNC, + *SHALLOW_CHECKPOINTERS_SYNC, +] +SHALLOW_CHECKPOINTERS_ASYNC = ["postgres_aio_shallow"] +REGULAR_CHECKPOINTERS_ASYNC = [ "memory", "sqlite_aio", "postgres_aio", "postgres_aio_pipe", "postgres_aio_pool", ] +ALL_CHECKPOINTERS_ASYNC = [ + *REGULAR_CHECKPOINTERS_ASYNC, + *SHALLOW_CHECKPOINTERS_ASYNC, +] ALL_CHECKPOINTERS_ASYNC_PLUS_NONE = [ *ALL_CHECKPOINTERS_ASYNC, None, diff --git a/libs/langgraph/tests/test_large_cases.py b/libs/langgraph/tests/test_large_cases.py index a89d68184..f3f399f38 100644 --- a/libs/langgraph/tests/test_large_cases.py +++ b/libs/langgraph/tests/test_large_cases.py @@ -38,7 +38,11 @@ ) from tests.agents import AgentAction, AgentFinish from tests.any_str import AnyDict, AnyStr, UnsortedSequence -from tests.conftest import ALL_CHECKPOINTERS_SYNC, SHOULD_CHECK_SNAPSHOTS +from tests.conftest import ( + ALL_CHECKPOINTERS_SYNC, + REGULAR_CHECKPOINTERS_SYNC, + SHOULD_CHECK_SNAPSHOTS, +) from tests.fake_chat import FakeChatModel from tests.fake_tracer import FakeTracer from tests.messages import ( @@ -112,6 +116,9 @@ def test_invoke_two_processes_in_out_interrupt( snapshot = app.get_state(thread2) assert snapshot.next == () + if "shallow" in checkpointer_name: + return + # list history history = [c for c in app.get_state_history(thread1)] assert history == [ @@ -296,7 +303,7 @@ def test_invoke_two_processes_in_out_interrupt( ] -@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +@pytest.mark.parametrize("checkpointer_name", REGULAR_CHECKPOINTERS_SYNC) def test_fork_always_re_runs_nodes( request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture ) -> None: @@ -742,8 +749,14 @@ def should_continue(data: dict) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], - config=app_w_interrupt.checkpointer.get_tuple(config).config, + created_at=AnyStr(), + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, metadata={ "parents": {}, "source": "loop", @@ -762,7 +775,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert ( app_w_interrupt.checkpointer.get_tuple(config).config["configurable"][ @@ -796,8 +813,14 @@ def should_continue(data: dict) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -814,7 +837,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -908,8 +935,14 @@ def should_continue(data: dict) -> str: }, tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -935,7 +968,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # test state get/update methods with interrupt_before @@ -971,8 +1008,14 @@ def should_continue(data: dict) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -991,7 +1034,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -1019,8 +1066,14 @@ def should_continue(data: dict) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1037,7 +1090,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -1131,8 +1188,14 @@ def should_continue(data: dict) -> str: }, tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1158,7 +1221,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # test re-invoke to continue with interrupt_before @@ -1194,8 +1261,14 @@ def should_continue(data: dict) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1214,7 +1287,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "3", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -1583,8 +1660,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1600,7 +1683,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) with assert_ctx_once(): @@ -1626,8 +1713,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1643,7 +1736,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) with assert_ctx_once(): @@ -1704,8 +1801,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1720,7 +1823,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # test state get/update methods with interrupt_before @@ -1755,8 +1862,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1772,7 +1885,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -1797,8 +1914,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1814,7 +1937,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -1873,8 +2000,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1889,7 +2022,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # test w interrupt before all @@ -1913,8 +2050,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1922,7 +2065,11 @@ def should_continue(data: AgentState) -> str: "writes": None, "thread_id": "3", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -1945,8 +2092,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1962,7 +2115,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "3", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -2001,8 +2158,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -2023,7 +2186,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "3", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -2069,8 +2236,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "4", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -2086,7 +2259,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "4", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -2125,8 +2302,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "4", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -2147,7 +2330,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "4", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -2882,8 +3069,14 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: PregelTask(AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2)), ), next=("tools",), - config=(app_w_interrupt.checkpointer.get_tuple(config)).config, - created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -2891,7 +3084,11 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: "writes": None, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config + ), ) # modify ai message @@ -2921,8 +3118,14 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, tasks=(PregelTask(AnyStr(), "tools", (PUSH, (), 0)),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -2945,7 +3148,11 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -3052,8 +3259,14 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: PregelTask(AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 3)), ), next=("tools", "tools"), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -3069,7 +3282,11 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config + ), ) app_w_interrupt.update_state( @@ -3106,8 +3323,14 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -3120,7 +3343,11 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config + ), ) # interrupt before tools @@ -3199,8 +3426,14 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: PregelTask(AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2)), ), next=("tools",), - config=(app_w_interrupt.checkpointer.get_tuple(config)).config, - created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -3208,7 +3441,11 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: "writes": None, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config + ), ) # modify ai message @@ -3239,7 +3476,7 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: tasks=(PregelTask(AnyStr(), "tools", (PUSH, (), 0)),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -3262,7 +3499,11 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -3369,8 +3610,14 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: PregelTask(AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 3)), ), next=("tools", "tools"), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -3386,7 +3633,11 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config + ), ) app_w_interrupt.update_state( @@ -3423,8 +3674,14 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -3437,7 +3694,11 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config + ), ) @@ -3702,11 +3963,17 @@ def should_continue(messages): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], - metadata={ - "parents": {}, - "source": "loop", + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), + metadata={ + "parents": {}, + "source": "loop", "step": 1, "writes": { "agent": AIMessage( @@ -3723,7 +3990,11 @@ def should_continue(messages): }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # modify ai message @@ -3770,7 +4041,11 @@ def should_continue(messages): }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -3832,8 +4107,14 @@ def should_continue(messages): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -3853,7 +4134,11 @@ def should_continue(messages): }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -3885,8 +4170,14 @@ def should_continue(messages): ], tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -3894,7 +4185,11 @@ def should_continue(messages): "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt = workflow.compile( @@ -3938,8 +4233,14 @@ def should_continue(messages): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -3959,7 +4260,11 @@ def should_continue(messages): }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # modify ai message @@ -3985,8 +4290,14 @@ def should_continue(messages): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -4006,7 +4317,11 @@ def should_continue(messages): }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -4068,8 +4383,14 @@ def should_continue(messages): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4089,7 +4410,11 @@ def should_continue(messages): }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -4122,8 +4447,14 @@ def should_continue(messages): ], tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -4131,7 +4462,11 @@ def should_continue(messages): "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # add an extra message as if it came from "tools" node @@ -4164,8 +4499,14 @@ def should_continue(messages): ], tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -4173,7 +4514,11 @@ def should_continue(messages): "writes": {"tools": UnsortedSequence("ai", "an extra message")}, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) @@ -4441,8 +4786,14 @@ class State(TypedDict): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4462,7 +4813,11 @@ class State(TypedDict): }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # modify ai message @@ -4509,7 +4864,11 @@ class State(TypedDict): }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -4572,8 +4931,14 @@ class State(TypedDict): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4593,7 +4958,11 @@ class State(TypedDict): }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -4626,8 +4995,14 @@ class State(TypedDict): ], tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -4635,7 +5010,11 @@ class State(TypedDict): "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt = workflow.compile( @@ -4679,8 +5058,14 @@ class State(TypedDict): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4700,7 +5085,11 @@ class State(TypedDict): }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # modify ai message @@ -4726,8 +5115,14 @@ class State(TypedDict): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -4747,7 +5142,11 @@ class State(TypedDict): }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -4810,8 +5209,14 @@ class State(TypedDict): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4831,7 +5236,11 @@ class State(TypedDict): }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -4863,8 +5272,14 @@ class State(TypedDict): ], tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -4872,7 +5287,11 @@ class State(TypedDict): "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # add an extra message as if it came from "tools" node @@ -4905,8 +5324,14 @@ class State(TypedDict): ], tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -4914,7 +5339,11 @@ class State(TypedDict): "writes": {"tools": UnsortedSequence("ai", "an extra message")}, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # create new graph with one more state key, reuse previous thread history @@ -4978,8 +5407,14 @@ class MoreState(TypedDict): }, tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -4987,7 +5422,11 @@ class MoreState(TypedDict): "writes": {"tools": UnsortedSequence("ai", "an extra message")}, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(new_app.checkpointer.list(config, limit=2))[-1].config + ), ) # new input is merged to old state @@ -5333,22 +5772,25 @@ def tool_two_node(s: State) -> State: "my_key": "value ⛰️", "market": "DE", } - assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ - { - "parents": {}, - "source": "loop", - "step": 0, - "writes": None, - "thread_id": "1", - }, - { - "parents": {}, - "source": "input", - "step": -1, - "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, - "thread_id": "1", - }, - ] + + if "shallow" not in checkpointer_name: + assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] + assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, next=("tool_two",), @@ -5366,8 +5808,14 @@ def tool_two_node(s: State) -> State: ), ), ), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5375,7 +5823,11 @@ def tool_two_node(s: State) -> State: "writes": None, "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) # clear the interrupt and next tasks tool_two.update_state(thread1, None, as_node=END) @@ -5384,8 +5836,14 @@ def tool_two_node(s: State) -> State: values={"my_key": "value ⛰️", "market": "DE"}, next=(), tasks=(), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -5393,7 +5851,11 @@ def tool_two_node(s: State) -> State: "writes": {}, "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) @@ -5487,22 +5949,24 @@ def start(state: State) -> list[Union[Send, str]]: "my_key": "value ⛰️ one", "market": "DE", } - assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ - { - "parents": {}, - "source": "loop", - "step": 0, - "writes": {"tool_one": {"my_key": " one"}}, - "thread_id": "1", - }, - { - "parents": {}, - "source": "input", - "step": -1, - "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, - "thread_id": "1", - }, - ] + if "shallow" not in checkpointer_name: + assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": {"tool_one": {"my_key": " one"}}, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] + assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️ one", "market": "DE"}, next=("tool_two",), @@ -5520,16 +5984,26 @@ def start(state: State) -> list[Union[Send, str]]: ), ), ), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], - metadata={ - "parents": {}, - "source": "loop", - "step": 0, - "writes": {"tool_one": {"my_key": " one"}}, - "thread_id": "1", + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), + metadata={ + "parents": {}, + "source": "loop", + "step": 0, + "writes": {"tool_one": {"my_key": " one"}}, + "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*tool_two.checkpointer.list(thread1, limit=2)][-1].config + ), ) # clear the interrupt and next tasks tool_two.update_state(thread1, None) @@ -5545,8 +6019,14 @@ def start(state: State) -> list[Union[Send, str]]: interrupts=(), ), ), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -5554,7 +6034,11 @@ def start(state: State) -> list[Union[Send, str]]: "writes": {}, "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*tool_two.checkpointer.list(thread1, limit=2)][-1].config + ), ) @@ -5645,27 +6129,30 @@ class State(TypedDict): "my_key": "value ⛰️", "market": "DE", } - assert [ - c.metadata - for c in tool_two.checkpointer.list( - {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} - ) - ] == [ - { - "parents": {}, - "source": "loop", - "step": 0, - "writes": None, - "thread_id": "1", - }, - { - "parents": {}, - "source": "input", - "step": -1, - "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, - "thread_id": "1", - }, - ] + + if "shallow" not in checkpointer_name: + assert [ + c.metadata + for c in tool_two.checkpointer.list( + {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} + ) + ] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] + assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, next=("tool_two",), @@ -5689,8 +6176,14 @@ class State(TypedDict): }, ), ), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5698,11 +6191,15 @@ class State(TypedDict): "writes": None, "thread_id": "1", }, - parent_config=[ - *tool_two.checkpointer.list( - {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2 - ) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list( + tool_two.checkpointer.list( + {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2 + ) + )[-1].config + ), ) # clear the interrupt and next tasks tool_two.update_state(thread1, None, as_node=END) @@ -5711,8 +6208,14 @@ class State(TypedDict): values={"my_key": "value ⛰️", "market": "DE"}, next=(), tasks=(), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -5720,11 +6223,15 @@ class State(TypedDict): "writes": {}, "thread_id": "1", }, - parent_config=[ - *tool_two.checkpointer.list( - {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2 - ) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list( + tool_two.checkpointer.list( + {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2 + ) + )[-1].config + ), ) @@ -5794,30 +6301,39 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "my_key": "value ⛰️", "market": "DE", } - assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ - { - "parents": {}, - "source": "loop", - "step": 0, - "writes": None, - "assistant_id": "a", - "thread_id": "1", - }, - { - "parents": {}, - "source": "input", - "step": -1, - "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, - "assistant_id": "a", - "thread_id": "1", - }, - ] + + if "shallow" not in checkpointer_name: + assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "assistant_id": "a", + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "assistant_id": "a", + "thread_id": "1", + }, + ] + assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5826,7 +6342,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread1, debug=1) == { @@ -5837,8 +6357,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "value ⛰️ slow", "market": "DE"}, tasks=(), next=(), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5847,7 +6373,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) thread2 = {"configurable": {"thread_id": "2", "assistant_id": "a"}} @@ -5860,8 +6390,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "value", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=tool_two.checkpointer.get_tuple(thread2).config, - created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5870,7 +6406,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "2", }, - parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread2, debug=1) == { @@ -5881,8 +6421,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "value fast", "market": "US"}, tasks=(), next=(), - config=tool_two.checkpointer.get_tuple(thread2).config, - created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5891,7 +6437,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "2", }, - parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) thread3 = {"configurable": {"thread_id": "3", "assistant_id": "b"}} @@ -5904,8 +6454,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "value", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=tool_two.checkpointer.get_tuple(thread3).config, - created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5914,7 +6470,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "b", "thread_id": "3", }, - parent_config=[*tool_two.checkpointer.list(thread3, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread3, limit=2))[-1].config + ), ) # update state tool_two.update_state(thread3, {"my_key": "key"}) # appends to my_key @@ -5922,8 +6482,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "valuekey", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=tool_two.checkpointer.get_tuple(thread3).config, - created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -5932,7 +6498,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "b", "thread_id": "3", }, - parent_config=[*tool_two.checkpointer.list(thread3, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread3, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread3, debug=1) == { @@ -5943,8 +6513,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "valuekey fast", "market": "US"}, tasks=(), next=(), - config=tool_two.checkpointer.get_tuple(thread3).config, - created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5953,7 +6529,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "b", "thread_id": "3", }, - parent_config=[*tool_two.checkpointer.list(thread3, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread3, limit=2))[-1].config + ), ) @@ -6307,8 +6887,14 @@ class State(TypedDict): values={"my_key": "value prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -6316,7 +6902,11 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread1, debug=1) == { @@ -6327,8 +6917,14 @@ class State(TypedDict): values={"my_key": "value prepared slow finished", "market": "DE"}, tasks=(), next=(), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -6336,7 +6932,11 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) thread2 = {"configurable": {"thread_id": "2"}} @@ -6349,8 +6949,14 @@ class State(TypedDict): values={"my_key": "value prepared", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=tool_two.checkpointer.get_tuple(thread2).config, - created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -6358,7 +6964,11 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "2", }, - parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread2, debug=1) == { @@ -6369,8 +6979,14 @@ class State(TypedDict): values={"my_key": "value prepared fast finished", "market": "US"}, tasks=(), next=(), - config=tool_two.checkpointer.get_tuple(thread2).config, - created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -6378,7 +6994,11 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "2", }, - parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) tool_two = tool_two_graph.compile( @@ -6399,8 +7019,14 @@ class State(TypedDict): }, tasks=(PregelTask(AnyStr(), "finish", (PULL, "finish")),), next=("finish",), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "11", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -6408,7 +7034,11 @@ class State(TypedDict): "writes": {"tool_two_slow": {"my_key": " slow"}}, "thread_id": "11", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) # update state @@ -6420,8 +7050,14 @@ class State(TypedDict): }, tasks=(PregelTask(AnyStr(), "finish", (PULL, "finish")),), next=("finish",), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "11", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -6429,7 +7065,11 @@ class State(TypedDict): "writes": {"tool_two_slow": {"my_key": "er"}}, "thread_id": "11", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) tool_two = tool_two_graph.compile( @@ -6450,8 +7090,14 @@ class State(TypedDict): values={"my_key": "value prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "21", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -6459,7 +7105,11 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "21", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread1, debug=1) == { @@ -6470,8 +7120,14 @@ class State(TypedDict): values={"my_key": "value prepared slow finished", "market": "DE"}, tasks=(), next=(), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "21", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -6479,7 +7135,11 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "21", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) thread2 = {"configurable": {"thread_id": "22"}} @@ -6492,8 +7152,14 @@ class State(TypedDict): values={"my_key": "value prepared", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=tool_two.checkpointer.get_tuple(thread2).config, - created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "22", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -6501,7 +7167,11 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "22", }, - parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread2, debug=1) == { @@ -6512,8 +7182,14 @@ class State(TypedDict): values={"my_key": "value prepared fast finished", "market": "US"}, tasks=(), next=(), - config=tool_two.checkpointer.get_tuple(thread2).config, - created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "22", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -6521,18 +7197,28 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "22", }, - parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) thread3 = {"configurable": {"thread_id": "23"}} # update an empty thread before first run - uconfig = tool_two.update_state(thread3, {"my_key": "key", "market": "DE"}) + tool_two.update_state(thread3, {"my_key": "key", "market": "DE"}) # check current state assert tool_two.get_state(thread3) == StateSnapshot( values={"my_key": "key", "market": "DE"}, tasks=(PregelTask(AnyStr(), "prepare", (PULL, "prepare")),), next=("prepare",), - config=uconfig, + config={ + "configurable": { + "thread_id": "23", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, created_at=AnyStr(), metadata={ "parents": {}, @@ -6553,8 +7239,14 @@ class State(TypedDict): values={"my_key": "key prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), - config=tool_two.checkpointer.get_tuple(thread3).config, - created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "23", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -6562,7 +7254,11 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "23", }, - parent_config=uconfig, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread3, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread3, debug=1) == { @@ -6573,8 +7269,14 @@ class State(TypedDict): values={"my_key": "key prepared slow finished", "market": "DE"}, tasks=(), next=(), - config=tool_two.checkpointer.get_tuple(thread3).config, - created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "23", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -6582,7 +7284,11 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "23", }, - parent_config=[*tool_two.checkpointer.list(thread3, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread3, limit=2))[-1].config + ), ) @@ -6656,8 +7362,10 @@ def route_to_three(state) -> Literal["3"]: state = graph.get_state(thread1) assert state.next == ("flaky",) # check history - history = [c for c in graph.get_state_history(thread1)] - assert len(history) == 2 + if "shallow" not in checkpointer_name: + history = [c for c in graph.get_state_history(thread1)] + assert len(history) == 2 + # resume execution assert graph.invoke(None, thread1, debug=1) == [ "0", @@ -6678,7 +7386,7 @@ def route_to_three(state) -> Literal["3"]: assert state.next == () # check history history = [c for c in graph.get_state_history(thread1)] - assert history == [ + expected_history = [ StateSnapshot( values=[ "0", @@ -6706,13 +7414,17 @@ def route_to_three(state) -> Literal["3"]: "parents": {}, }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=(), ), StateSnapshot( @@ -6905,6 +7617,10 @@ def route_to_three(state) -> Literal["3"]: ), ), ] + if "shallow" in checkpointer_name: + expected_history = expected_history[:1] + + assert history == expected_history @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) @@ -6990,13 +7706,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # now, get_state with subgraphs state assert app.get_state(config, subgraphs=True) == StateSnapshot( @@ -7050,16 +7770,20 @@ def outer_2(state: State): "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("inner:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - {"": AnyStr(), AnyStr("child:"): AnyStr()} - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("inner:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), + } } - }, + ), ), ), ), @@ -7079,17 +7803,21 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # get_state_history returns outer graph checkpoints history = list(app.get_state_history(config)) - assert history == [ + expected_history = [ StateSnapshot( values={"my_key": "hi my value"}, tasks=( @@ -7121,13 +7849,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ), StateSnapshot( values={"my_key": "my value"}, @@ -7192,9 +7924,15 @@ def outer_2(state: State): parent_config=None, ), ] + + if "shallow" in checkpointer_name: + expected_history = expected_history[:1] + + assert history == expected_history + # get_state_history for a subgraph returns its checkpoints child_history = [*app.get_state_history(history[0].tasks[0].state)] - assert child_history == [ + expected_child_history = [ StateSnapshot( values={"my_key": "hi my value here", "my_other_key": "hi my value"}, next=("inner_2",), @@ -7227,16 +7965,20 @@ def outer_2(state: State): "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("inner:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - {"": AnyStr(), AnyStr("child:"): AnyStr()} - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("inner:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), + } } - }, + ), tasks=(PregelTask(AnyStr(), "inner_2", (PULL, "inner_2")),), ), StateSnapshot( @@ -7327,6 +8069,11 @@ def outer_2(state: State): ), ] + if "shallow" in checkpointer_name: + expected_child_history = expected_child_history[:1] + + assert child_history == expected_child_history + # resume app.invoke(None, config, debug=True) # test state w/ nested subgraph state (after resuming from interrupt) @@ -7351,13 +8098,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # test full history at the end actual_history = list(app.get_state_history(config)) @@ -7383,13 +8134,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ), StateSnapshot( values={"my_key": "hi my value here and there"}, @@ -7525,6 +8280,9 @@ def outer_2(state: State): parent_config=None, ), ] + if "shallow" in checkpointer_name: + expected_history = expected_history[:1] + assert actual_history == expected_history # test looking up parent state by checkpoint ID for actual_snapshot, expected_snapshot in zip(actual_history, expected_history): @@ -7629,13 +8387,17 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) child_state = app.get_state(outer_state.tasks[0].state) assert ( @@ -7671,13 +8433,17 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("child:"), - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("child:"), + "checkpoint_id": AnyStr(), + } } - }, + ), ).tasks[0] ) grandchild_state = app.get_state(child_state.tasks[0].state) @@ -7724,20 +8490,24 @@ def parent_2(state: State): "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("child:"): AnyStr(), - AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), + } + ), + } } - }, + ), ) # get state with subgraphs assert app.get_state(config, subgraphs=True) == StateSnapshot( @@ -7804,22 +8574,26 @@ def parent_2(state: State): "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("child:"): AnyStr(), - AnyStr( - re.compile(r"child:.+|child1:") - ): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr( + re.compile(r"child:.+|child1:") + ): AnyStr(), + } + ), + } } - }, + ), ), ), ), @@ -7848,16 +8622,20 @@ def parent_2(state: State): "langgraph_checkpoint_ns": AnyStr("child:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("child:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - {"": AnyStr(), AnyStr("child:"): AnyStr()} - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("child:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), + } } - }, + ), ), ), ), @@ -7877,13 +8655,17 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # # resume assert [c for c in app.stream(None, config, subgraphs=True)] == [ @@ -7920,15 +8702,23 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) ) + + if "shallow" in checkpointer_name: + return + # get outer graph history outer_history = list(app.get_state_history(config)) assert outer_history == [ @@ -8628,19 +9418,23 @@ def edit(state: JokeState): "langgraph_triggers": [PUSH], }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("generate_joke:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("generate_joke:"): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("generate_joke:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("generate_joke:"): AnyStr(), + } + ), + } } - }, + ), tasks=(PregelTask(id=AnyStr(""), name="generate", path=(PULL, "generate")),), ) assert graph.get_state(outer_state.tasks[2].state) == StateSnapshot( @@ -8673,19 +9467,23 @@ def edit(state: JokeState): "langgraph_triggers": [PUSH], }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("generate_joke:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("generate_joke:"): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("generate_joke:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("generate_joke:"): AnyStr(), + } + ), + } } - }, + ), tasks=(PregelTask(id=AnyStr(""), name="generate", path=(PULL, "generate")),), ) # update state of dogs joke graph @@ -8729,16 +9527,23 @@ def edit(state: JokeState): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) assert actual_snapshot == expected_snapshot + if "shallow" in checkpointer_name: + return + # test full history actual_history = list(graph.get_state_history(config)) @@ -9000,13 +9805,17 @@ def foo(call: ToolCall): "thread_id": "2", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -9083,13 +9892,17 @@ def foo(call: ToolCall): "thread_id": "2", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=(), ) @@ -9159,13 +9972,17 @@ def foo(call: ToolCall): "thread_id": "3", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "3", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -9268,13 +10085,17 @@ def foo(call: ToolCall): "thread_id": "3", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "3", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -9469,13 +10290,17 @@ def foo(call: ToolCall): "thread_id": "2", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -9552,13 +10377,17 @@ def foo(call: ToolCall): "thread_id": "2", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=(), ) @@ -9732,13 +10561,17 @@ def weather_graph(state: RouterState): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -9820,13 +10653,17 @@ def weather_graph(state: RouterState): "thread_id": "14", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -9869,19 +10706,23 @@ def weather_graph(state: RouterState): "langgraph_checkpoint_ns": AnyStr("weather_graph:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": AnyStr("weather_graph:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("weather_graph:"): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": AnyStr("weather_graph:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("weather_graph:"): AnyStr(), + } + ), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -9920,13 +10761,17 @@ def weather_graph(state: RouterState): "thread_id": "14", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -9975,19 +10820,23 @@ def weather_graph(state: RouterState): "langgraph_checkpoint_ns": AnyStr("weather_graph:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": AnyStr("weather_graph:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("weather_graph:"): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": AnyStr("weather_graph:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("weather_graph:"): AnyStr(), + } + ), + } } - }, + ), tasks=(), ), ), diff --git a/libs/langgraph/tests/test_large_cases_async.py b/libs/langgraph/tests/test_large_cases_async.py index b4d656552..f6e807fca 100644 --- a/libs/langgraph/tests/test_large_cases_async.py +++ b/libs/langgraph/tests/test_large_cases_async.py @@ -36,7 +36,11 @@ from langgraph.store.memory import InMemoryStore from langgraph.types import PregelTask, Send, StateSnapshot, StreamWriter from tests.any_str import AnyDict, AnyStr -from tests.conftest import ALL_CHECKPOINTERS_ASYNC, awith_checkpointer +from tests.conftest import ( + ALL_CHECKPOINTERS_ASYNC, + REGULAR_CHECKPOINTERS_ASYNC, + awith_checkpointer, +) from tests.fake_chat import FakeChatModel from tests.fake_tracer import FakeTracer from tests.messages import ( @@ -111,6 +115,9 @@ async def test_invoke_two_processes_in_out_interrupt( snapshot = await app.aget_state(thread2) assert snapshot.next == () + if "shallow" in checkpointer_name: + return + # list history history = [c async for c in app.aget_state_history(thread1)] assert history == [ @@ -311,7 +318,7 @@ async def test_invoke_two_processes_in_out_interrupt( ] -@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +@pytest.mark.parametrize("checkpointer_name", REGULAR_CHECKPOINTERS_ASYNC) async def test_fork_always_re_runs_nodes( checkpointer_name: str, mocker: MockerFixture ) -> None: @@ -844,9 +851,13 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) await app_w_interrupt.aupdate_state( @@ -874,10 +885,14 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -894,9 +909,13 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -990,10 +1009,14 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, tasks=(), next=(), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1019,9 +1042,13 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) # test state get/update methods with interrupt_before @@ -1064,10 +1091,14 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1086,9 +1117,13 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) await app_w_interrupt.aupdate_state( @@ -1116,10 +1151,14 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1136,9 +1175,13 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -1232,10 +1275,14 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, tasks=(), next=(), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1261,9 +1308,13 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) # test re-invoke to continue with interrupt_before @@ -1306,10 +1357,14 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1328,9 +1383,13 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, "thread_id": "3", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -1730,10 +1789,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1749,9 +1812,13 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) async with assert_ctx_once(): @@ -1777,10 +1844,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1796,9 +1867,13 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) async with assert_ctx_once(): @@ -1859,10 +1934,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(), next=(), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1877,9 +1956,13 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) # test state get/update methods with interrupt_before @@ -1918,10 +2001,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1937,9 +2024,13 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) await app_w_interrupt.aupdate_state( @@ -1964,10 +2055,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1983,9 +2078,13 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -2044,10 +2143,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(), next=(), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -2062,9 +2165,13 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) @@ -2679,10 +2786,14 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: PregelTask(AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2)), ), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -2690,9 +2801,13 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: "writes": None, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) # modify ai message @@ -2744,9 +2859,13 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -2871,9 +2990,13 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) await app_w_interrupt.aupdate_state( @@ -2921,9 +3044,13 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) # interrupt before tools @@ -2958,7 +3085,7 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, {"__interrupt__": ()}, ] - + tup = await app_w_interrupt.checkpointer.aget_tuple(config) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "messages": [ @@ -3004,10 +3131,8 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: PregelTask(AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2)), ), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config=tup.config, + created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "loop", @@ -3015,9 +3140,13 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: "writes": None, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) # modify ai message @@ -3069,9 +3198,13 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -3198,9 +3331,13 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) await app_w_interrupt.aupdate_state( @@ -3248,9 +3385,13 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) @@ -3509,9 +3650,13 @@ def should_continue(messages): }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) # modify ai message @@ -3559,9 +3704,13 @@ def should_continue(messages): }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -3645,9 +3794,13 @@ def should_continue(messages): }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) await app_w_interrupt.aupdate_state( @@ -3689,9 +3842,13 @@ def should_continue(messages): "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) @@ -3984,32 +4141,38 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "my_key": "value", "market": "DE", } - assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [ - { - "parents": {}, - "source": "loop", - "step": 0, - "writes": None, - "assistant_id": "a", - "thread_id": "1", - }, - { - "parents": {}, - "source": "input", - "step": -1, - "writes": {"__start__": {"my_key": "value", "market": "DE"}}, - "assistant_id": "a", - "thread_id": "1", - }, - ] + if "shallow" not in checkpointer_name: + assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "assistant_id": "a", + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value", "market": "DE"}}, + "assistant_id": "a", + "thread_id": "1", + }, + ] + assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), - config=(await tool_two.checkpointer.aget_tuple(thread1)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4018,9 +4181,13 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "1", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) # resume, for same result as above assert await tool_two.ainvoke(None, thread1, debug=1) == { @@ -4031,10 +4198,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "value slow", "market": "DE"}, tasks=(), next=(), - config=(await tool_two.checkpointer.aget_tuple(thread1)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4043,9 +4214,13 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "1", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) thread2 = {"configurable": {"thread_id": "2", "assistant_id": "a"}} @@ -4058,10 +4233,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "value", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=(await tool_two.checkpointer.aget_tuple(thread2)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4070,9 +4249,13 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "2", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread2, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread2, limit=2)][ + -1 + ].config + ), ) # resume, for same result as above assert await tool_two.ainvoke(None, thread2, debug=1) == { @@ -4083,10 +4266,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "value fast", "market": "US"}, tasks=(), next=(), - config=(await tool_two.checkpointer.aget_tuple(thread2)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4095,9 +4282,13 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "2", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread2, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread2, limit=2)][ + -1 + ].config + ), ) thread3 = {"configurable": {"thread_id": "3", "assistant_id": "b"}} @@ -4110,10 +4301,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "value", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=(await tool_two.checkpointer.aget_tuple(thread3)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4122,9 +4317,13 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "b", "thread_id": "3", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread3, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread3, limit=2)][ + -1 + ].config + ), ) # update state await tool_two.aupdate_state(thread3, {"my_key": "key"}) # appends to my_key @@ -4132,10 +4331,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "valuekey", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=(await tool_two.checkpointer.aget_tuple(thread3)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -4144,9 +4347,13 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "b", "thread_id": "3", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread3, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread3, limit=2)][ + -1 + ].config + ), ) # resume, for same result as above assert await tool_two.ainvoke(None, thread3, debug=1) == { @@ -4157,10 +4364,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "valuekey fast", "market": "US"}, tasks=(), next=(), - config=(await tool_two.checkpointer.aget_tuple(thread3)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4169,9 +4380,13 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "b", "thread_id": "3", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread3, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread3, limit=2)][ + -1 + ].config + ), ) @@ -4689,10 +4904,14 @@ class State(TypedDict): values={"my_key": "value prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), - config=(await tool_two.checkpointer.aget_tuple(thread1)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "11", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4700,9 +4919,13 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "11", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) # resume, for same result as above assert await tool_two.ainvoke(None, thread1, debug=1) == { @@ -4713,10 +4936,14 @@ class State(TypedDict): values={"my_key": "value prepared slow finished", "market": "DE"}, tasks=(), next=(), - config=(await tool_two.checkpointer.aget_tuple(thread1)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "11", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4724,9 +4951,13 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "11", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) thread2 = {"configurable": {"thread_id": "12"}} @@ -4739,10 +4970,14 @@ class State(TypedDict): values={"my_key": "value prepared", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=(await tool_two.checkpointer.aget_tuple(thread2)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "12", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4750,9 +4985,13 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "12", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread2, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread2, limit=2)][ + -1 + ].config + ), ) # resume, for same result as above assert await tool_two.ainvoke(None, thread2, debug=1) == { @@ -4763,10 +5002,14 @@ class State(TypedDict): values={"my_key": "value prepared fast finished", "market": "US"}, tasks=(), next=(), - config=(await tool_two.checkpointer.aget_tuple(thread2)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "12", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4774,9 +5017,13 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "12", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread2, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread2, limit=2)][ + -1 + ].config + ), ) tool_two = tool_two_graph.compile( @@ -4797,10 +5044,14 @@ class State(TypedDict): values={"my_key": "value prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), - config=(await tool_two.checkpointer.aget_tuple(thread1)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "21", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4808,9 +5059,13 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "21", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) # resume, for same result as above assert await tool_two.ainvoke(None, thread1, debug=1) == { @@ -4821,10 +5076,14 @@ class State(TypedDict): values={"my_key": "value prepared slow finished", "market": "DE"}, tasks=(), next=(), - config=(await tool_two.checkpointer.aget_tuple(thread1)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "21", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4832,9 +5091,13 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "21", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) thread2 = {"configurable": {"thread_id": "22"}} @@ -4847,10 +5110,14 @@ class State(TypedDict): values={"my_key": "value prepared", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=(await tool_two.checkpointer.aget_tuple(thread2)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "22", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4858,9 +5125,13 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "22", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread2, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread2, limit=2)][ + -1 + ].config + ), ) # resume, for same result as above assert await tool_two.ainvoke(None, thread2, debug=1) == { @@ -4871,10 +5142,14 @@ class State(TypedDict): values={"my_key": "value prepared fast finished", "market": "US"}, tasks=(), next=(), - config=(await tool_two.checkpointer.aget_tuple(thread2)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "22", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4882,9 +5157,13 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "22", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread2, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread2, limit=2)][ + -1 + ].config + ), ) thread3 = {"configurable": {"thread_id": "23"}} @@ -4918,10 +5197,14 @@ class State(TypedDict): values={"my_key": "key prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), - config=(await tool_two.checkpointer.aget_tuple(thread3)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "23", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4929,7 +5212,7 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "23", }, - parent_config=uconfig, + parent_config=(None if "shallow" in checkpointer_name else uconfig), ) # resume, for same result as above assert await tool_two.ainvoke(None, thread3, debug=1) == { @@ -4940,10 +5223,14 @@ class State(TypedDict): values={"my_key": "key prepared slow finished", "market": "DE"}, tasks=(), next=(), - config=(await tool_two.checkpointer.aget_tuple(thread3)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "23", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4951,9 +5238,13 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "23", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread3, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread3, limit=2)][ + -1 + ].config + ), ) @@ -5039,13 +5330,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # now, get_state with subgraphs state assert await app.aget_state(config, subgraphs=True) == StateSnapshot( @@ -5099,16 +5394,20 @@ def outer_2(state: State): "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("inner:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - {"": AnyStr(), AnyStr("child:"): AnyStr()} - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("inner:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), + } } - }, + ), ), ), ), @@ -5128,17 +5427,21 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # get_state_history returns outer graph checkpoints history = [c async for c in app.aget_state_history(config)] - assert history == [ + expected_history = [ StateSnapshot( values={"my_key": "hi my value"}, tasks=( @@ -5170,13 +5473,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ), StateSnapshot( values={"my_key": "my value"}, @@ -5241,11 +5548,17 @@ def outer_2(state: State): parent_config=None, ), ] + + if "shallow" in checkpointer_name: + expected_history = expected_history[:1] + + assert history == expected_history + # get_state_history for a subgraph returns its checkpoints child_history = [ c async for c in app.aget_state_history(history[0].tasks[0].state) ] - assert child_history == [ + expected_child_history = [ StateSnapshot( values={"my_key": "hi my value here", "my_other_key": "hi my value"}, next=("inner_2",), @@ -5278,16 +5591,20 @@ def outer_2(state: State): "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("inner:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - {"": AnyStr(), AnyStr("child:"): AnyStr()} - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("inner:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), + } } - }, + ), tasks=(PregelTask(AnyStr(), "inner_2", (PULL, "inner_2")),), ), StateSnapshot( @@ -5378,6 +5695,11 @@ def outer_2(state: State): ), ] + if "shallow" in checkpointer_name: + expected_child_history = expected_child_history[:1] + + assert child_history == expected_child_history + # resume await app.ainvoke(None, config, debug=True) # test state w/ nested subgraph state (after resuming from interrupt) @@ -5402,13 +5724,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # test full history at the end actual_history = [c async for c in app.aget_state_history(config)] @@ -5436,13 +5762,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ), StateSnapshot( values={"my_key": "hi my value here and there"}, @@ -5581,6 +5911,9 @@ def outer_2(state: State): parent_config=None, ), ] + if "shallow" in checkpointer_name: + expected_history = expected_history[:1] + assert actual_history == expected_history # test looking up parent state by checkpoint ID for actual_snapshot, expected_snapshot in zip(actual_history, expected_history): @@ -5684,13 +6017,17 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) child_state = await app.aget_state(outer_state.tasks[0].state) assert ( @@ -5726,13 +6063,17 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("child:"), - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("child:"), + "checkpoint_id": AnyStr(), + } } - }, + ), ).tasks[0] ) grandchild_state = await app.aget_state(child_state.tasks[0].state) @@ -5779,20 +6120,24 @@ def parent_2(state: State): "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("child:"): AnyStr(), - AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), + } + ), + } } - }, + ), ) # get state with subgraphs assert await app.aget_state(config, subgraphs=True) == StateSnapshot( @@ -5861,22 +6206,28 @@ def parent_2(state: State): "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("child:"): AnyStr(), - AnyStr( - re.compile(r"child:.+|child1:") - ): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr( + re.compile( + r"child:.+|child1:" + ) + ): AnyStr(), + } + ), + } } - }, + ), ), ), ), @@ -5905,16 +6256,20 @@ def parent_2(state: State): "langgraph_checkpoint_ns": AnyStr("child:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("child:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - {"": AnyStr(), AnyStr("child:"): AnyStr()} - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("child:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), + } } - }, + ), ), ), ), @@ -5934,13 +6289,17 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # resume assert [c async for c in app.astream(None, config, subgraphs=True)] == [ @@ -5982,15 +6341,23 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) ) + + if "shallow" in checkpointer_name: + return + # get outer graph history outer_history = [c async for c in app.aget_state_history(config)] assert ( @@ -6707,16 +7074,23 @@ async def edit(state: JokeState): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) assert actual_snapshot == expected_snapshot + if "shallow" in checkpointer_name: + return + # test full history actual_history = [c async for c in graph.aget_state_history(config)] expected_history = [ @@ -6989,13 +7363,17 @@ def get_first_in_list(): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -7081,13 +7459,17 @@ def get_first_in_list(): "thread_id": "14", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -7130,19 +7512,23 @@ def get_first_in_list(): "langgraph_checkpoint_ns": AnyStr("weather_graph:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": AnyStr("weather_graph:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("weather_graph:"): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": AnyStr("weather_graph:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("weather_graph:"): AnyStr(), + } + ), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -7181,13 +7567,17 @@ def get_first_in_list(): "thread_id": "14", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -7238,19 +7628,23 @@ def get_first_in_list(): "langgraph_checkpoint_ns": AnyStr("weather_graph:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": AnyStr("weather_graph:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("weather_graph:"): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": AnyStr("weather_graph:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("weather_graph:"): AnyStr(), + } + ), + } } - }, + ), tasks=(), ), ), diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index b9b858c77..b08cc87a7 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -78,6 +78,7 @@ from tests.conftest import ( ALL_CHECKPOINTERS_SYNC, ALL_STORES_SYNC, + REGULAR_CHECKPOINTERS_SYNC, SHOULD_CHECK_SNAPSHOTS, ) from tests.memory_assert import MemorySaverAssertCheckpointMetadata @@ -624,7 +625,7 @@ def test_invoke_two_processes_in_out(mocker: MockerFixture) -> None: assert step == 2 -@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +@pytest.mark.parametrize("checkpointer_name", REGULAR_CHECKPOINTERS_SYNC) def test_run_from_checkpoint_id_retains_previous_writes( request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture ) -> None: @@ -1157,6 +1158,10 @@ def reset(self): # both the pending write and the new write were applied, 1 + 2 + 3 = 6 assert graph.invoke(None, thread1) == {"value": 6} + if "shallow" in checkpointer_name: + assert len(list(checkpointer.list(thread1))) == 1 + return + # check all final checkpoints checkpoints = [c for c in checkpointer.list(thread1)] # we should have 3 @@ -1618,6 +1623,9 @@ def raise_if_above_10(input: int) -> int: assert state.values.get("total") == 5 assert state.next == () + if "shallow" in checkpointer_name: + return + assert len(list(app.get_state_history(thread_1, limit=1))) == 1 # list all checkpoints for thread 1 thread_1_history = [c for c in app.get_state_history(thread_1)] @@ -2270,6 +2278,11 @@ def qa(data: State) -> State: ] app_w_interrupt.update_state(config, {"docs": ["doc5"]}) + expected_parent_config = ( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ) assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "query": "analyzed: query: what is weather in sf", @@ -2277,8 +2290,14 @@ def qa(data: State) -> State: }, tasks=(PregelTask(AnyStr(), "qa", (PULL, "qa")),), next=("qa",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -2286,7 +2305,7 @@ def qa(data: State) -> State: "writes": {"retriever_one": {"docs": ["doc5"]}}, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=expected_parent_config, ) assert [c for c in app_w_interrupt.stream(None, config, debug=1)] == [ @@ -4670,13 +4689,17 @@ class CustomParentState(TypedDict): "parents": {}, }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=(), ) @@ -5203,11 +5226,15 @@ def second_node(state: State): assert state is not None assert state.values == {"steps": ["start"], "attempt": 1} # input state saved assert state.next == ("node1",) # Should retry failed node + assert "RuntimeError('Simulated failure')" in state.tasks[0].error # Retry with updated attempt count result = graph.invoke({"steps": [], "attempt": 2}, config) assert result == {"steps": ["start", "node1", "node2"], "attempt": 2} + if "shallow" in checkpointer_name: + return + # Verify checkpoint history shows both attempts history = list(graph.get_state_history(config)) assert len(history) == 6 # Initial + failed attempt + successful attempt diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 45f1748fb..02b985c62 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -75,6 +75,7 @@ ALL_CHECKPOINTERS_ASYNC, ALL_CHECKPOINTERS_ASYNC_PLUS_NONE, ALL_STORES_ASYNC, + REGULAR_CHECKPOINTERS_ASYNC, SHOULD_CHECK_SNAPSHOTS, awith_checkpointer, awith_store, @@ -347,22 +348,23 @@ async def tool_two_node(s: State) -> State: ) }, ] - assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [ - { - "parents": {}, - "source": "loop", - "step": 0, - "writes": None, - "thread_id": "1", - }, - { - "parents": {}, - "source": "input", - "step": -1, - "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, - "thread_id": "1", - }, - ] + if "shallow" not in checkpointer_name: + assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] tup = await tool_two.checkpointer.aget_tuple(thread1) assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, @@ -390,9 +392,13 @@ async def tool_two_node(s: State) -> State: "writes": None, "thread_id": "1", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) # clear the interrupt and next tasks @@ -412,9 +418,13 @@ async def tool_two_node(s: State) -> State: "writes": {}, "thread_id": "1", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) @@ -524,22 +534,25 @@ class State(TypedDict): ) }, ] - assert [c.metadata async for c in tool_two.checkpointer.alist(thread1root)] == [ - { - "parents": {}, - "source": "loop", - "step": 0, - "writes": None, - "thread_id": "1", - }, - { - "parents": {}, - "source": "input", - "step": -1, - "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, - "thread_id": "1", - }, - ] + if "shallow" not in checkpointer_name: + assert [ + c.metadata async for c in tool_two.checkpointer.alist(thread1root) + ] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] tup = await tool_two.checkpointer.aget_tuple(thread1) assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, @@ -573,9 +586,13 @@ class State(TypedDict): "writes": None, "thread_id": "1", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1root, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in tool_two.checkpointer.alist(thread1root, limit=2) + ][-1].config + ), ) # clear the interrupt and next tasks @@ -595,9 +612,13 @@ class State(TypedDict): "writes": {}, "thread_id": "1", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1root, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in tool_two.checkpointer.alist(thread1root, limit=2) + ][-1].config + ), ) @@ -699,22 +720,25 @@ def start(state: State) -> list[Union[Send, str]]: "my_key": "value ⛰️ one", "market": "DE", } - assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [ - { - "parents": {}, - "source": "loop", - "step": 0, - "writes": {"tool_one": {"my_key": " one"}}, - "thread_id": "1", - }, - { - "parents": {}, - "source": "input", - "step": -1, - "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, - "thread_id": "1", - }, - ] + + if "shallow" not in checkpointer_name: + assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": {"tool_one": {"my_key": " one"}}, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] + tup = await tool_two.checkpointer.aget_tuple(thread1) assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️ one", "market": "DE"}, @@ -742,9 +766,13 @@ def start(state: State) -> list[Union[Send, str]]: "writes": {"tool_one": {"my_key": " one"}}, "thread_id": "1", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) # clear the interrupt and next tasks await tool_two.aupdate_state(thread1, None) @@ -770,9 +798,13 @@ def start(state: State) -> list[Union[Send, str]]: "writes": {}, "thread_id": "1", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) @@ -1754,6 +1786,10 @@ def reset(self): # both the pending write and the new write were applied, 1 + 2 + 3 = 6 assert await graph.ainvoke(None, thread1) == {"value": 6} + if "shallow" in checkpointer_name: + assert len([c async for c in checkpointer.alist(thread1)]) == 1 + return + # check all final checkpoints checkpoints = [c async for c in checkpointer.alist(thread1)] # we should have 3 @@ -1911,7 +1947,7 @@ def reset(self): ) -@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +@pytest.mark.parametrize("checkpointer_name", REGULAR_CHECKPOINTERS_ASYNC) async def test_run_from_checkpoint_id_retains_previous_writes( request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture ) -> None: @@ -2339,7 +2375,7 @@ async def graph(state: dict) -> dict: ] -@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +@pytest.mark.parametrize("checkpointer_name", REGULAR_CHECKPOINTERS_ASYNC) async def test_send_dedupe_on_resume(checkpointer_name: str) -> None: if not FF_SEND_V2: pytest.skip("Send deduplication is only available in Send V2") @@ -2792,13 +2828,17 @@ async def foo(call: ToolCall): "thread_id": "2", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -2875,13 +2915,17 @@ async def foo(call: ToolCall): "thread_id": "2", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=(), ) @@ -2951,13 +2995,17 @@ async def foo(call: ToolCall): "thread_id": "3", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "3", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -3060,13 +3108,17 @@ async def foo(call: ToolCall): "thread_id": "3", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "3", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -3260,13 +3312,17 @@ async def foo(call: ToolCall): "thread_id": "2", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -3343,13 +3399,17 @@ async def foo(call: ToolCall): "thread_id": "2", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=(), ) @@ -3572,6 +3632,9 @@ def raise_if_above_10(input: int) -> int: assert state.values.get("total") == 5 assert state.next == () + if "shallow" in checkpointer_name: + return + assert len([c async for c in app.aget_state_history(thread_1, limit=1)]) == 1 # list all checkpoints for thread 1 thread_1_history = [c async for c in app.aget_state_history(thread_1)] @@ -4279,13 +4342,17 @@ async def decider(data: State) -> str: "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) async with assert_ctx_once(): @@ -5758,13 +5825,17 @@ class CustomParentState(TypedDict): "parents": {}, }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=(), ) @@ -6281,6 +6352,9 @@ async def second_node(state: State): result = await graph.ainvoke({"steps": [], "attempt": 2}, config) assert result == {"steps": ["start", "node1", "node2"], "attempt": 2} + if "shallow" in checkpointer_name: + return + # Verify checkpoint history shows both attempts history = [c async for c in graph.aget_state_history(config)] assert len(history) == 6 # Initial + failed attempt + successful attempt