From b2b51a362d9f49216bf12f59a1dfa59bbb9f9e5c Mon Sep 17 00:00:00 2001 From: vbarda Date: Thu, 19 Dec 2024 13:27:06 -0500 Subject: [PATCH] cr --- .../langgraph/checkpoint/postgres/shallow.py | 39 ++++++++++++++----- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/shallow.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/shallow.py index 5cc0471e3..9e448ae2f 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/shallow.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/shallow.py @@ -21,9 +21,6 @@ from langgraph.checkpoint.serde.base import SerializerProtocol from langgraph.checkpoint.serde.types import TASKS -Conn = _internal.Conn # For backward compatibility - - """ To add a new migration, add a new string to the MIGRATIONS list. The position of the migration in the list is the version number. @@ -44,10 +41,9 @@ thread_id TEXT NOT NULL, checkpoint_ns TEXT NOT NULL DEFAULT '', channel TEXT NOT NULL, - version TEXT NOT NULL, type TEXT NOT NULL, blob BYTEA, - PRIMARY KEY (thread_id, checkpoint_ns, channel, version) + PRIMARY KEY (thread_id, checkpoint_ns, channel) );""", """CREATE TABLE IF NOT EXISTS checkpoint_writes ( thread_id TEXT NOT NULL, @@ -84,7 +80,6 @@ on bl.thread_id = checkpoints.thread_id and bl.checkpoint_ns = checkpoints.checkpoint_ns and bl.channel = jsonb_each_text.key - and bl.version = jsonb_each_text.value ) as channel_values, ( select @@ -104,9 +99,11 @@ from checkpoints """ UPSERT_CHECKPOINT_BLOBS_SQL = """ - INSERT INTO checkpoint_blobs (thread_id, checkpoint_ns, channel, version, type, blob) - VALUES (%s, %s, %s, %s, %s, %s) - ON CONFLICT (thread_id, checkpoint_ns, channel, version) DO NOTHING + INSERT INTO checkpoint_blobs (thread_id, checkpoint_ns, channel, type, blob) + VALUES (%s, %s, %s, %s, %s) + ON CONFLICT (thread_id, checkpoint_ns, channel) DO UPDATE SET + type = EXCLUDED.type, + blob = EXCLUDED.blob; """ UPSERT_CHECKPOINTS_SQL = """ @@ -161,6 +158,30 @@ def __init__( self.lock = threading.Lock() self.supports_pipeline = Capabilities().has_pipeline() + def _dump_blobs( + self, + thread_id: str, + checkpoint_ns: str, + values: dict[str, Any], + versions: ChannelVersions, + ) -> list[tuple[str, str, str, str, str, Optional[bytes]]]: + if not versions: + return [] + + return [ + ( + thread_id, + checkpoint_ns, + k, + *( + self.serde.dumps_typed(values[k]) + if k in values + else ("empty", None) + ), + ) + for k, ver in versions.items() + ] + @classmethod @contextmanager def from_conn_string(