Skip to content

Commit

Permalink
More performance improvements in checkpointing and channels (#1685)
Browse files Browse the repository at this point in the history
* Performance improvements in checkpointer libs

- Use sha1 instead of md5 for hashing (faster in python 3.x)
- Use orjson instead of json for json dumping (sadly can't use for json loading)

* Update tests

* Update

* Use random number instead of hash for get_version_number

* Avoid saving writes for the last task to complete in each step

- only when possible, exceptions for ERROR, INTERRUPT, SEND

* Make Channel.from_checkpoint a regular function

- context manager no longer needed since Context became a managed value

* Use __slots__ for Channels

* Fix for kafka
  • Loading branch information
nfcampos authored Sep 13, 2024
1 parent 8a80b1d commit 66fc7c9
Show file tree
Hide file tree
Showing 20 changed files with 206 additions and 344 deletions.
10 changes: 3 additions & 7 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from hashlib import md5
import random
from typing import Any, List, Optional, Tuple

from langchain_core.runnables import RunnableConfig
Expand All @@ -8,7 +8,6 @@
WRITES_IDX_MAP,
BaseCheckpointSaver,
Checkpoint,
EmptyChannelError,
get_checkpoint_id,
)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
Expand Down Expand Up @@ -244,11 +243,8 @@ def get_next_version(self, current: Optional[str], channel: ChannelProtocol) ->
else:
current_v = int(current.split(".")[0])
next_v = current_v + 1
try:
next_h = md5(self.serde.dumps_typed(channel.checkpoint())[1]).hexdigest()
except EmptyChannelError:
next_h = ""
return f"{next_v:032}.{next_h}"
next_h = random.random()
return f"{next_v:032}.{next_h:016}"

def _search_where(
self,
Expand Down
10 changes: 3 additions & 7 deletions libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
import sqlite3
import threading
from contextlib import closing, contextmanager
from hashlib import md5
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple

from langchain_core.runnables import RunnableConfig
Expand All @@ -13,7 +13,6 @@
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
EmptyChannelError,
SerializerProtocol,
get_checkpoint_id,
)
Expand Down Expand Up @@ -514,8 +513,5 @@ def get_next_version(self, current: Optional[str], channel: ChannelProtocol) ->
else:
current_v = int(current.split(".")[0])
next_v = current_v + 1
try:
next_h = md5(self.serde.dumps_typed(channel.checkpoint())[1]).hexdigest()
except EmptyChannelError:
next_h = ""
return f"{next_v:032}.{next_h}"
next_h = random.random()
return f"{next_v:032}.{next_h:016}"
22 changes: 22 additions & 0 deletions libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import random
from contextlib import asynccontextmanager
from typing import (
Any,
Expand Down Expand Up @@ -26,6 +27,7 @@
get_checkpoint_id,
)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.serde.types import ChannelProtocol
from langgraph.checkpoint.sqlite.utils import search_where

T = TypeVar("T", bound=callable)
Expand Down Expand Up @@ -498,3 +500,23 @@ async def aput_writes(
for idx, (channel, value) in enumerate(writes)
],
)

def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
"""Generate the next version ID for a channel.
This method creates a new version identifier for a channel based on its current version.
Args:
current (Optional[str]): The current version identifier of the channel.
channel (BaseChannel): The channel being versioned.
Returns:
str: The next version identifier, which is guaranteed to be monotonically increasing.
"""
if current is None:
current_v = 0
else:
current_v = int(current.split(".")[0])
next_v = current_v + 1
next_h = random.random()
return f"{next_v:032}.{next_h:016}"
14 changes: 13 additions & 1 deletion libs/checkpoint/langgraph/checkpoint/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import random
from collections import defaultdict
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from functools import partial
Expand All @@ -17,7 +18,7 @@
SerializerProtocol,
get_checkpoint_id,
)
from langgraph.checkpoint.serde.types import TASKS
from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol


