Skip to content

Commit

Permalink
checkpoint: Add option to use persistent dict for in-memory checkpoin…
Browse files Browse the repository at this point in the history
…ter (#2439)

- This should only be used in very specific circunstances, sqlite or
postgres adapters much more appropriate in most circunstances

---------

Co-authored-by: William Fu-Hinthorn <[email protected]>
  • Loading branch information
nfcampos and hinthornw authored Nov 18, 2024
1 parent 3f1792d commit 886df0f
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 13 deletions.
100 changes: 92 additions & 8 deletions libs/checkpoint/langgraph/checkpoint/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import asyncio
import logging
import os
import pickle
import random
import shutil
from collections import defaultdict
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from contextlib import AbstractAsyncContextManager, AbstractContextManager, ExitStack
from functools import partial
from types import TracebackType
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple, Type

from langchain_core.runnables import RunnableConfig

Expand All @@ -20,6 +24,8 @@
)
from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol

logger = logging.getLogger(__name__)


class MemorySaver(
BaseCheckpointSaver[str], AbstractContextManager, AbstractAsyncContextManager
Expand Down Expand Up @@ -68,32 +74,37 @@ def __init__(
self,
*,
serde: Optional[SerializerProtocol] = None,
factory: Type[defaultdict] = defaultdict,
) -> None:
super().__init__(serde=serde)
self.storage = defaultdict(lambda: defaultdict(dict))
self.writes = defaultdict(dict)
self.storage = factory(lambda: defaultdict(dict))
self.writes = factory(dict)
self.stack = ExitStack()
if factory is not defaultdict:
self.stack.enter_context(self.storage) # type: ignore[arg-type]
self.stack.enter_context(self.writes) # type: ignore[arg-type]

def __enter__(self) -> "MemorySaver":
return self
return self.stack.__enter__()

def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
return
return self.stack.__exit__(exc_type, exc_value, traceback)

async def __aenter__(self) -> "MemorySaver":
return self
return self.stack.__enter__()

async def __aexit__(
self,
__exc_type: Optional[type[BaseException]],
__exc_value: Optional[BaseException],
__traceback: Optional[TracebackType],
) -> Optional[bool]:
return
return self.stack.__exit__(__exc_type, __exc_value, __traceback)

def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get a checkpoint tuple from the in-memory storage.
Expand Down Expand Up @@ -478,3 +489,76 @@ def get_next_version(self, current: Optional[str], channel: ChannelProtocol) ->
next_v = current_v + 1
next_h = random.random()
return f"{next_v:032}.{next_h:016}"


class PersistentDict(defaultdict):
"""Persistent dictionary with an API compatible with shelve and anydbm.
The dict is kept in memory, so the dictionary operations run as fast as
a regular dictionary.
Write to disk is delayed until close or sync (similar to gdbm's fast mode).
Input file format is automatically discovered.
Output file format is selectable between pickle, json, and csv.
All three serialization formats are backed by fast C implementations.
Adapted from https://code.activestate.com/recipes/576642-persistent-dict-with-multiple-standard-file-format/
"""

def __init__(self, *args: Any, filename: str, **kwds: Any) -> None:
self.flag = "c" # r=readonly, c=create, or n=new
self.mode = None # None or an octal triple like 0644
self.format = "pickle" # 'csv', 'json', or 'pickle'
self.filename = filename
super().__init__(*args, **kwds)

def sync(self) -> None:
"Write dict to disk"
if self.flag == "r":
return
tempname = self.filename + ".tmp"
fileobj = open(tempname, "wb" if self.format == "pickle" else "w")
try:
self.dump(fileobj)
except Exception:
os.remove(tempname)
raise
finally:
fileobj.close()
shutil.move(tempname, self.filename) # atomic commit
if self.mode is not None:
os.chmod(self.filename, self.mode)

def close(self) -> None:
self.sync()
self.clear()

def __enter__(self) -> "PersistentDict":
return self

def __exit__(self, *exc_info: Any) -> None:
self.close()

def dump(self, fileobj: Any) -> None:
if self.format == "pickle":
pickle.dump(dict(self), fileobj, 2)
else:
raise NotImplementedError("Unknown format: " + repr(self.format))

def load(self) -> None:
# try formats from most restrictive to least restrictive
if self.flag == "n":
return
with open(self.filename, "rb" if self.format == "pickle" else "r") as fileobj:
for loader in (pickle.load,):
fileobj.seek(0)
try:
return self.update(loader(fileobj))
except EOFError:
return
except Exception:
logging.error(f"Failed to load file: {fileobj.name}")
raise
raise ValueError("File not in a supported f ormat")
10 changes: 7 additions & 3 deletions libs/checkpoint/tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ async def test_cannot_put_empty_namespace() -> None:
assert store.get(("langgraph", "foo"), "bar") is None

class MockAsyncBatchedStore(AsyncBatchedBaseStore):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self._store = InMemoryStore()

Expand All @@ -340,13 +340,17 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]:
await async_store.aput(("langgraph", "foo"), "bar", doc)

await async_store.aput(("foo", "langgraph", "foo"), "bar", doc)
assert (await async_store.aget(("foo", "langgraph", "foo"), "bar")).value == doc
val = await async_store.aget(("foo", "langgraph", "foo"), "bar")
assert val is not None
assert val.value == doc
assert (await async_store.asearch(("foo", "langgraph", "foo")))[0].value == doc
await async_store.adelete(("foo", "langgraph", "foo"), "bar")
assert (await async_store.aget(("foo", "langgraph", "foo"), "bar")) is None

await async_store.abatch([PutOp(("valid", "namespace"), "key", doc)])
assert (await async_store.aget(("valid", "namespace"), "key")).value == doc
val = await async_store.aget(("valid", "namespace"), "key")
assert val is not None
assert val.value == doc
assert (await async_store.asearch(("valid", "namespace")))[0].value == doc
await async_store.adelete(("valid", "namespace"), "key")
assert (await async_store.aget(("valid", "namespace"), "key")) is None
11 changes: 9 additions & 2 deletions libs/langgraph/tests/memory_assert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio
import os
import tempfile
from collections import defaultdict
from functools import partial
from typing import Any, Optional

from langchain_core.runnables import RunnableConfig
Expand All @@ -12,7 +15,7 @@
SerializerProtocol,
copy_checkpoint,
)
from langgraph.checkpoint.memory import MemorySaver
from langgraph.checkpoint.memory import MemorySaver, PersistentDict


class NoopSerializer(SerializerProtocol):
Expand All @@ -32,9 +35,13 @@ def __init__(
serde: Optional[SerializerProtocol] = None,
put_sleep: Optional[float] = None,
) -> None:
super().__init__(serde=serde)
_, filename = tempfile.mkstemp()
super().__init__(
serde=serde, factory=partial(PersistentDict, filename=filename)
)
self.storage_for_copies = defaultdict(lambda: defaultdict(dict))
self.put_sleep = put_sleep
self.stack.callback(os.remove, filename)

def put(
self,
Expand Down

0 comments on commit 886df0f

Please sign in to comment.