From bf19dc7d08ec0369d35cef3ac0d4e570405fc59a Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Fri, 27 Sep 2024 12:02:03 -0400 Subject: [PATCH] checkpoint-postgres: handle null chars in metadata (#1885) --- .../langgraph/checkpoint/postgres/base.py | 3 ++- libs/checkpoint-postgres/tests/test_async.py | 10 ++++++++++ libs/checkpoint-postgres/tests/test_sync.py | 9 +++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py index 76232e337..535232370 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py @@ -229,7 +229,8 @@ def _load_metadata(self, metadata: dict[str, Any]) -> CheckpointMetadata: def _dump_metadata(self, metadata: CheckpointMetadata) -> str: serialized_metadata = self.jsonplus_serde.dumps(metadata) - return serialized_metadata.decode() + # NOTE: we're using JSON serializer (not msgpack), so we need to remove null characters before writing + return serialized_metadata.decode().replace("\\u0000", "") def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str: if current is None: diff --git a/libs/checkpoint-postgres/tests/test_async.py b/libs/checkpoint-postgres/tests/test_async.py index 288996529..d44376fe5 100644 --- a/libs/checkpoint-postgres/tests/test_async.py +++ b/libs/checkpoint-postgres/tests/test_async.py @@ -101,3 +101,13 @@ async def test_asearch(self) -> None: } == {"", "inner"} # TODO: test before and limit params + + async def test_null_chars(self) -> None: + async with AsyncPostgresSaver.from_conn_string(DEFAULT_URI) as saver: + config = await saver.aput( + self.config_1, self.chkpnt_1, {"my_key": "\x00abc"}, {} + ) + assert (await saver.aget_tuple(config)).metadata["my_key"] == "abc" + assert [c async for c in saver.alist(None, filter={"my_key": "abc"})][ + 0 + ].metadata["my_key"] == "abc" diff --git a/libs/checkpoint-postgres/tests/test_sync.py b/libs/checkpoint-postgres/tests/test_sync.py index f64059a28..d196d075b 100644 --- a/libs/checkpoint-postgres/tests/test_sync.py +++ b/libs/checkpoint-postgres/tests/test_sync.py @@ -101,3 +101,12 @@ def test_search(self) -> None: } == {"", "inner"} # TODO: test before and limit params + + def test_null_chars(self) -> None: + with PostgresSaver.from_conn_string(DEFAULT_URI) as saver: + config = saver.put(self.config_1, self.chkpnt_1, {"my_key": "\x00abc"}, {}) + assert saver.get_tuple(config).metadata["my_key"] == "abc" + assert ( + list(saver.list(None, filter={"my_key": "abc"}))[0].metadata["my_key"] + == "abc" + )