Skip to content

Commit

Permalink
Fix pending writes and sends and channel versions.
Browse files Browse the repository at this point in the history
MySQL lacks array functions, so we use JSON arrays as a proxy. Extra
compatibility code is needed since MySQL will encode the blobs in each
array in base64.

MySQL also lacks ORDER BY over its JSON array functions, so sorting
needs to happen in the application.

Lastly, MySQL lacks some convenient JSON functions for extracting the
channel versions from a JSON document, since they are dynamic. We work
around that too in modern versions of MySQL.
  • Loading branch information
tjni committed Sep 25, 2024
1 parent cc1f7ba commit f51a990
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 37 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ Implementation of LangGraph CheckpointSaver that uses MySQL.
> [!TIP]
> The code in this repository tries to mimic the code in [langgraph-checkpoint-postgres](https://github.com/langchain-ai/langgraph/tree/main/libs/checkpoint-postgres) as much as possible to enable keeping in sync with the official checkpointer implementation.
> [!NOTE]
> In order to keep the queries close to the Postgres queries, we use features from recent versions of MySQL 8. I'm not sure what the exact minimum version is.
## Dependencies

To use synchronous `PyMySQLSaver`, install `langgraph-checkpoint-mysql[pymysql]`. To use asynchronous `AIOMySQLSaver`, install `langgraph-checkpoint-mysql[aiomysql]`.
Expand Down
22 changes: 15 additions & 7 deletions langgraph/checkpoint/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
from langgraph.checkpoint.mysql.base import (
BaseMySQLSaver,
)
from langgraph.checkpoint.mysql.utils import (
deserialize_channel_values,
deserialize_pending_sends,
deserialize_pending_writes,
)
from langgraph.checkpoint.serde.base import SerializerProtocol


Expand Down Expand Up @@ -170,8 +175,8 @@ def list(
},
self._load_checkpoint(
json.loads(value["checkpoint"]),
value["channel_values"],
value["pending_sends"],
deserialize_channel_values(value["channel_values"]),
deserialize_pending_sends(value["pending_sends"]),
),
self._load_metadata(value["metadata"]),
{
Expand All @@ -183,7 +188,9 @@ def list(
}
if value["parent_checkpoint_id"]
else None,
self._load_writes(value["pending_writes"]),
self._load_writes(
deserialize_pending_writes(value["pending_writes"])
),
)

def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
Expand Down Expand Up @@ -248,8 +255,8 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
},
self._load_checkpoint(
json.loads(value["checkpoint"]),
value["channel_values"],
value["pending_sends"],
deserialize_channel_values(value["channel_values"]),
deserialize_pending_sends(value["pending_sends"]),
),
self._load_metadata(value["metadata"]),
{
Expand All @@ -261,7 +268,9 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
}
if value["parent_checkpoint_id"]
else None,
self._load_writes(value["pending_writes"]),
self._load_writes(
deserialize_pending_writes(value["pending_writes"])
),
)

