diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py index 3ffce2168..475761808 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py @@ -19,6 +19,7 @@ ) from langgraph.checkpoint.postgres import _internal from langgraph.checkpoint.postgres.base import BasePostgresSaver +from langgraph.checkpoint.postgres.shallow import ShallowPostgresSaver from langgraph.checkpoint.serde.base import SerializerProtocol Conn = _internal.Conn # For backward compatibility @@ -396,83 +397,4 @@ def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]: yield cur -class ShallowPostgresSaver(PostgresSaver): - def put( - self, - config: RunnableConfig, - checkpoint: Checkpoint, - metadata: CheckpointMetadata, - new_versions: ChannelVersions, - ) -> RunnableConfig: - """Save a checkpoint to the database. - - This method saves a checkpoint to the Postgres database. The checkpoint is associated - with the provided config and its parent config (if any). - - Args: - config (RunnableConfig): The config to associate with the checkpoint. - checkpoint (Checkpoint): The checkpoint to save. - metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. - new_versions (ChannelVersions): New channel versions as of this write. - - Returns: - RunnableConfig: Updated configuration after storing the checkpoint. - - Examples: - - >>> from langgraph.checkpoint.postgres import PostgresSaver - >>> DB_URI = "postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable" - >>> with PostgresSaver.from_conn_string(DB_URI) as memory: - >>> config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} - >>> checkpoint = {"ts": "2024-05-04T06:32:42.235444+00:00", "id": "1ef4f797-8335-6428-8001-8a1503f9b875", "channel_values": {"key": "value"}} - >>> saved_config = memory.put(config, checkpoint, {"source": "input", "step": 1, "writes": {"key": "value"}}, {}) - >>> print(saved_config) - {'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '1ef4f797-8335-6428-8001-8a1503f9b875'}} - """ - configurable = config["configurable"].copy() - thread_id = configurable.pop("thread_id") - checkpoint_ns = configurable.pop("checkpoint_ns") - checkpoint_id = configurable.pop( - "checkpoint_id", configurable.pop("thread_ts", None) - ) - - copy = checkpoint.copy() - next_config = { - "configurable": { - "thread_id": thread_id, - "checkpoint_ns": checkpoint_ns, - "checkpoint_id": checkpoint["id"], - } - } - - with self._cursor(pipeline=True) as cur: - # clear existing checkpoints - cur.execute( - "DELETE FROM checkpoints WHERE thread_id = %s AND checkpoint_ns = %s", - (thread_id, checkpoint_ns), - ) - # write new checkpoint - cur.executemany( - self.UPSERT_CHECKPOINT_BLOBS_SQL, - self._dump_blobs( - thread_id, - checkpoint_ns, - copy.pop("channel_values"), # type: ignore[misc] - new_versions, - ), - ) - cur.execute( - self.UPSERT_CHECKPOINTS_SQL, - ( - thread_id, - checkpoint_ns, - checkpoint["id"], - checkpoint_id, - Jsonb(self._dump_checkpoint(copy)), - self._dump_metadata(metadata), - ), - ) - return next_config - - -__all__ = ["PostgresSaver", "BasePostgresSaver", "Conn"] +__all__ = ["PostgresSaver", "BasePostgresSaver", "ShallowPostgresSaver", "Conn"] diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/shallow.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/shallow.py new file mode 100644 index 000000000..7b1560a68 --- /dev/null +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/shallow.py @@ -0,0 +1,370 @@ +import threading +from collections.abc import Iterator, Sequence +from contextlib import contextmanager +from typing import Any, Optional + +from langchain_core.runnables import RunnableConfig +from psycopg import Capabilities, Connection, Cursor, Pipeline +from psycopg.rows import DictRow, dict_row +from psycopg.types.json import Jsonb +from psycopg_pool import ConnectionPool + +from langgraph.checkpoint.base import ( + WRITES_IDX_MAP, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + CheckpointTuple, +) +from langgraph.checkpoint.postgres import _internal +from langgraph.checkpoint.postgres.base import BasePostgresSaver +from langgraph.checkpoint.serde.base import SerializerProtocol +from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer +from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol + +Conn = _internal.Conn # For backward compatibility + + +""" +To add a new migration, add a new string to the MIGRATIONS list. +The position of the migration in the list is the version number. +""" +MIGRATIONS = [ + """CREATE TABLE IF NOT EXISTS checkpoint_migrations ( + v INTEGER PRIMARY KEY +);""", + """CREATE TABLE IF NOT EXISTS checkpoints ( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + type TEXT, + checkpoint JSONB NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}', + PRIMARY KEY (thread_id, checkpoint_ns) +);""", + """CREATE TABLE IF NOT EXISTS checkpoint_blobs ( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + channel TEXT NOT NULL, + version TEXT NOT NULL, + type TEXT NOT NULL, + blob BYTEA, + PRIMARY KEY (thread_id, checkpoint_ns, channel, version) +);""", + """CREATE TABLE IF NOT EXISTS checkpoint_writes ( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + task_id TEXT NOT NULL, + idx INTEGER NOT NULL, + channel TEXT NOT NULL, + type TEXT, + blob BYTEA NOT NULL, + PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) +);""", + """ + CREATE INDEX CONCURRENTLY IF NOT EXISTS checkpoints_thread_id_idx ON checkpoints(thread_id); + """, + """ + CREATE INDEX CONCURRENTLY IF NOT EXISTS checkpoint_blobs_thread_id_idx ON checkpoint_blobs(thread_id); + """, + """ + CREATE INDEX CONCURRENTLY IF NOT EXISTS checkpoint_writes_thread_id_idx ON checkpoint_writes(thread_id); + """, +] + +SELECT_SQL = f""" +select + thread_id, + checkpoint, + checkpoint_ns, + metadata, + ( + select array_agg(array[bl.channel::bytea, bl.type::bytea, bl.blob]) + from jsonb_each_text(checkpoint -> 'channel_versions') + inner join checkpoint_blobs bl + on bl.thread_id = checkpoints.thread_id + and bl.checkpoint_ns = checkpoints.checkpoint_ns + and bl.channel = jsonb_each_text.key + and bl.version = jsonb_each_text.value + ) as channel_values, + ( + select + array_agg(array[cw.task_id::text::bytea, cw.channel::bytea, cw.type::bytea, cw.blob] order by cw.task_id, cw.idx) + from checkpoint_writes cw + where cw.thread_id = checkpoints.thread_id + and cw.checkpoint_ns = checkpoints.checkpoint_ns + and cw.checkpoint_id = (checkpoint->>'id') + ) as pending_writes, + ( + select array_agg(array[cw.type::bytea, cw.blob] order by cw.task_id, cw.idx) + from checkpoint_writes cw + where cw.thread_id = checkpoints.thread_id + and cw.checkpoint_ns = checkpoints.checkpoint_ns + and cw.channel = '{TASKS}' + ) as pending_sends +from checkpoints """ + +UPSERT_CHECKPOINT_BLOBS_SQL = """ + INSERT INTO checkpoint_blobs (thread_id, checkpoint_ns, channel, version, type, blob) + VALUES (%s, %s, %s, %s, %s, %s) + ON CONFLICT (thread_id, checkpoint_ns, channel, version) DO NOTHING +""" + +UPSERT_CHECKPOINTS_SQL = """ + INSERT INTO checkpoints (thread_id, checkpoint_ns, checkpoint, metadata) + VALUES (%s, %s, %s, %s) + ON CONFLICT (thread_id, checkpoint_ns) + DO UPDATE SET + checkpoint = EXCLUDED.checkpoint, + metadata = EXCLUDED.metadata; +""" + +UPSERT_CHECKPOINT_WRITES_SQL = """ + INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO UPDATE SET + channel = EXCLUDED.channel, + type = EXCLUDED.type, + blob = EXCLUDED.blob; +""" + +INSERT_CHECKPOINT_WRITES_SQL = """ + INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO NOTHING +""" + + +class ShallowPostgresSaver(BasePostgresSaver): + SELECT_SQL = SELECT_SQL + MIGRATIONS = MIGRATIONS + UPSERT_CHECKPOINT_BLOBS_SQL = UPSERT_CHECKPOINT_BLOBS_SQL + UPSERT_CHECKPOINTS_SQL = UPSERT_CHECKPOINTS_SQL + UPSERT_CHECKPOINT_WRITES_SQL = UPSERT_CHECKPOINT_WRITES_SQL + INSERT_CHECKPOINT_WRITES_SQL = INSERT_CHECKPOINT_WRITES_SQL + + lock: threading.Lock + + def __init__( + self, + conn: _internal.Conn, + pipe: Optional[Pipeline] = None, + serde: Optional[SerializerProtocol] = None, + ) -> None: + super().__init__(serde=serde) + if isinstance(conn, ConnectionPool) and pipe is not None: + raise ValueError( + "Pipeline should be used only with a single Connection, not ConnectionPool." + ) + + self.conn = conn + self.pipe = pipe + self.lock = threading.Lock() + self.supports_pipeline = Capabilities().has_pipeline() + + @classmethod + @contextmanager + def from_conn_string( + cls, conn_string: str, *, pipeline: bool = False + ) -> Iterator["ShallowPostgresSaver"]: + """Create a new ShallowPostgresSaver instance from a connection string. + + Args: + conn_string (str): The Postgres connection info string. + pipeline (bool): whether to use Pipeline + + Returns: + ShallowPostgresSaver: A new ShallowPostgresSaver instance. + """ + with Connection.connect( + conn_string, autocommit=True, prepare_threshold=0, row_factory=dict_row + ) as conn: + if pipeline: + with conn.pipeline() as pipe: + yield cls(conn, pipe) + else: + yield cls(conn) + + def setup(self) -> None: + """Set up the checkpoint database asynchronously. + + This method creates the necessary tables in the Postgres database if they don't + already exist and runs database migrations. It MUST be called directly by the user + the first time checkpointer is used. + """ + with self._cursor() as cur: + cur.execute(self.MIGRATIONS[0]) + results = cur.execute( + "SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1" + ) + row = results.fetchone() + if row is None: + version = -1 + else: + version = row["v"] + for v, migration in zip( + range(version + 1, len(self.MIGRATIONS)), + self.MIGRATIONS[version + 1 :], + ): + cur.execute(migration) + cur.execute(f"INSERT INTO checkpoint_migrations (v) VALUES ({v})") + if self.pipe: + self.pipe.sync() + + def list( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> Iterator[CheckpointTuple]: + checkpoint_tuple = self.get_tuple(config) + return [] if checkpoint_tuple is None else [checkpoint_tuple] + + def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"].get("checkpoint_ns", "") + args = (thread_id, checkpoint_ns) + where = "WHERE thread_id = %s AND checkpoint_ns = %s" + + with self._cursor() as cur: + cur.execute( + self.SELECT_SQL + where, + args, + binary=True, + ) + + for value in cur: + checkpoint = self._load_checkpoint( + value["checkpoint"], + value["channel_values"], + value["pending_sends"], + ) + return CheckpointTuple( + config={ + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + }, + checkpoint=checkpoint, + metadata=self._load_metadata(value["metadata"]), + pending_writes=self._load_writes(value["pending_writes"]), + ) + + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + configurable = config["configurable"].copy() + thread_id = configurable.pop("thread_id") + checkpoint_ns = configurable.pop("checkpoint_ns") + + copy = checkpoint.copy() + next_config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + } + + with self._cursor(pipeline=True) as cur: + cur.executemany( + self.UPSERT_CHECKPOINT_BLOBS_SQL, + self._dump_blobs( + thread_id, + checkpoint_ns, + copy.pop("channel_values"), # type: ignore[misc] + new_versions, + ), + ) + cur.execute( + self.UPSERT_CHECKPOINTS_SQL, + ( + thread_id, + checkpoint_ns, + Jsonb(self._dump_checkpoint(copy)), + self._dump_metadata(metadata), + ), + ) + return next_config + + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str, + ) -> None: + """Store intermediate writes linked to a checkpoint. + + This method saves intermediate writes associated with a checkpoint to the Postgres database. + + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (List[Tuple[str, Any]]): List of writes to store. + task_id (str): Identifier for the task creating the writes. + """ + query = ( + self.UPSERT_CHECKPOINT_WRITES_SQL + if all(w[0] in WRITES_IDX_MAP for w in writes) + else self.INSERT_CHECKPOINT_WRITES_SQL + ) + with self._cursor(pipeline=True) as cur: + cur.executemany( + query, + self._dump_writes( + config["configurable"]["thread_id"], + config["configurable"]["checkpoint_ns"], + config["configurable"]["checkpoint_id"], + task_id, + writes, + ), + ) + + @contextmanager + def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]: + """Create a database cursor as a context manager. + + Args: + pipeline (bool): whether to use pipeline for the DB operations inside the context manager. + Will be applied regardless of whether the ShallowPostgresSaver instance was initialized with a pipeline. + If pipeline mode is not supported, will fall back to using transaction context manager. + """ + with _internal.get_connection(self.conn) as conn: + if self.pipe: + # a connection in pipeline mode can be used concurrently + # in multiple threads/coroutines, but only one cursor can be + # used at a time + try: + with conn.cursor(binary=True, row_factory=dict_row) as cur: + yield cur + finally: + if pipeline: + self.pipe.sync() + elif pipeline: + # a connection not in pipeline mode can only be used by one + # thread/coroutine at a time, so we acquire a lock + if self.supports_pipeline: + with ( + self.lock, + conn.pipeline(), + conn.cursor(binary=True, row_factory=dict_row) as cur, + ): + yield cur + else: + # Use connection's transaction context manager when pipeline mode not supported + with ( + self.lock, + conn.transaction(), + conn.cursor(binary=True, row_factory=dict_row) as cur, + ): + yield cur + else: + with self.lock, conn.cursor(binary=True, row_factory=dict_row) as cur: + yield cur diff --git a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py index ac579080e..cb6b7b852 100644 --- a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py @@ -468,49 +468,6 @@ def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> return f"{next_v:032}.{next_h:016}" -class ShallowMemorySaver(MemorySaver): - def put( - self, - config: RunnableConfig, - checkpoint: Checkpoint, - metadata: CheckpointMetadata, - new_versions: ChannelVersions, - ) -> RunnableConfig: - """Save a checkpoint to the in-memory storage. - - This method saves a checkpoint to the in-memory storage. The checkpoint is associated - with the provided config. - - Args: - config (RunnableConfig): The config to associate with the checkpoint. - checkpoint (Checkpoint): The checkpoint to save. - metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. - new_versions (dict): New versions as of this write - - Returns: - RunnableConfig: The updated config containing the saved checkpoint's timestamp. - """ - c = checkpoint.copy() - c.pop("pending_sends") # type: ignore[misc] - thread_id = config["configurable"]["thread_id"] - checkpoint_ns = config["configurable"]["checkpoint_ns"] - # always overwrite the last checkpoint - self.storage[thread_id][checkpoint_ns] = { - checkpoint["id"]: ( - self.serde.dumps_typed(c), - self.serde.dumps_typed(metadata), - config["configurable"].get("checkpoint_id"), # parent - ) - } - return { - "configurable": { - "thread_id": thread_id, - "checkpoint_ns": checkpoint_ns, - "checkpoint_id": checkpoint["id"], - } - } - - class PersistentDict(defaultdict): """Persistent dictionary with an API compatible with shelve and anydbm. diff --git a/libs/langgraph/tests/conftest.py b/libs/langgraph/tests/conftest.py index 39ad32489..aea36597c 100644 --- a/libs/langgraph/tests/conftest.py +++ b/libs/langgraph/tests/conftest.py @@ -55,13 +55,6 @@ def checkpointer_memory(): yield MemorySaverAssertImmutable() -@pytest.fixture(scope="function") -def checkpointer_shallow_memory(): - from langgraph.checkpoint.memory import ShallowMemorySaver - - yield ShallowMemorySaver() - - @pytest.fixture(scope="function") def checkpointer_sqlite(): with SqliteSaver.from_conn_string(":memory:") as checkpointer: @@ -443,7 +436,7 @@ async def awith_store(store_name: Optional[str]) -> AsyncIterator[BaseStore]: raise NotImplementedError(f"Unknown store {store_name}") -SHALLOW_CHECKPOINTERS_SYNC = ["shallow_memory", "shallow_postgres"] +SHALLOW_CHECKPOINTERS_SYNC = ["shallow_postgres"] REGULAR_CHECKPOINTERS_SYNC = [ "memory", "sqlite", diff --git a/libs/langgraph/tests/test_large_cases.py b/libs/langgraph/tests/test_large_cases.py index 72df87b9b..426d7692b 100644 --- a/libs/langgraph/tests/test_large_cases.py +++ b/libs/langgraph/tests/test_large_cases.py @@ -775,13 +775,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert ( app_w_interrupt.checkpointer.get_tuple(config).config["configurable"][ @@ -839,13 +837,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -972,13 +968,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # test state get/update methods with interrupt_before @@ -1040,13 +1034,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -1098,13 +1090,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -1231,13 +1221,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # test re-invoke to continue with interrupt_before @@ -1299,13 +1287,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "3", }, - parent_config={ - "configurable": { - "thread_id": "3", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -1697,13 +1683,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) with assert_ctx_once(): @@ -1752,13 +1736,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) with assert_ctx_once(): @@ -1841,13 +1823,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # test state get/update methods with interrupt_before @@ -1905,13 +1885,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -1959,13 +1937,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -2046,13 +2022,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # test w interrupt before all @@ -2091,13 +2065,11 @@ def should_continue(data: AgentState) -> str: "writes": None, "thread_id": "3", }, - parent_config={ - "configurable": { - "thread_id": "3", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -2143,13 +2115,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "3", }, - parent_config={ - "configurable": { - "thread_id": "3", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -2216,13 +2186,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "3", }, - parent_config={ - "configurable": { - "thread_id": "3", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -2291,13 +2259,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "4", }, - parent_config={ - "configurable": { - "thread_id": "4", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -2364,13 +2330,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "4", }, - parent_config={ - "configurable": { - "thread_id": "4", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -3952,13 +3916,11 @@ def should_continue(messages): }, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # modify ai message @@ -4005,13 +3967,11 @@ def should_continue(messages): }, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -4100,13 +4060,11 @@ def should_continue(messages): }, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -4153,13 +4111,11 @@ def should_continue(messages): "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt = workflow.compile( @@ -4230,13 +4186,11 @@ def should_continue(messages): }, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # modify ai message @@ -4289,13 +4243,11 @@ def should_continue(messages): }, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -4384,13 +4336,11 @@ def should_continue(messages): }, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -4438,13 +4388,11 @@ def should_continue(messages): "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # add an extra message as if it came from "tools" node @@ -4492,13 +4440,11 @@ def should_continue(messages): "writes": {"tools": UnsortedSequence("ai", "an extra message")}, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) @@ -4793,13 +4739,11 @@ class State(TypedDict): }, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # modify ai message @@ -4846,13 +4790,11 @@ class State(TypedDict): }, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -4942,13 +4884,11 @@ class State(TypedDict): }, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -4996,13 +4936,11 @@ class State(TypedDict): "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt = workflow.compile( @@ -5073,13 +5011,11 @@ class State(TypedDict): }, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # modify ai message @@ -5132,13 +5068,11 @@ class State(TypedDict): }, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -5228,13 +5162,11 @@ class State(TypedDict): }, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -5281,13 +5213,11 @@ class State(TypedDict): "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # add an extra message as if it came from "tools" node @@ -5335,13 +5265,11 @@ class State(TypedDict): "writes": {"tools": UnsortedSequence("ai", "an extra message")}, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # create new graph with one more state key, reuse previous thread history @@ -5420,13 +5348,11 @@ class MoreState(TypedDict): "writes": {"tools": UnsortedSequence("ai", "an extra message")}, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(new_app.checkpointer.list(config, limit=2))[-1].config + ), ) # new input is merged to old state @@ -5823,13 +5749,11 @@ def tool_two_node(s: State) -> State: "writes": None, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) # clear the interrupt and next tasks tool_two.update_state(thread1, None, as_node=END) @@ -5853,13 +5777,11 @@ def tool_two_node(s: State) -> State: "writes": {}, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) @@ -6173,13 +6095,11 @@ class State(TypedDict): "writes": None, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) # clear the interrupt and next tasks tool_two.update_state(thread1, None, as_node=END) @@ -6203,13 +6123,11 @@ class State(TypedDict): "writes": {}, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) @@ -6320,13 +6238,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread1, debug=1) == { @@ -6353,13 +6269,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) thread2 = {"configurable": {"thread_id": "2", "assistant_id": "a"}} @@ -6388,13 +6302,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread2, debug=1) == { @@ -6421,13 +6333,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) thread3 = {"configurable": {"thread_id": "3", "assistant_id": "b"}} @@ -6456,13 +6366,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "b", "thread_id": "3", }, - parent_config={ - "configurable": { - "thread_id": "3", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread3, limit=2))[-1].config + ), ) # update state tool_two.update_state(thread3, {"my_key": "key"}) # appends to my_key @@ -6486,13 +6394,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "b", "thread_id": "3", }, - parent_config={ - "configurable": { - "thread_id": "3", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread3, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread3, debug=1) == { @@ -6519,13 +6425,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "b", "thread_id": "3", }, - parent_config={ - "configurable": { - "thread_id": "3", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread3, limit=2))[-1].config + ), ) @@ -6894,13 +6798,11 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread1, debug=1) == { @@ -6926,13 +6828,11 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "1", }, - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) thread2 = {"configurable": {"thread_id": "2"}} @@ -6960,13 +6860,11 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread2, debug=1) == { @@ -6992,13 +6890,11 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) tool_two = tool_two_graph.compile( @@ -7034,13 +6930,11 @@ class State(TypedDict): "writes": {"tool_two_slow": {"my_key": " slow"}}, "thread_id": "11", }, - parent_config={ - "configurable": { - "thread_id": "11", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) # update state @@ -7067,13 +6961,11 @@ class State(TypedDict): "writes": {"tool_two_slow": {"my_key": "er"}}, "thread_id": "11", }, - parent_config={ - "configurable": { - "thread_id": "11", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) tool_two = tool_two_graph.compile( @@ -7109,13 +7001,11 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "21", }, - parent_config={ - "configurable": { - "thread_id": "21", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread1, debug=1) == { @@ -7141,13 +7031,11 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "21", }, - parent_config={ - "configurable": { - "thread_id": "21", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) thread2 = {"configurable": {"thread_id": "22"}} @@ -7175,13 +7063,11 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "22", }, - parent_config={ - "configurable": { - "thread_id": "22", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread2, debug=1) == { @@ -7207,24 +7093,28 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "22", }, - parent_config={ - "configurable": { - "thread_id": "22", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) thread3 = {"configurable": {"thread_id": "23"}} # update an empty thread before first run - uconfig = tool_two.update_state(thread3, {"my_key": "key", "market": "DE"}) + tool_two.update_state(thread3, {"my_key": "key", "market": "DE"}) # check current state assert tool_two.get_state(thread3) == StateSnapshot( values={"my_key": "key", "market": "DE"}, tasks=(PregelTask(AnyStr(), "prepare", (PULL, "prepare")),), next=("prepare",), - config=uconfig, + config={ + "configurable": { + "thread_id": "23", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, created_at=AnyStr(), metadata={ "parents": {}, @@ -7260,7 +7150,11 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "23", }, - parent_config=uconfig, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread3, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread3, debug=1) == { @@ -7286,13 +7180,11 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "23", }, - parent_config={ - "configurable": { - "thread_id": "23", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - } - }, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread3, limit=2))[-1].config + ), ) @@ -7700,13 +7592,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # now, get_state with subgraphs state assert app.get_state(config, subgraphs=True) == StateSnapshot( @@ -7760,16 +7656,20 @@ def outer_2(state: State): "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("inner:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - {"": AnyStr(), AnyStr("child:"): AnyStr()} - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("inner:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), + } } - }, + ), ), ), ), @@ -7789,13 +7689,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # get_state_history returns outer graph checkpoints history = list(app.get_state_history(config)) @@ -7831,13 +7735,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ), StateSnapshot( values={"my_key": "my value"}, @@ -7943,16 +7851,20 @@ def outer_2(state: State): "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("inner:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - {"": AnyStr(), AnyStr("child:"): AnyStr()} - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("inner:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), + } } - }, + ), tasks=(PregelTask(AnyStr(), "inner_2", (PULL, "inner_2")),), ), StateSnapshot( @@ -8072,13 +7984,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # test full history at the end actual_history = list(app.get_state_history(config)) @@ -8104,13 +8020,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ), StateSnapshot( values={"my_key": "hi my value here and there"}, @@ -8353,13 +8273,17 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) child_state = app.get_state(outer_state.tasks[0].state) assert ( @@ -8395,13 +8319,17 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("child:"), - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("child:"), + "checkpoint_id": AnyStr(), + } } - }, + ), ).tasks[0] ) grandchild_state = app.get_state(child_state.tasks[0].state) @@ -8448,20 +8376,24 @@ def parent_2(state: State): "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("child:"): AnyStr(), - AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), + } + ), + } } - }, + ), ) # get state with subgraphs assert app.get_state(config, subgraphs=True) == StateSnapshot( @@ -8528,22 +8460,26 @@ def parent_2(state: State): "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("child:"): AnyStr(), - AnyStr( - re.compile(r"child:.+|child1:") - ): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr( + re.compile(r"child:.+|child1:") + ): AnyStr(), + } + ), + } } - }, + ), ), ), ), @@ -8572,16 +8508,20 @@ def parent_2(state: State): "langgraph_checkpoint_ns": AnyStr("child:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("child:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - {"": AnyStr(), AnyStr("child:"): AnyStr()} - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("child:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), + } } - }, + ), ), ), ), @@ -8601,13 +8541,17 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # # resume assert [c for c in app.stream(None, config, subgraphs=True)] == [ @@ -8644,13 +8588,17 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) ) @@ -10460,13 +10408,17 @@ def weather_graph(state: RouterState): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -10548,13 +10500,17 @@ def weather_graph(state: RouterState): "thread_id": "14", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -10597,19 +10553,23 @@ def weather_graph(state: RouterState): "langgraph_checkpoint_ns": AnyStr("weather_graph:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": AnyStr("weather_graph:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("weather_graph:"): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": AnyStr("weather_graph:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("weather_graph:"): AnyStr(), + } + ), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -10648,13 +10608,17 @@ def weather_graph(state: RouterState): "thread_id": "14", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -10703,19 +10667,23 @@ def weather_graph(state: RouterState): "langgraph_checkpoint_ns": AnyStr("weather_graph:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": AnyStr("weather_graph:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("weather_graph:"): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": AnyStr("weather_graph:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("weather_graph:"): AnyStr(), + } + ), + } } - }, + ), tasks=(), ), ), diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 968dd8ebf..185c57765 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -2278,6 +2278,11 @@ def qa(data: State) -> State: ] app_w_interrupt.update_state(config, {"docs": ["doc5"]}) + expected_parent_config = ( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ) assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "query": "analyzed: query: what is weather in sf", @@ -2300,13 +2305,7 @@ def qa(data: State) -> State: "writes": {"retriever_one": {"docs": ["doc5"]}}, "thread_id": "2", }, - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=expected_parent_config, ) assert [c for c in app_w_interrupt.stream(None, config, debug=1)] == [ @@ -4661,6 +4660,11 @@ class CustomParentState(TypedDict): ], "user_name": "Meow", } + expected_parent_config = ( + None + if "shallow" in checkpointer_name + else list(graph.checkpointer.list(config, limit=2))[-1].config + ) assert graph.get_state(config) == StateSnapshot( values={ "messages": [ @@ -4690,13 +4694,7 @@ class CustomParentState(TypedDict): "parents": {}, }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), - } - }, + parent_config=expected_parent_config, tasks=(), )