From 886df0fa86841388cbd9bf257733a50555174551 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 18 Nov 2024 08:04:05 -0800 Subject: [PATCH] checkpoint: Add option to use persistent dict for in-memory checkpointer (#2439) - This should only be used in very specific circunstances, sqlite or postgres adapters much more appropriate in most circunstances --------- Co-authored-by: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> --- .../langgraph/checkpoint/memory/__init__.py | 100 ++++++++++++++++-- libs/checkpoint/tests/test_store.py | 10 +- libs/langgraph/tests/memory_assert.py | 11 +- 3 files changed, 108 insertions(+), 13 deletions(-) diff --git a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py index 25575e481..81791b539 100644 --- a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py @@ -1,10 +1,14 @@ import asyncio +import logging +import os +import pickle import random +import shutil from collections import defaultdict -from contextlib import AbstractAsyncContextManager, AbstractContextManager +from contextlib import AbstractAsyncContextManager, AbstractContextManager, ExitStack from functools import partial from types import TracebackType -from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple +from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple, Type from langchain_core.runnables import RunnableConfig @@ -20,6 +24,8 @@ ) from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol +logger = logging.getLogger(__name__) + class MemorySaver( BaseCheckpointSaver[str], AbstractContextManager, AbstractAsyncContextManager @@ -68,13 +74,18 @@ def __init__( self, *, serde: Optional[SerializerProtocol] = None, + factory: Type[defaultdict] = defaultdict, ) -> None: super().__init__(serde=serde) - self.storage = defaultdict(lambda: defaultdict(dict)) - self.writes = defaultdict(dict) + self.storage = factory(lambda: defaultdict(dict)) + self.writes = factory(dict) + self.stack = ExitStack() + if factory is not defaultdict: + self.stack.enter_context(self.storage) # type: ignore[arg-type] + self.stack.enter_context(self.writes) # type: ignore[arg-type] def __enter__(self) -> "MemorySaver": - return self + return self.stack.__enter__() def __exit__( self, @@ -82,10 +93,10 @@ def __exit__( exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> Optional[bool]: - return + return self.stack.__exit__(exc_type, exc_value, traceback) async def __aenter__(self) -> "MemorySaver": - return self + return self.stack.__enter__() async def __aexit__( self, @@ -93,7 +104,7 @@ async def __aexit__( __exc_value: Optional[BaseException], __traceback: Optional[TracebackType], ) -> Optional[bool]: - return + return self.stack.__exit__(__exc_type, __exc_value, __traceback) def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from the in-memory storage. @@ -478,3 +489,76 @@ def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> next_v = current_v + 1 next_h = random.random() return f"{next_v:032}.{next_h:016}" + + +class PersistentDict(defaultdict): + """Persistent dictionary with an API compatible with shelve and anydbm. + + The dict is kept in memory, so the dictionary operations run as fast as + a regular dictionary. + + Write to disk is delayed until close or sync (similar to gdbm's fast mode). + + Input file format is automatically discovered. + Output file format is selectable between pickle, json, and csv. + All three serialization formats are backed by fast C implementations. + + Adapted from https://code.activestate.com/recipes/576642-persistent-dict-with-multiple-standard-file-format/ + + """ + + def __init__(self, *args: Any, filename: str, **kwds: Any) -> None: + self.flag = "c" # r=readonly, c=create, or n=new + self.mode = None # None or an octal triple like 0644 + self.format = "pickle" # 'csv', 'json', or 'pickle' + self.filename = filename + super().__init__(*args, **kwds) + + def sync(self) -> None: + "Write dict to disk" + if self.flag == "r": + return + tempname = self.filename + ".tmp" + fileobj = open(tempname, "wb" if self.format == "pickle" else "w") + try: + self.dump(fileobj) + except Exception: + os.remove(tempname) + raise + finally: + fileobj.close() + shutil.move(tempname, self.filename) # atomic commit + if self.mode is not None: + os.chmod(self.filename, self.mode) + + def close(self) -> None: + self.sync() + self.clear() + + def __enter__(self) -> "PersistentDict": + return self + + def __exit__(self, *exc_info: Any) -> None: + self.close() + + def dump(self, fileobj: Any) -> None: + if self.format == "pickle": + pickle.dump(dict(self), fileobj, 2) + else: + raise NotImplementedError("Unknown format: " + repr(self.format)) + + def load(self) -> None: + # try formats from most restrictive to least restrictive + if self.flag == "n": + return + with open(self.filename, "rb" if self.format == "pickle" else "r") as fileobj: + for loader in (pickle.load,): + fileobj.seek(0) + try: + return self.update(loader(fileobj)) + except EOFError: + return + except Exception: + logging.error(f"Failed to load file: {fileobj.name}") + raise + raise ValueError("File not in a supported f ormat") diff --git a/libs/checkpoint/tests/test_store.py b/libs/checkpoint/tests/test_store.py index d9fbc5084..9d06281d0 100644 --- a/libs/checkpoint/tests/test_store.py +++ b/libs/checkpoint/tests/test_store.py @@ -314,7 +314,7 @@ async def test_cannot_put_empty_namespace() -> None: assert store.get(("langgraph", "foo"), "bar") is None class MockAsyncBatchedStore(AsyncBatchedBaseStore): - def __init__(self): + def __init__(self) -> None: super().__init__() self._store = InMemoryStore() @@ -340,13 +340,17 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]: await async_store.aput(("langgraph", "foo"), "bar", doc) await async_store.aput(("foo", "langgraph", "foo"), "bar", doc) - assert (await async_store.aget(("foo", "langgraph", "foo"), "bar")).value == doc + val = await async_store.aget(("foo", "langgraph", "foo"), "bar") + assert val is not None + assert val.value == doc assert (await async_store.asearch(("foo", "langgraph", "foo")))[0].value == doc await async_store.adelete(("foo", "langgraph", "foo"), "bar") assert (await async_store.aget(("foo", "langgraph", "foo"), "bar")) is None await async_store.abatch([PutOp(("valid", "namespace"), "key", doc)]) - assert (await async_store.aget(("valid", "namespace"), "key")).value == doc + val = await async_store.aget(("valid", "namespace"), "key") + assert val is not None + assert val.value == doc assert (await async_store.asearch(("valid", "namespace")))[0].value == doc await async_store.adelete(("valid", "namespace"), "key") assert (await async_store.aget(("valid", "namespace"), "key")) is None diff --git a/libs/langgraph/tests/memory_assert.py b/libs/langgraph/tests/memory_assert.py index 6b44051f7..0a9f13a47 100644 --- a/libs/langgraph/tests/memory_assert.py +++ b/libs/langgraph/tests/memory_assert.py @@ -1,5 +1,8 @@ import asyncio +import os +import tempfile from collections import defaultdict +from functools import partial from typing import Any, Optional from langchain_core.runnables import RunnableConfig @@ -12,7 +15,7 @@ SerializerProtocol, copy_checkpoint, ) -from langgraph.checkpoint.memory import MemorySaver +from langgraph.checkpoint.memory import MemorySaver, PersistentDict class NoopSerializer(SerializerProtocol): @@ -32,9 +35,13 @@ def __init__( serde: Optional[SerializerProtocol] = None, put_sleep: Optional[float] = None, ) -> None: - super().__init__(serde=serde) + _, filename = tempfile.mkstemp() + super().__init__( + serde=serde, factory=partial(PersistentDict, filename=filename) + ) self.storage_for_copies = defaultdict(lambda: defaultdict(dict)) self.put_sleep = put_sleep + self.stack.callback(os.remove, filename) def put( self,