class MemorySaver(
Expand Down Expand Up @@ -457,3 +458,14 @@ async def aput_writes(
return await asyncio.get_running_loop().run_in_executor(
None, self.put_writes, config, writes, task_id
)

def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
if current is None:
current_v = 0
elif isinstance(current, int):
current_v = current
else:
current_v = int(current.split(".")[0])
next_v = current_v + 1
next_h = random.random()
return f"{next_v:032}.{next_h:016}"
2 changes: 2 additions & 0 deletions libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def _default(self, obj):
return self._encode_constructor_args(
obj.__class__, method=(None, "construct"), kwargs=obj.dict()
)
elif hasattr(obj, "_asdict") and callable(obj._asdict):
return self._encode_constructor_args(obj.__class__, kwargs=obj._asdict())
elif isinstance(obj, pathlib.Path):
return self._encode_constructor_args(pathlib.Path, args=obj.parts)
elif isinstance(obj, re.Pattern):
Expand Down
12 changes: 11 additions & 1 deletion libs/checkpoint/tests/test_jsonplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def test_serde_jsonplus() -> None:
"an_int": 1,
"a_float": 1.1,
"runnable_map": RunnableMap({}),
"a_bytes": b"my bytes",
"a_bytearray": bytearray([42]),
}

serde = JsonPlusSerializer()
Expand All @@ -122,14 +124,22 @@ def test_serde_jsonplus() -> None:

assert dumped == (
"json",
b"""{"path": {"lc": 2, "type": "constructor", "id": ["pathlib", "Path"], "args": ["foo", "bar"]}, "re": {"lc": 2, "type": "constructor", "id": ["re", "compile"], "args": ["foo", 48]}, "decimal": {"lc": 2, "type": "constructor", "id": ["decimal", "Decimal"], "args": ["1.10101"]}, "ip4": {"lc": 2, "type": "constructor", "id": ["ipaddress", "IPv4Address"], "args": ["192.168.0.1"]}, "deque": {"lc": 2, "type": "constructor", "id": ["collections", "deque"], "args": [[1, 2, 3]]}, "tzn": {"lc": 2, "type": "constructor", "id": ["zoneinfo", "ZoneInfo"], "args": ["America/New_York"]}, "date": {"lc": 2, "type": "constructor", "id": ["datetime", "date"], "args": [2024, 4, 19]}, "time": {"lc": 2, "type": "constructor", "id": ["datetime", "time"], "args": [23, 4, 57, 51022, {"lc": 2, "type": "constructor", "id": ["datetime", "timezone"], "args": [{"lc": 2, "type": "constructor", "id": ["datetime", "timedelta"], "args": [0, 86340, 0]}]}], "kwargs": {"fold": 0}}, "uid": {"lc": 2, "type": "constructor", "id": ["uuid", "UUID"], "args": ["00000000000000000000000000000001"]}, "timestamp": {"lc": 2, "type": "constructor", "id": ["datetime", "datetime"], "method": "fromisoformat", "args": ["2024-04-19T23:04:57.051022+23:59"]}, "my_slotted_class": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyDataclassWSlots"], "kwargs": {"foo": "bar", "bar": 2}}, "my_dataclass": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyDataclass"], "kwargs": {"foo": "foo", "bar": 1}}, "my_enum": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyEnum"], "args": ["foo"]}, "my_pydantic": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyPydantic"], "method": [null, "model_construct"], "kwargs": {"foo": "foo", "bar": 1}}, "my_funny_pydantic": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyFunnyPydantic"], "method": [null, "construct"], "kwargs": {"foo": "foo", "bar": 1}}, "person": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "Person"], "kwargs": {"name": "foo"}}, "a_bool": true, "a_none": null, "a_str": "foo", "a_str_nuc": "foo\\u0000", "a_str_uc": "foo \xe2\x9b\xb0\xef\xb8\x8f", "a_str_ucuc": "foo \xe2\x9b\xb0\xef\xb8\x8f\\u0000", "a_str_ucucuc": "foo \\\\u26f0\\\\ufe0f", "text": ["Hello", "Python", "Surrogate", "Example", "String", "With", "Surrogates", "Embedded", "In", "The", "Text", "\xe6\x94\xb6\xe8\x8a\xb1\xf0\x9f\x99\x84\xc2\xb7\xe5\x88\xb0"], "an_int": 1, "a_float": 1.1, "runnable_map": {"lc": 1, "type": "constructor", "id": ["langchain", "schema", "runnable", "RunnableParallel"], "kwargs": {"steps__": {}}, "name": "RunnableParallel<>", "graph": {"nodes": [{"id": 0, "type": "schema", "data": "Parallel<>Input"}, {"id": 1, "type": "schema", "data": "Parallel<>Output"}], "edges": []}}}""",
b"""{"path": {"lc": 2, "type": "constructor", "id": ["pathlib", "Path"], "args": ["foo", "bar"]}, "re": {"lc": 2, "type": "constructor", "id": ["re", "compile"], "args": ["foo", 48]}, "decimal": {"lc": 2, "type": "constructor", "id": ["decimal", "Decimal"], "args": ["1.10101"]}, "ip4": {"lc": 2, "type": "constructor", "id": ["ipaddress", "IPv4Address"], "args": ["192.168.0.1"]}, "deque": {"lc": 2, "type": "constructor", "id": ["collections", "deque"], "args": [[1, 2, 3]]}, "tzn": {"lc": 2, "type": "constructor", "id": ["zoneinfo", "ZoneInfo"], "args": ["America/New_York"]}, "date": {"lc": 2, "type": "constructor", "id": ["datetime", "date"], "args": [2024, 4, 19]}, "time": {"lc": 2, "type": "constructor", "id": ["datetime", "time"], "args": [23, 4, 57, 51022, {"lc": 2, "type": "constructor", "id": ["datetime", "timezone"], "args": [{"lc": 2, "type": "constructor", "id": ["datetime", "timedelta"], "args": [0, 86340, 0]}]}], "kwargs": {"fold": 0}}, "uid": {"lc": 2, "type": "constructor", "id": ["uuid", "UUID"], "args": ["00000000000000000000000000000001"]}, "timestamp": {"lc": 2, "type": "constructor", "id": ["datetime", "datetime"], "method": "fromisoformat", "args": ["2024-04-19T23:04:57.051022+23:59"]}, "my_slotted_class": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyDataclassWSlots"], "kwargs": {"foo": "bar", "bar": 2}}, "my_dataclass": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyDataclass"], "kwargs": {"foo": "foo", "bar": 1}}, "my_enum": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyEnum"], "args": ["foo"]}, "my_pydantic": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyPydantic"], "method": [null, "model_construct"], "kwargs": {"foo": "foo", "bar": 1}}, "my_funny_pydantic": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyFunnyPydantic"], "method": [null, "construct"], "kwargs": {"foo": "foo", "bar": 1}}, "person": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "Person"], "kwargs": {"name": "foo"}}, "a_bool": true, "a_none": null, "a_str": "foo", "a_str_nuc": "foo\\u0000", "a_str_uc": "foo \xe2\x9b\xb0\xef\xb8\x8f", "a_str_ucuc": "foo \xe2\x9b\xb0\xef\xb8\x8f\\u0000", "a_str_ucucuc": "foo \\\\u26f0\\\\ufe0f", "text": ["Hello", "Python", "Surrogate", "Example", "String", "With", "Surrogates", "Embedded", "In", "The", "Text", "\xe6\x94\xb6\xe8\x8a\xb1\xf0\x9f\x99\x84\xc2\xb7\xe5\x88\xb0"], "an_int": 1, "a_float": 1.1, "runnable_map": {"lc": 1, "type": "constructor", "id": ["langchain", "schema", "runnable", "RunnableParallel"], "kwargs": {"steps__": {}}, "name": "RunnableParallel<>", "graph": {"nodes": [{"id": 0, "type": "schema", "data": "Parallel<>Input"}, {"id": 1, "type": "schema", "data": "Parallel<>Output"}], "edges": []}}, "a_bytes": {"lc": 2, "type": "constructor", "id": ["builtins", "bytes"], "method": "fromhex", "args": ["6d79206279746573"]}, "a_bytearray": {"lc": 2, "type": "constructor", "id": ["builtins", "bytearray"], "method": "fromhex", "args": ["2a"]}}""",
)

assert serde.loads_typed(dumped) == {
**to_serialize,
"text": [v.encode("utf-8", "ignore").decode() for v in to_serialize["text"]],
}

for key, value in to_serialize.items():
if key == "text":
assert serde.loads_typed(serde.dumps_typed(value)) == [
v.encode("utf-8", "ignore").decode() for v in value
]
else:
assert serde.loads_typed(serde.dumps_typed(value)) == value


def test_serde_jsonplus_bytes() -> None:
serde = JsonPlusSerializer()
Expand Down
21 changes: 5 additions & 16 deletions libs/langgraph/langgraph/channels/any_value.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from contextlib import contextmanager
from typing import Generator, Generic, Optional, Sequence, Type
from typing import Generic, Optional, Sequence, Type

from langchain_core.runnables import RunnableConfig
from typing_extensions import Self

from langgraph.channels.base import BaseChannel, Value
Expand All @@ -12,8 +10,7 @@ class AnyValue(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the last value received, assumes that if multiple values are
received, they are all equal."""

def __init__(self, typ: Type[Value]) -> None:
self.typ = typ
__slots__ = ("typ", "value")

def __eq__(self, value: object) -> bool:
return isinstance(value, AnyValue)
Expand All @@ -28,20 +25,12 @@ def UpdateType(self) -> Type[Value]:
"""The type of the update received by the channel."""
return self.typ

@contextmanager
def from_checkpoint(
self, checkpoint: Optional[Value], config: RunnableConfig
) -> Generator[Self, None, None]:
def from_checkpoint(self, checkpoint: Optional[Value]) -> Self:
empty = self.__class__(self.typ)
empty.key = self.key
if checkpoint is not None:
empty.value = checkpoint
try:
yield empty
finally:
try:
del empty.value
except AttributeError:
pass
return empty

def update(self, values: Sequence[Value]) -> bool:
if len(values) == 0:
Expand Down
50 changes: 8 additions & 42 deletions libs/langgraph/langgraph/channels/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager, contextmanager
from typing import (
Any,
AsyncIterator,
Generic,
Iterator,
Optional,
Sequence,
TypeVar,
)

from langchain_core.runnables import RunnableConfig
from typing import Any, Generic, Optional, Sequence, Type, TypeVar

from typing_extensions import Self

from langgraph.errors import EmptyChannelError, InvalidUpdateError
Expand All @@ -21,7 +11,11 @@


class BaseChannel(Generic[Value, Update, C], ABC):
key: str = ""
__slots__ = ("key", "typ")

def __init__(self, typ: Type[Any], key: str = "") -> None:
self.typ = typ
self.key = key

@property
@abstractmethod
Expand All @@ -41,38 +35,10 @@ def checkpoint(self) -> Optional[C]:
or doesn't support checkpoints."""
return self.get()

@contextmanager
@abstractmethod
def from_checkpoint(
self, checkpoint: Optional[C], config: RunnableConfig
) -> Iterator[Self]:
"""Return a new identical channel, optionally initialized from a checkpoint.
If the checkpoint contains complex data structures, they should be copied."""

@contextmanager
def from_checkpoint_named(
self, checkpoint: Optional[C], config: RunnableConfig
) -> Iterator[Self]:
with self.from_checkpoint(checkpoint, config) as value:
value.key = self.key
yield value

@asynccontextmanager
async def afrom_checkpoint(
self, checkpoint: Optional[C], config: RunnableConfig
) -> AsyncIterator[Self]:
def from_checkpoint(self, checkpoint: Optional[C]) -> Self:
"""Return a new identical channel, optionally initialized from a checkpoint.
If the checkpoint contains complex data structures, they should be copied."""
with self.from_checkpoint(checkpoint, config) as value:
yield value

@asynccontextmanager
async def afrom_checkpoint_named(
self, checkpoint: Optional[C], config: RunnableConfig
) -> AsyncIterator[Self]:
async with self.afrom_checkpoint(checkpoint, config) as value:
value.key = self.key
yield value

# state methods

Expand Down
22 changes: 6 additions & 16 deletions libs/langgraph/langgraph/channels/binop.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import collections.abc
from contextlib import contextmanager
from typing import (
Callable,
Generator,
Generic,
Optional,
Sequence,
Type,
)

from langchain_core.runnables import RunnableConfig
from typing_extensions import NotRequired, Required, Self

from langgraph.channels.base import BaseChannel, Value
Expand Down Expand Up @@ -37,10 +34,11 @@ class BinaryOperatorAggregate(Generic[Value], BaseChannel[Value, Value, Value]):
```
"""

__slots__ = ("value", "operator")

def __init__(self, typ: Type[Value], operator: Callable[[Value, Value], Value]):
super().__init__(typ)
self.operator = operator
# keep the type exposed by ValueType/UpdateType as-is
self.typ = typ
# special forms from typing or collections.abc are not instantiable
# so we need to replace them with their concrete counterparts
typ = _strip_extras(typ)
Expand Down Expand Up @@ -73,20 +71,12 @@ def UpdateType(self) -> Type[Value]:
"""The type of the update received by the channel."""
return self.typ

@contextmanager
def from_checkpoint(
self, checkpoint: Optional[Value], config: RunnableConfig
) -> Generator[Self, None, None]:
def from_checkpoint(self, checkpoint: Optional[Value]) -> Self:
empty = self.__class__(self.typ, self.operator)
empty.key = self.key
if checkpoint is not None:
empty.value = checkpoint
try:
yield empty
finally:
try:
del empty.value
except AttributeError:
pass
return empty

def update(self, values: Sequence[Value]) -> bool:
if not values:
Expand Down
23 changes: 9 additions & 14 deletions libs/langgraph/langgraph/channels/dynamic_barrier_value.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from contextlib import contextmanager
from typing import Generator, Generic, NamedTuple, Optional, Sequence, Type, Union
from typing import Generic, NamedTuple, Optional, Sequence, Type, Union

from langchain_core.runnables import RunnableConfig
from typing_extensions import Self

from langgraph.channels.base import BaseChannel, Value
Expand All @@ -24,11 +22,13 @@ class DynamicBarrierValue(
back to the "priming" state.
"""

__slots__ = ("names", "seen")

names: Optional[set[Value]]
seen: set[Value]

def __init__(self, typ: Type[Value]) -> None:
self.typ = typ
super().__init__(typ)
self.names = None
self.seen = set()

Expand All @@ -48,22 +48,17 @@ def UpdateType(self) -> Type[Value]:
def checkpoint(self) -> tuple[Optional[set[Value]], set[Value]]:
return (self.names, self.seen)

@contextmanager
def from_checkpoint(
self,
checkpoint: Optional[tuple[Optional[set[Value]], set[Value]]],
config: RunnableConfig,
) -> Generator[Self, None, None]:
) -> Self:
empty = self.__class__(self.typ)
empty.key = self.key
if checkpoint is not None:
names, seen = checkpoint
empty.names = names.copy() if names is not None else None
empty.seen = seen.copy()

try:
yield empty
finally:
pass
empty.names = names if names is not None else None
empty.seen = seen
return empty

def update(self, values: Sequence[Union[Value, WaitForNames]]) -> bool:
if wait_for_names := [v for v in values if isinstance(v, WaitForNames)]:
Expand Down
Loading

0 comments on commit 66fc7c9

Please sign in to comment.