diff --git a/README.md b/README.md index c52232d..cb29328 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,9 @@ Implementation of LangGraph CheckpointSaver that uses MySQL. > [!TIP] > The code in this repository tries to mimic the code in [langgraph-checkpoint-postgres](https://github.com/langchain-ai/langgraph/tree/main/libs/checkpoint-postgres) as much as possible to enable keeping in sync with the official checkpointer implementation. +> [!NOTE] +> In order to keep the queries close to the Postgres queries, we use features from recent versions of MySQL 8. I'm not sure what the exact minimum version is. + ## Dependencies To use synchronous `PyMySQLSaver`, install `langgraph-checkpoint-mysql[pymysql]`. To use asynchronous `AIOMySQLSaver`, install `langgraph-checkpoint-mysql[aiomysql]`. diff --git a/langgraph/checkpoint/mysql/__init__.py b/langgraph/checkpoint/mysql/__init__.py index dd9dab7..27feae1 100644 --- a/langgraph/checkpoint/mysql/__init__.py +++ b/langgraph/checkpoint/mysql/__init__.py @@ -20,6 +20,11 @@ from langgraph.checkpoint.mysql.base import ( BaseMySQLSaver, ) +from langgraph.checkpoint.mysql.utils import ( + deserialize_channel_values, + deserialize_pending_sends, + deserialize_pending_writes, +) from langgraph.checkpoint.serde.base import SerializerProtocol @@ -170,8 +175,8 @@ def list( }, self._load_checkpoint( json.loads(value["checkpoint"]), - value["channel_values"], - value["pending_sends"], + deserialize_channel_values(value["channel_values"]), + deserialize_pending_sends(value["pending_sends"]), ), self._load_metadata(value["metadata"]), { @@ -183,7 +188,9 @@ def list( } if value["parent_checkpoint_id"] else None, - self._load_writes(value["pending_writes"]), + self._load_writes( + deserialize_pending_writes(value["pending_writes"]) + ), ) def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: @@ -248,8 +255,8 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: }, self._load_checkpoint( json.loads(value["checkpoint"]), - value["channel_values"], - value["pending_sends"], + deserialize_channel_values(value["channel_values"]), + deserialize_pending_sends(value["pending_sends"]), ), self._load_metadata(value["metadata"]), { @@ -261,7 +268,9 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: } if value["parent_checkpoint_id"] else None, - self._load_writes(value["pending_writes"]), + self._load_writes( + deserialize_pending_writes(value["pending_writes"]) + ), ) def put( @@ -302,7 +311,6 @@ def put( checkpoint_id = configurable.pop( "checkpoint_id", configurable.pop("thread_ts", None) ) - copy = checkpoint.copy() next_config = { "configurable": { diff --git a/langgraph/checkpoint/mysql/aio.py b/langgraph/checkpoint/mysql/aio.py index 00c6fb8..011713d 100644 --- a/langgraph/checkpoint/mysql/aio.py +++ b/langgraph/checkpoint/mysql/aio.py @@ -19,6 +19,11 @@ get_checkpoint_id, ) from langgraph.checkpoint.mysql.base import BaseMySQLSaver +from langgraph.checkpoint.mysql.utils import ( + deserialize_channel_values, + deserialize_pending_sends, + deserialize_pending_writes, +) from langgraph.checkpoint.serde.base import SerializerProtocol Conn = Union[aiomysql.Connection, aiomysql.Pool] @@ -152,8 +157,8 @@ async def alist( await asyncio.to_thread( self._load_checkpoint, json.loads(value["checkpoint"]), - value["channel_values"], - value["pending_sends"], + deserialize_channel_values(value["channel_values"]), + deserialize_pending_sends(value["pending_sends"]), ), self._load_metadata(value["metadata"]), { @@ -165,7 +170,10 @@ async def alist( } if value["parent_checkpoint_id"] else None, - await asyncio.to_thread(self._load_writes, value["pending_writes"]), + await asyncio.to_thread( + self._load_writes, + deserialize_pending_writes(value["pending_writes"]), + ), ) async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: @@ -210,8 +218,8 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: await asyncio.to_thread( self._load_checkpoint, json.loads(value["checkpoint"]), - value["channel_values"], - value["pending_sends"], + deserialize_channel_values(value["channel_values"]), + deserialize_pending_sends(value["pending_sends"]), ), self._load_metadata(value["metadata"]), { @@ -223,7 +231,10 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: } if value["parent_checkpoint_id"] else None, - await asyncio.to_thread(self._load_writes, value["pending_writes"]), + await asyncio.to_thread( + self._load_writes, + deserialize_pending_writes(value["pending_writes"]), + ), ) async def aput( diff --git a/langgraph/checkpoint/mysql/base.py b/langgraph/checkpoint/mysql/base.py index 32ec95a..1fc7a69 100644 --- a/langgraph/checkpoint/mysql/base.py +++ b/langgraph/checkpoint/mysql/base.py @@ -68,29 +68,32 @@ metadata, ( select json_arrayagg(json_array(bl.channel, bl.type, bl.blob)) - from json_table( - checkpoint, - '$.channel_versions[*]' columns ( - `key` VARCHAR(255) PATH '$.key', - value VARCHAR(255) PATH '$.value' - ) + from + ( + select channel, json_unquote( + json_extract(checkpoint, concat('$.channel_versions.', channel)) + ) as version + from json_table( + json_keys(checkpoint, '$.channel_versions'), + '$[*]' columns (channel VARCHAR(150) PATH '$') + ) as channels ) as channel_versions inner join checkpoint_blobs bl on bl.thread_id = checkpoints.thread_id and bl.checkpoint_ns = checkpoints.checkpoint_ns - and bl.channel = channel_versions.key - and bl.version = channel_versions.value + and bl.channel = channel_versions.channel + and bl.version = channel_versions.version ) as channel_values, ( select - json_arrayagg(json_array(cw.task_id, cw.channel, cw.type, cw.blob)) + json_arrayagg(json_array(cw.task_id, cw.channel, cw.type, cw.blob, cw.idx)) from checkpoint_writes cw where cw.thread_id = checkpoints.thread_id and cw.checkpoint_ns = checkpoints.checkpoint_ns and cw.checkpoint_id = checkpoints.checkpoint_id ) as pending_writes, ( - select json_arrayagg(json_array(cw.type, cw.blob)) + select json_arrayagg(json_array(cw.type, cw.blob, cw.idx)) from checkpoint_writes cw where cw.thread_id = checkpoints.thread_id and cw.checkpoint_ns = checkpoints.checkpoint_ns @@ -140,13 +143,13 @@ class BaseMySQLSaver(BaseCheckpointSaver[str]): def _load_checkpoint( self, checkpoint: dict[str, Any], - channel_values: list[tuple[bytes, bytes, bytes]], - pending_sends: list[tuple[bytes, bytes]], + channel_values: list[tuple[str, str, bytes]], + pending_sends: list[tuple[str, bytes]], ) -> Checkpoint: return { **checkpoint, "pending_sends": [ - self.serde.loads_typed((c.decode(), b)) for c, b in pending_sends or [] + self.serde.loads_typed((c, b)) for c, b in pending_sends or [] ], "channel_values": self._load_blobs(channel_values), } @@ -154,15 +157,11 @@ def _load_checkpoint( def _dump_checkpoint(self, checkpoint: Checkpoint) -> dict[str, Any]: return {**checkpoint, "pending_sends": []} - def _load_blobs( - self, blob_values: list[tuple[bytes, bytes, bytes]] - ) -> dict[str, Any]: + def _load_blobs(self, blob_values: list[tuple[str, str, bytes]]) -> dict[str, Any]: if not blob_values: return {} return { - k.decode(): self.serde.loads_typed((t.decode(), v)) - for k, t, v in blob_values - if t.decode() != "empty" + k: self.serde.loads_typed((t, v)) for k, t, v in blob_values if t != "empty" } def _dump_blobs( @@ -191,14 +190,14 @@ def _dump_blobs( ] def _load_writes( - self, writes: list[tuple[bytes, bytes, bytes, bytes]] + self, writes: list[tuple[str, str, str, bytes]] ) -> list[tuple[str, str, Any]]: return ( [ ( - tid.decode(), - channel.decode(), - self.serde.loads_typed((t.decode(), v)), + tid, + channel, + self.serde.loads_typed((t, v)), ) for tid, channel, t, v in writes ] diff --git a/langgraph/checkpoint/mysql/utils.py b/langgraph/checkpoint/mysql/utils.py new file mode 100644 index 0000000..d93bb4d --- /dev/null +++ b/langgraph/checkpoint/mysql/utils.py @@ -0,0 +1,69 @@ +import base64 +import json +from typing import NamedTuple + +# When MySQL returns a blob in a JSON array, it is base64 encoded and a prefix +# of "base64:type251:" attached to it. +MySQLBase64Blob = str + + +def decode_base64_blob(base64_blob: MySQLBase64Blob) -> bytes: + _, data = base64_blob.rsplit(":", 1) + return base64.b64decode(data) + + +class MySQLPendingWrite(NamedTuple): + """ + The pending write tuple we receive from our DB query. + """ + + task_id: str + channel: str + type_: str + blob: MySQLBase64Blob + idx: int + + +def deserialize_pending_writes(value: str) -> list[tuple[str, str, str, bytes]]: + if not value: + return [] + + values = (MySQLPendingWrite(*write) for write in json.loads(value)) + + return [ + (db.task_id, db.channel, db.type_, decode_base64_blob(db.blob)) + for db in sorted(values, key=lambda db: (db.task_id, db.idx)) + ] + + +class MySQLPendingSend(NamedTuple): + type_: str + blob: MySQLBase64Blob + idx: int + + +def deserialize_pending_sends(value: str) -> list[tuple[str, bytes]]: + if not value: + return [] + + values = (MySQLPendingSend(*send) for send in json.loads(value)) + + return [ + (db.type_, decode_base64_blob(db.blob)) + for db in sorted(values, key=lambda db: db.idx) + ] + + +class MySQLChannelValue(NamedTuple): + channel: str + type_: str + blob: MySQLBase64Blob + + +def deserialize_channel_values(value: str) -> list[tuple[str, str, bytes]]: + if not value: + return [] + + values = (MySQLChannelValue(*channel_value) for channel_value in json.loads(value)) + + return [(db.channel, db.type_, decode_base64_blob(db.blob)) for db in values] diff --git a/pyproject.toml b/pyproject.toml index 0224010..ea1abb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langgraph-checkpoint-mysql" -version = "1.0.0" +version = "1.0.1" description = "Library with a MySQL implementation of LangGraph checkpoint saver." authors = ["Theodore Ni "] license = "MIT" diff --git a/tests/test_async.py b/tests/test_async.py index beecd11..e935760 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -5,12 +5,14 @@ from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( + ChannelVersions, Checkpoint, CheckpointMetadata, create_checkpoint, empty_checkpoint, ) from langgraph.checkpoint.mysql.aio import AIOMySQLSaver +from langgraph.checkpoint.serde.types import TASKS class TestAIOMySQLSaver: @@ -101,3 +103,51 @@ async def test_asearch(self) -> None: } == {"", "inner"} # TODO: test before and limit params + + async def test_write_and_read_pending_writes_and_sends(self) -> None: + async with AIOMySQLSaver.from_conn_string(DEFAULT_URI) as saver: + config: RunnableConfig = { + "configurable": { + "thread_id": "thread-1", + "checkpoint_id": "1", + "checkpoint_ns": "", + } + } + chkpnt = create_checkpoint(self.chkpnt_1, {}, 1, id="1") + + await saver.aput(config, chkpnt, {}, {}) + await saver.aput_writes(config, [("w1", "w1v"), ("w2", "w2v")], "world") + await saver.aput_writes(config, [(TASKS, "w3v")], "hello") + + result = [c async for c in saver.alist({})][0] + + assert result.pending_writes == [ + ("hello", TASKS, "w3v"), + ("world", "w1", "w1v"), + ("world", "w2", "w2v"), + ] + + assert result.checkpoint["pending_sends"] == ["w3v"] + + async def test_write_and_read_channel_values(self) -> None: + async with AIOMySQLSaver.from_conn_string(DEFAULT_URI) as saver: + config: RunnableConfig = { + "configurable": { + "thread_id": "thread-4", + "checkpoint_id": "4", + "checkpoint_ns": "", + } + } + chkpnt = empty_checkpoint() + chkpnt["id"] = "4" + chkpnt["channel_values"] = { + "channel1": "channel1v", + } + + newversions: ChannelVersions = {"channel1": 1} + chkpnt["channel_versions"] = newversions + + await saver.aput(config, chkpnt, {}, newversions) + + result = [c async for c in saver.alist({})][0] + assert result.checkpoint["channel_values"] == {"channel1": "channel1v"} diff --git a/tests/test_sync.py b/tests/test_sync.py index 5429793..7563a85 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -5,12 +5,14 @@ from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( + ChannelVersions, Checkpoint, CheckpointMetadata, create_checkpoint, empty_checkpoint, ) from langgraph.checkpoint.mysql import PyMySQLSaver +from langgraph.checkpoint.serde.types import TASKS class TestPyMySQLSaver: @@ -101,3 +103,51 @@ def test_search(self) -> None: } == {"", "inner"} # TODO: test before and limit params + + def test_write_and_read_pending_writes_and_sends(self) -> None: + with PyMySQLSaver.from_conn_string(DEFAULT_URI) as saver: + config: RunnableConfig = { + "configurable": { + "thread_id": "thread-1", + "checkpoint_id": "1", + "checkpoint_ns": "", + } + } + chkpnt = create_checkpoint(self.chkpnt_1, {}, 1, id="1") + + saver.put(config, chkpnt, {}, {}) + saver.put_writes(config, [("w1", "w1v"), ("w2", "w2v")], "world") + saver.put_writes(config, [(TASKS, "w3v")], "hello") + + result = next(saver.list({})) + + assert result.pending_writes == [ + ("hello", TASKS, "w3v"), + ("world", "w1", "w1v"), + ("world", "w2", "w2v"), + ] + + assert result.checkpoint["pending_sends"] == ["w3v"] + + def test_write_and_read_channel_values(self) -> None: + with PyMySQLSaver.from_conn_string(DEFAULT_URI) as saver: + config: RunnableConfig = { + "configurable": { + "thread_id": "thread-4", + "checkpoint_id": "4", + "checkpoint_ns": "", + } + } + chkpnt = empty_checkpoint() + chkpnt["id"] = "4" + chkpnt["channel_values"] = { + "channel1": "channel1v", + } + + newversions: ChannelVersions = {"channel1": 1} + chkpnt["channel_versions"] = newversions + + saver.put(config, chkpnt, {}, newversions) + + result = next(saver.list({})) + assert result.checkpoint["channel_values"] == {"channel1": "channel1v"}