def put(
Expand Down Expand Up @@ -302,7 +311,6 @@ def put(
checkpoint_id = configurable.pop(
"checkpoint_id", configurable.pop("thread_ts", None)
)

copy = checkpoint.copy()
next_config = {
"configurable": {
Expand Down
23 changes: 17 additions & 6 deletions langgraph/checkpoint/mysql/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
get_checkpoint_id,
)
from langgraph.checkpoint.mysql.base import BaseMySQLSaver
from langgraph.checkpoint.mysql.utils import (
deserialize_channel_values,
deserialize_pending_sends,
deserialize_pending_writes,
)
from langgraph.checkpoint.serde.base import SerializerProtocol

Conn = Union[aiomysql.Connection, aiomysql.Pool]
Expand Down Expand Up @@ -152,8 +157,8 @@ async def alist(
await asyncio.to_thread(
self._load_checkpoint,
json.loads(value["checkpoint"]),
value["channel_values"],
value["pending_sends"],
deserialize_channel_values(value["channel_values"]),
deserialize_pending_sends(value["pending_sends"]),
),
self._load_metadata(value["metadata"]),
{
Expand All @@ -165,7 +170,10 @@ async def alist(
}
if value["parent_checkpoint_id"]
else None,
await asyncio.to_thread(self._load_writes, value["pending_writes"]),
await asyncio.to_thread(
self._load_writes,
deserialize_pending_writes(value["pending_writes"]),
),
)

async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
Expand Down Expand Up @@ -210,8 +218,8 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
await asyncio.to_thread(
self._load_checkpoint,
json.loads(value["checkpoint"]),
value["channel_values"],
value["pending_sends"],
deserialize_channel_values(value["channel_values"]),
deserialize_pending_sends(value["pending_sends"]),
),
self._load_metadata(value["metadata"]),
{
Expand All @@ -223,7 +231,10 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
}
if value["parent_checkpoint_id"]
else None,
await asyncio.to_thread(self._load_writes, value["pending_writes"]),
await asyncio.to_thread(
self._load_writes,
deserialize_pending_writes(value["pending_writes"]),
),
)

async def aput(
Expand Down
45 changes: 22 additions & 23 deletions langgraph/checkpoint/mysql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,29 +68,32 @@
metadata,
(
select json_arrayagg(json_array(bl.channel, bl.type, bl.blob))
from json_table(
checkpoint,
'$.channel_versions[*]' columns (
`key` VARCHAR(255) PATH '$.key',
value VARCHAR(255) PATH '$.value'
)
from
(
select channel, json_unquote(
json_extract(checkpoint, concat('$.channel_versions.', channel))
) as version
from json_table(
json_keys(checkpoint, '$.channel_versions'),
'$[*]' columns (channel VARCHAR(150) PATH '$')
) as channels
) as channel_versions
inner join checkpoint_blobs bl
on bl.thread_id = checkpoints.thread_id
and bl.checkpoint_ns = checkpoints.checkpoint_ns
and bl.channel = channel_versions.key
and bl.version = channel_versions.value
and bl.channel = channel_versions.channel
and bl.version = channel_versions.version
) as channel_values,
(
select
json_arrayagg(json_array(cw.task_id, cw.channel, cw.type, cw.blob))
json_arrayagg(json_array(cw.task_id, cw.channel, cw.type, cw.blob, cw.idx))
from checkpoint_writes cw
where cw.thread_id = checkpoints.thread_id
and cw.checkpoint_ns = checkpoints.checkpoint_ns
and cw.checkpoint_id = checkpoints.checkpoint_id
) as pending_writes,
(
select json_arrayagg(json_array(cw.type, cw.blob))
select json_arrayagg(json_array(cw.type, cw.blob, cw.idx))
from checkpoint_writes cw
where cw.thread_id = checkpoints.thread_id
and cw.checkpoint_ns = checkpoints.checkpoint_ns
Expand Down Expand Up @@ -140,29 +143,25 @@ class BaseMySQLSaver(BaseCheckpointSaver[str]):
def _load_checkpoint(
self,
checkpoint: dict[str, Any],
channel_values: list[tuple[bytes, bytes, bytes]],
pending_sends: list[tuple[bytes, bytes]],
channel_values: list[tuple[str, str, bytes]],
pending_sends: list[tuple[str, bytes]],
) -> Checkpoint:
return {
**checkpoint,
"pending_sends": [
self.serde.loads_typed((c.decode(), b)) for c, b in pending_sends or []
self.serde.loads_typed((c, b)) for c, b in pending_sends or []
],
"channel_values": self._load_blobs(channel_values),
}

def _dump_checkpoint(self, checkpoint: Checkpoint) -> dict[str, Any]:
return {**checkpoint, "pending_sends": []}

def _load_blobs(
self, blob_values: list[tuple[bytes, bytes, bytes]]
) -> dict[str, Any]:
def _load_blobs(self, blob_values: list[tuple[str, str, bytes]]) -> dict[str, Any]:
if not blob_values:
return {}
return {
k.decode(): self.serde.loads_typed((t.decode(), v))
for k, t, v in blob_values
if t.decode() != "empty"
k: self.serde.loads_typed((t, v)) for k, t, v in blob_values if t != "empty"
}

def _dump_blobs(
Expand Down Expand Up @@ -191,14 +190,14 @@ def _dump_blobs(
]

def _load_writes(
self, writes: list[tuple[bytes, bytes, bytes, bytes]]
self, writes: list[tuple[str, str, str, bytes]]
) -> list[tuple[str, str, Any]]:
return (
[
(
tid.decode(),
channel.decode(),
self.serde.loads_typed((t.decode(), v)),
tid,
channel,
self.serde.loads_typed((t, v)),
)
for tid, channel, t, v in writes
]
Expand Down
69 changes: 69 additions & 0 deletions langgraph/checkpoint/mysql/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import base64
import json
from typing import NamedTuple

# When MySQL returns a blob in a JSON array, it is base64 encoded and a prefix
# of "base64:type251:" attached to it.
MySQLBase64Blob = str


def decode_base64_blob(base64_blob: MySQLBase64Blob) -> bytes:
_, data = base64_blob.rsplit(":", 1)
return base64.b64decode(data)


class MySQLPendingWrite(NamedTuple):
"""
The pending write tuple we receive from our DB query.
"""

task_id: str
channel: str
type_: str
blob: MySQLBase64Blob
idx: int


def deserialize_pending_writes(value: str) -> list[tuple[str, str, str, bytes]]:
if not value:
return []

values = (MySQLPendingWrite(*write) for write in json.loads(value))

return [
(db.task_id, db.channel, db.type_, decode_base64_blob(db.blob))
for db in sorted(values, key=lambda db: (db.task_id, db.idx))
]


class MySQLPendingSend(NamedTuple):
type_: str
blob: MySQLBase64Blob
idx: int


def deserialize_pending_sends(value: str) -> list[tuple[str, bytes]]:
if not value:
return []

values = (MySQLPendingSend(*send) for send in json.loads(value))

return [
(db.type_, decode_base64_blob(db.blob))
for db in sorted(values, key=lambda db: db.idx)
]


class MySQLChannelValue(NamedTuple):
channel: str
type_: str
blob: MySQLBase64Blob


def deserialize_channel_values(value: str) -> list[tuple[str, str, bytes]]:
if not value:
return []

values = (MySQLChannelValue(*channel_value) for channel_value in json.loads(value))

return [(db.channel, db.type_, decode_base64_blob(db.blob)) for db in values]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langgraph-checkpoint-mysql"
version = "1.0.0"
version = "1.0.1"
description = "Library with a MySQL implementation of LangGraph checkpoint saver."
authors = ["Theodore Ni <[email protected]>"]
license = "MIT"
Expand Down
50 changes: 50 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from langchain_core.runnables import RunnableConfig

from langgraph.checkpoint.base import (
ChannelVersions,
Checkpoint,
CheckpointMetadata,
create_checkpoint,
empty_checkpoint,
)
from langgraph.checkpoint.mysql.aio import AIOMySQLSaver
from langgraph.checkpoint.serde.types import TASKS


class TestAIOMySQLSaver:
Expand Down Expand Up @@ -101,3 +103,51 @@ async def test_asearch(self) -> None:
} == {"", "inner"}

# TODO: test before and limit params

async def test_write_and_read_pending_writes_and_sends(self) -> None:
async with AIOMySQLSaver.from_conn_string(DEFAULT_URI) as saver:
config: RunnableConfig = {
"configurable": {
"thread_id": "thread-1",
"checkpoint_id": "1",
"checkpoint_ns": "",
}
}
chkpnt = create_checkpoint(self.chkpnt_1, {}, 1, id="1")

await saver.aput(config, chkpnt, {}, {})
await saver.aput_writes(config, [("w1", "w1v"), ("w2", "w2v")], "world")
await saver.aput_writes(config, [(TASKS, "w3v")], "hello")

result = [c async for c in saver.alist({})][0]

assert result.pending_writes == [
("hello", TASKS, "w3v"),
("world", "w1", "w1v"),
("world", "w2", "w2v"),
]

assert result.checkpoint["pending_sends"] == ["w3v"]

async def test_write_and_read_channel_values(self) -> None:
async with AIOMySQLSaver.from_conn_string(DEFAULT_URI) as saver:
config: RunnableConfig = {
"configurable": {
"thread_id": "thread-4",
"checkpoint_id": "4",
"checkpoint_ns": "",
}
}
chkpnt = empty_checkpoint()
chkpnt["id"] = "4"
chkpnt["channel_values"] = {
"channel1": "channel1v",
}

newversions: ChannelVersions = {"channel1": 1}
chkpnt["channel_versions"] = newversions

await saver.aput(config, chkpnt, {}, newversions)

result = [c async for c in saver.alist({})][0]
assert result.checkpoint["channel_values"] == {"channel1": "channel1v"}
Loading

0 comments on commit f51a990

Please sign in to comment.