Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda committed Dec 19, 2024
1 parent de6e881 commit b2b51a3
Showing 1 changed file with 30 additions and 9 deletions.
39 changes: 30 additions & 9 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/shallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
from langgraph.checkpoint.serde.base import SerializerProtocol
from langgraph.checkpoint.serde.types import TASKS

Conn = _internal.Conn # For backward compatibility


"""
To add a new migration, add a new string to the MIGRATIONS list.
The position of the migration in the list is the version number.
Expand All @@ -44,10 +41,9 @@
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
channel TEXT NOT NULL,
version TEXT NOT NULL,
type TEXT NOT NULL,
blob BYTEA,
PRIMARY KEY (thread_id, checkpoint_ns, channel, version)
PRIMARY KEY (thread_id, checkpoint_ns, channel)
);""",
"""CREATE TABLE IF NOT EXISTS checkpoint_writes (
thread_id TEXT NOT NULL,
Expand Down Expand Up @@ -84,7 +80,6 @@
on bl.thread_id = checkpoints.thread_id
and bl.checkpoint_ns = checkpoints.checkpoint_ns
and bl.channel = jsonb_each_text.key
and bl.version = jsonb_each_text.value
) as channel_values,
(
select
Expand All @@ -104,9 +99,11 @@
from checkpoints """

UPSERT_CHECKPOINT_BLOBS_SQL = """
INSERT INTO checkpoint_blobs (thread_id, checkpoint_ns, channel, version, type, blob)
VALUES (%s, %s, %s, %s, %s, %s)
ON CONFLICT (thread_id, checkpoint_ns, channel, version) DO NOTHING
INSERT INTO checkpoint_blobs (thread_id, checkpoint_ns, channel, type, blob)
VALUES (%s, %s, %s, %s, %s)
ON CONFLICT (thread_id, checkpoint_ns, channel) DO UPDATE SET
type = EXCLUDED.type,
blob = EXCLUDED.blob;
"""

UPSERT_CHECKPOINTS_SQL = """
Expand Down Expand Up @@ -161,6 +158,30 @@ def __init__(
self.lock = threading.Lock()
self.supports_pipeline = Capabilities().has_pipeline()

def _dump_blobs(
self,
thread_id: str,
checkpoint_ns: str,
values: dict[str, Any],
versions: ChannelVersions,
) -> list[tuple[str, str, str, str, str, Optional[bytes]]]:
if not versions:
return []

return [
(
thread_id,
checkpoint_ns,
k,
*(
self.serde.dumps_typed(values[k])
if k in values
else ("empty", None)
),
)
for k, ver in versions.items()
]

@classmethod
@contextmanager
def from_conn_string(
Expand Down

0 comments on commit b2b51a3

Please sign in to comment.