Skip to content

Commit

Permalink
standalone implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda committed Dec 19, 2024
1 parent a00d50a commit 99badd2
Show file tree
Hide file tree
Showing 6 changed files with 920 additions and 712 deletions.
82 changes: 2 additions & 80 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from langgraph.checkpoint.postgres import _internal
from langgraph.checkpoint.postgres.base import BasePostgresSaver
from langgraph.checkpoint.postgres.shallow import ShallowPostgresSaver
from langgraph.checkpoint.serde.base import SerializerProtocol

Conn = _internal.Conn # For backward compatibility
Expand Down Expand Up @@ -396,83 +397,4 @@ def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]:
yield cur


class ShallowPostgresSaver(PostgresSaver):
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Save a checkpoint to the database.
This method saves a checkpoint to the Postgres database. The checkpoint is associated
with the provided config and its parent config (if any).
Args:
config (RunnableConfig): The config to associate with the checkpoint.
checkpoint (Checkpoint): The checkpoint to save.
metadata (CheckpointMetadata): Additional metadata to save with the checkpoint.
new_versions (ChannelVersions): New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
Examples:
>>> from langgraph.checkpoint.postgres import PostgresSaver
>>> DB_URI = "postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable"
>>> with PostgresSaver.from_conn_string(DB_URI) as memory:
>>> config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
>>> checkpoint = {"ts": "2024-05-04T06:32:42.235444+00:00", "id": "1ef4f797-8335-6428-8001-8a1503f9b875", "channel_values": {"key": "value"}}
>>> saved_config = memory.put(config, checkpoint, {"source": "input", "step": 1, "writes": {"key": "value"}}, {})
>>> print(saved_config)
{'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '1ef4f797-8335-6428-8001-8a1503f9b875'}}
"""
configurable = config["configurable"].copy()
thread_id = configurable.pop("thread_id")
checkpoint_ns = configurable.pop("checkpoint_ns")
checkpoint_id = configurable.pop(
"checkpoint_id", configurable.pop("thread_ts", None)
)

copy = checkpoint.copy()
next_config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
}
}

with self._cursor(pipeline=True) as cur:
# clear existing checkpoints
cur.execute(
"DELETE FROM checkpoints WHERE thread_id = %s AND checkpoint_ns = %s",
(thread_id, checkpoint_ns),
)
# write new checkpoint
cur.executemany(
self.UPSERT_CHECKPOINT_BLOBS_SQL,
self._dump_blobs(
thread_id,
checkpoint_ns,
copy.pop("channel_values"), # type: ignore[misc]
new_versions,
),
)
cur.execute(
self.UPSERT_CHECKPOINTS_SQL,
(
thread_id,
checkpoint_ns,
checkpoint["id"],
checkpoint_id,
Jsonb(self._dump_checkpoint(copy)),
self._dump_metadata(metadata),
),
)
return next_config


__all__ = ["PostgresSaver", "BasePostgresSaver", "Conn"]
__all__ = ["PostgresSaver", "BasePostgresSaver", "ShallowPostgresSaver", "Conn"]
Loading

0 comments on commit 99badd2

Please sign in to comment.