diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py index 0dd7d7733..1d54ec778 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py @@ -1,4 +1,4 @@ -from hashlib import md5 +import random from typing import Any, List, Optional, Tuple from langchain_core.runnables import RunnableConfig @@ -8,7 +8,6 @@ WRITES_IDX_MAP, BaseCheckpointSaver, Checkpoint, - EmptyChannelError, get_checkpoint_id, ) from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer @@ -244,11 +243,8 @@ def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> else: current_v = int(current.split(".")[0]) next_v = current_v + 1 - try: - next_h = md5(self.serde.dumps_typed(channel.checkpoint())[1]).hexdigest() - except EmptyChannelError: - next_h = "" - return f"{next_v:032}.{next_h}" + next_h = random.random() + return f"{next_v:032}.{next_h:016}" def _search_where( self, diff --git a/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/__init__.py b/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/__init__.py index 8b6f728e9..215c84206 100644 --- a/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/__init__.py +++ b/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/__init__.py @@ -1,7 +1,7 @@ +import random import sqlite3 import threading from contextlib import closing, contextmanager -from hashlib import md5 from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple from langchain_core.runnables import RunnableConfig @@ -13,7 +13,6 @@ Checkpoint, CheckpointMetadata, CheckpointTuple, - EmptyChannelError, SerializerProtocol, get_checkpoint_id, ) @@ -514,8 +513,5 @@ def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> else: current_v = int(current.split(".")[0]) next_v = current_v + 1 - try: - next_h = md5(self.serde.dumps_typed(channel.checkpoint())[1]).hexdigest() - except EmptyChannelError: - next_h = "" - return f"{next_v:032}.{next_h}" + next_h = random.random() + return f"{next_v:032}.{next_h:016}" diff --git a/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py b/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py index 3ffd715c1..0e3ab0ecb 100644 --- a/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py +++ b/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py @@ -1,4 +1,5 @@ import asyncio +import random from contextlib import asynccontextmanager from typing import ( Any, @@ -26,6 +27,7 @@ get_checkpoint_id, ) from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer +from langgraph.checkpoint.serde.types import ChannelProtocol from langgraph.checkpoint.sqlite.utils import search_where T = TypeVar("T", bound=callable) @@ -498,3 +500,23 @@ async def aput_writes( for idx, (channel, value) in enumerate(writes) ], ) + + def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str: + """Generate the next version ID for a channel. + + This method creates a new version identifier for a channel based on its current version. + + Args: + current (Optional[str]): The current version identifier of the channel. + channel (BaseChannel): The channel being versioned. + + Returns: + str: The next version identifier, which is guaranteed to be monotonically increasing. + """ + if current is None: + current_v = 0 + else: + current_v = int(current.split(".")[0]) + next_v = current_v + 1 + next_h = random.random() + return f"{next_v:032}.{next_h:016}" diff --git a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py index 372e5ba2f..a0c3237ac 100644 --- a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py @@ -1,4 +1,5 @@ import asyncio +import random from collections import defaultdict from contextlib import AbstractAsyncContextManager, AbstractContextManager from functools import partial @@ -17,7 +18,7 @@ SerializerProtocol, get_checkpoint_id, ) -from langgraph.checkpoint.serde.types import TASKS +from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol class MemorySaver( @@ -457,3 +458,14 @@ async def aput_writes( return await asyncio.get_running_loop().run_in_executor( None, self.put_writes, config, writes, task_id ) + + def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str: + if current is None: + current_v = 0 + elif isinstance(current, int): + current_v = current + else: + current_v = int(current.split(".")[0]) + next_v = current_v + 1 + next_h = random.random() + return f"{next_v:032}.{next_h:016}" diff --git a/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py b/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py index e91aa425d..2007ada62 100644 --- a/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py +++ b/libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py @@ -62,6 +62,8 @@ def _default(self, obj): return self._encode_constructor_args( obj.__class__, method=(None, "construct"), kwargs=obj.dict() ) + elif hasattr(obj, "_asdict") and callable(obj._asdict): + return self._encode_constructor_args(obj.__class__, kwargs=obj._asdict()) elif isinstance(obj, pathlib.Path): return self._encode_constructor_args(pathlib.Path, args=obj.parts) elif isinstance(obj, re.Pattern): diff --git a/libs/checkpoint/tests/test_jsonplus.py b/libs/checkpoint/tests/test_jsonplus.py index 22241a44a..a9f5d71e6 100644 --- a/libs/checkpoint/tests/test_jsonplus.py +++ b/libs/checkpoint/tests/test_jsonplus.py @@ -114,6 +114,8 @@ def test_serde_jsonplus() -> None: "an_int": 1, "a_float": 1.1, "runnable_map": RunnableMap({}), + "a_bytes": b"my bytes", + "a_bytearray": bytearray([42]), } serde = JsonPlusSerializer() @@ -122,7 +124,7 @@ def test_serde_jsonplus() -> None: assert dumped == ( "json", - b"""{"path": {"lc": 2, "type": "constructor", "id": ["pathlib", "Path"], "args": ["foo", "bar"]}, "re": {"lc": 2, "type": "constructor", "id": ["re", "compile"], "args": ["foo", 48]}, "decimal": {"lc": 2, "type": "constructor", "id": ["decimal", "Decimal"], "args": ["1.10101"]}, "ip4": {"lc": 2, "type": "constructor", "id": ["ipaddress", "IPv4Address"], "args": ["192.168.0.1"]}, "deque": {"lc": 2, "type": "constructor", "id": ["collections", "deque"], "args": [[1, 2, 3]]}, "tzn": {"lc": 2, "type": "constructor", "id": ["zoneinfo", "ZoneInfo"], "args": ["America/New_York"]}, "date": {"lc": 2, "type": "constructor", "id": ["datetime", "date"], "args": [2024, 4, 19]}, "time": {"lc": 2, "type": "constructor", "id": ["datetime", "time"], "args": [23, 4, 57, 51022, {"lc": 2, "type": "constructor", "id": ["datetime", "timezone"], "args": [{"lc": 2, "type": "constructor", "id": ["datetime", "timedelta"], "args": [0, 86340, 0]}]}], "kwargs": {"fold": 0}}, "uid": {"lc": 2, "type": "constructor", "id": ["uuid", "UUID"], "args": ["00000000000000000000000000000001"]}, "timestamp": {"lc": 2, "type": "constructor", "id": ["datetime", "datetime"], "method": "fromisoformat", "args": ["2024-04-19T23:04:57.051022+23:59"]}, "my_slotted_class": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyDataclassWSlots"], "kwargs": {"foo": "bar", "bar": 2}}, "my_dataclass": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyDataclass"], "kwargs": {"foo": "foo", "bar": 1}}, "my_enum": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyEnum"], "args": ["foo"]}, "my_pydantic": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyPydantic"], "method": [null, "model_construct"], "kwargs": {"foo": "foo", "bar": 1}}, "my_funny_pydantic": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyFunnyPydantic"], "method": [null, "construct"], "kwargs": {"foo": "foo", "bar": 1}}, "person": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "Person"], "kwargs": {"name": "foo"}}, "a_bool": true, "a_none": null, "a_str": "foo", "a_str_nuc": "foo\\u0000", "a_str_uc": "foo \xe2\x9b\xb0\xef\xb8\x8f", "a_str_ucuc": "foo \xe2\x9b\xb0\xef\xb8\x8f\\u0000", "a_str_ucucuc": "foo \\\\u26f0\\\\ufe0f", "text": ["Hello", "Python", "Surrogate", "Example", "String", "With", "Surrogates", "Embedded", "In", "The", "Text", "\xe6\x94\xb6\xe8\x8a\xb1\xf0\x9f\x99\x84\xc2\xb7\xe5\x88\xb0"], "an_int": 1, "a_float": 1.1, "runnable_map": {"lc": 1, "type": "constructor", "id": ["langchain", "schema", "runnable", "RunnableParallel"], "kwargs": {"steps__": {}}, "name": "RunnableParallel<>", "graph": {"nodes": [{"id": 0, "type": "schema", "data": "Parallel<>Input"}, {"id": 1, "type": "schema", "data": "Parallel<>Output"}], "edges": []}}}""", + b"""{"path": {"lc": 2, "type": "constructor", "id": ["pathlib", "Path"], "args": ["foo", "bar"]}, "re": {"lc": 2, "type": "constructor", "id": ["re", "compile"], "args": ["foo", 48]}, "decimal": {"lc": 2, "type": "constructor", "id": ["decimal", "Decimal"], "args": ["1.10101"]}, "ip4": {"lc": 2, "type": "constructor", "id": ["ipaddress", "IPv4Address"], "args": ["192.168.0.1"]}, "deque": {"lc": 2, "type": "constructor", "id": ["collections", "deque"], "args": [[1, 2, 3]]}, "tzn": {"lc": 2, "type": "constructor", "id": ["zoneinfo", "ZoneInfo"], "args": ["America/New_York"]}, "date": {"lc": 2, "type": "constructor", "id": ["datetime", "date"], "args": [2024, 4, 19]}, "time": {"lc": 2, "type": "constructor", "id": ["datetime", "time"], "args": [23, 4, 57, 51022, {"lc": 2, "type": "constructor", "id": ["datetime", "timezone"], "args": [{"lc": 2, "type": "constructor", "id": ["datetime", "timedelta"], "args": [0, 86340, 0]}]}], "kwargs": {"fold": 0}}, "uid": {"lc": 2, "type": "constructor", "id": ["uuid", "UUID"], "args": ["00000000000000000000000000000001"]}, "timestamp": {"lc": 2, "type": "constructor", "id": ["datetime", "datetime"], "method": "fromisoformat", "args": ["2024-04-19T23:04:57.051022+23:59"]}, "my_slotted_class": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyDataclassWSlots"], "kwargs": {"foo": "bar", "bar": 2}}, "my_dataclass": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyDataclass"], "kwargs": {"foo": "foo", "bar": 1}}, "my_enum": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyEnum"], "args": ["foo"]}, "my_pydantic": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyPydantic"], "method": [null, "model_construct"], "kwargs": {"foo": "foo", "bar": 1}}, "my_funny_pydantic": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "MyFunnyPydantic"], "method": [null, "construct"], "kwargs": {"foo": "foo", "bar": 1}}, "person": {"lc": 2, "type": "constructor", "id": ["tests", "test_jsonplus", "Person"], "kwargs": {"name": "foo"}}, "a_bool": true, "a_none": null, "a_str": "foo", "a_str_nuc": "foo\\u0000", "a_str_uc": "foo \xe2\x9b\xb0\xef\xb8\x8f", "a_str_ucuc": "foo \xe2\x9b\xb0\xef\xb8\x8f\\u0000", "a_str_ucucuc": "foo \\\\u26f0\\\\ufe0f", "text": ["Hello", "Python", "Surrogate", "Example", "String", "With", "Surrogates", "Embedded", "In", "The", "Text", "\xe6\x94\xb6\xe8\x8a\xb1\xf0\x9f\x99\x84\xc2\xb7\xe5\x88\xb0"], "an_int": 1, "a_float": 1.1, "runnable_map": {"lc": 1, "type": "constructor", "id": ["langchain", "schema", "runnable", "RunnableParallel"], "kwargs": {"steps__": {}}, "name": "RunnableParallel<>", "graph": {"nodes": [{"id": 0, "type": "schema", "data": "Parallel<>Input"}, {"id": 1, "type": "schema", "data": "Parallel<>Output"}], "edges": []}}, "a_bytes": {"lc": 2, "type": "constructor", "id": ["builtins", "bytes"], "method": "fromhex", "args": ["6d79206279746573"]}, "a_bytearray": {"lc": 2, "type": "constructor", "id": ["builtins", "bytearray"], "method": "fromhex", "args": ["2a"]}}""", ) assert serde.loads_typed(dumped) == { @@ -130,6 +132,14 @@ def test_serde_jsonplus() -> None: "text": [v.encode("utf-8", "ignore").decode() for v in to_serialize["text"]], } + for key, value in to_serialize.items(): + if key == "text": + assert serde.loads_typed(serde.dumps_typed(value)) == [ + v.encode("utf-8", "ignore").decode() for v in value + ] + else: + assert serde.loads_typed(serde.dumps_typed(value)) == value + def test_serde_jsonplus_bytes() -> None: serde = JsonPlusSerializer() diff --git a/libs/langgraph/langgraph/channels/any_value.py b/libs/langgraph/langgraph/channels/any_value.py index e9bd85572..e9dfb77d6 100644 --- a/libs/langgraph/langgraph/channels/any_value.py +++ b/libs/langgraph/langgraph/channels/any_value.py @@ -1,7 +1,5 @@ -from contextlib import contextmanager -from typing import Generator, Generic, Optional, Sequence, Type +from typing import Generic, Optional, Sequence, Type -from langchain_core.runnables import RunnableConfig from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value @@ -12,8 +10,7 @@ class AnyValue(Generic[Value], BaseChannel[Value, Value, Value]): """Stores the last value received, assumes that if multiple values are received, they are all equal.""" - def __init__(self, typ: Type[Value]) -> None: - self.typ = typ + __slots__ = ("typ", "value") def __eq__(self, value: object) -> bool: return isinstance(value, AnyValue) @@ -28,20 +25,12 @@ def UpdateType(self) -> Type[Value]: """The type of the update received by the channel.""" return self.typ - @contextmanager - def from_checkpoint( - self, checkpoint: Optional[Value], config: RunnableConfig - ) -> Generator[Self, None, None]: + def from_checkpoint(self, checkpoint: Optional[Value]) -> Self: empty = self.__class__(self.typ) + empty.key = self.key if checkpoint is not None: empty.value = checkpoint - try: - yield empty - finally: - try: - del empty.value - except AttributeError: - pass + return empty def update(self, values: Sequence[Value]) -> bool: if len(values) == 0: diff --git a/libs/langgraph/langgraph/channels/base.py b/libs/langgraph/langgraph/channels/base.py index 955b6ab76..61f8908c0 100644 --- a/libs/langgraph/langgraph/channels/base.py +++ b/libs/langgraph/langgraph/channels/base.py @@ -1,16 +1,6 @@ from abc import ABC, abstractmethod -from contextlib import asynccontextmanager, contextmanager -from typing import ( - Any, - AsyncIterator, - Generic, - Iterator, - Optional, - Sequence, - TypeVar, -) - -from langchain_core.runnables import RunnableConfig +from typing import Any, Generic, Optional, Sequence, Type, TypeVar + from typing_extensions import Self from langgraph.errors import EmptyChannelError, InvalidUpdateError @@ -21,7 +11,11 @@ class BaseChannel(Generic[Value, Update, C], ABC): - key: str = "" + __slots__ = ("key", "typ") + + def __init__(self, typ: Type[Any], key: str = "") -> None: + self.typ = typ + self.key = key @property @abstractmethod @@ -41,38 +35,10 @@ def checkpoint(self) -> Optional[C]: or doesn't support checkpoints.""" return self.get() - @contextmanager @abstractmethod - def from_checkpoint( - self, checkpoint: Optional[C], config: RunnableConfig - ) -> Iterator[Self]: - """Return a new identical channel, optionally initialized from a checkpoint. - If the checkpoint contains complex data structures, they should be copied.""" - - @contextmanager - def from_checkpoint_named( - self, checkpoint: Optional[C], config: RunnableConfig - ) -> Iterator[Self]: - with self.from_checkpoint(checkpoint, config) as value: - value.key = self.key - yield value - - @asynccontextmanager - async def afrom_checkpoint( - self, checkpoint: Optional[C], config: RunnableConfig - ) -> AsyncIterator[Self]: + def from_checkpoint(self, checkpoint: Optional[C]) -> Self: """Return a new identical channel, optionally initialized from a checkpoint. If the checkpoint contains complex data structures, they should be copied.""" - with self.from_checkpoint(checkpoint, config) as value: - yield value - - @asynccontextmanager - async def afrom_checkpoint_named( - self, checkpoint: Optional[C], config: RunnableConfig - ) -> AsyncIterator[Self]: - async with self.afrom_checkpoint(checkpoint, config) as value: - value.key = self.key - yield value # state methods diff --git a/libs/langgraph/langgraph/channels/binop.py b/libs/langgraph/langgraph/channels/binop.py index 02e33bd2e..d3fe4fce2 100644 --- a/libs/langgraph/langgraph/channels/binop.py +++ b/libs/langgraph/langgraph/channels/binop.py @@ -1,15 +1,12 @@ import collections.abc -from contextlib import contextmanager from typing import ( Callable, - Generator, Generic, Optional, Sequence, Type, ) -from langchain_core.runnables import RunnableConfig from typing_extensions import NotRequired, Required, Self from langgraph.channels.base import BaseChannel, Value @@ -37,10 +34,11 @@ class BinaryOperatorAggregate(Generic[Value], BaseChannel[Value, Value, Value]): ``` """ + __slots__ = ("value", "operator") + def __init__(self, typ: Type[Value], operator: Callable[[Value, Value], Value]): + super().__init__(typ) self.operator = operator - # keep the type exposed by ValueType/UpdateType as-is - self.typ = typ # special forms from typing or collections.abc are not instantiable # so we need to replace them with their concrete counterparts typ = _strip_extras(typ) @@ -73,20 +71,12 @@ def UpdateType(self) -> Type[Value]: """The type of the update received by the channel.""" return self.typ - @contextmanager - def from_checkpoint( - self, checkpoint: Optional[Value], config: RunnableConfig - ) -> Generator[Self, None, None]: + def from_checkpoint(self, checkpoint: Optional[Value]) -> Self: empty = self.__class__(self.typ, self.operator) + empty.key = self.key if checkpoint is not None: empty.value = checkpoint - try: - yield empty - finally: - try: - del empty.value - except AttributeError: - pass + return empty def update(self, values: Sequence[Value]) -> bool: if not values: diff --git a/libs/langgraph/langgraph/channels/dynamic_barrier_value.py b/libs/langgraph/langgraph/channels/dynamic_barrier_value.py index 64406b8f8..bbecd3d8b 100644 --- a/libs/langgraph/langgraph/channels/dynamic_barrier_value.py +++ b/libs/langgraph/langgraph/channels/dynamic_barrier_value.py @@ -1,7 +1,5 @@ -from contextlib import contextmanager -from typing import Generator, Generic, NamedTuple, Optional, Sequence, Type, Union +from typing import Generic, NamedTuple, Optional, Sequence, Type, Union -from langchain_core.runnables import RunnableConfig from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value @@ -24,11 +22,13 @@ class DynamicBarrierValue( back to the "priming" state. """ + __slots__ = ("names", "seen") + names: Optional[set[Value]] seen: set[Value] def __init__(self, typ: Type[Value]) -> None: - self.typ = typ + super().__init__(typ) self.names = None self.seen = set() @@ -48,22 +48,17 @@ def UpdateType(self) -> Type[Value]: def checkpoint(self) -> tuple[Optional[set[Value]], set[Value]]: return (self.names, self.seen) - @contextmanager def from_checkpoint( self, checkpoint: Optional[tuple[Optional[set[Value]], set[Value]]], - config: RunnableConfig, - ) -> Generator[Self, None, None]: + ) -> Self: empty = self.__class__(self.typ) + empty.key = self.key if checkpoint is not None: names, seen = checkpoint - empty.names = names.copy() if names is not None else None - empty.seen = seen.copy() - - try: - yield empty - finally: - pass + empty.names = names if names is not None else None + empty.seen = seen + return empty def update(self, values: Sequence[Union[Value, WaitForNames]]) -> bool: if wait_for_names := [v for v in values if isinstance(v, WaitForNames)]: diff --git a/libs/langgraph/langgraph/channels/ephemeral_value.py b/libs/langgraph/langgraph/channels/ephemeral_value.py index 59e34a5f0..5beba22eb 100644 --- a/libs/langgraph/langgraph/channels/ephemeral_value.py +++ b/libs/langgraph/langgraph/channels/ephemeral_value.py @@ -1,7 +1,5 @@ -from contextlib import contextmanager -from typing import Generator, Generic, Optional, Sequence, Type +from typing import Generic, Optional, Sequence, Type -from langchain_core.runnables import RunnableConfig from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value @@ -11,8 +9,10 @@ class EphemeralValue(Generic[Value], BaseChannel[Value, Value, Value]): """Stores the value received in the step immediately preceding, clears after.""" + __slots__ = ("value", "guard") + def __init__(self, typ: Type[Value], guard: bool = True) -> None: - self.typ = typ + super().__init__(typ) self.guard = guard def __eq__(self, value: object) -> bool: @@ -28,20 +28,12 @@ def UpdateType(self) -> Type[Value]: """The type of the update received by the channel.""" return self.typ - @contextmanager - def from_checkpoint( - self, checkpoint: Optional[Value], config: RunnableConfig - ) -> Generator[Self, None, None]: + def from_checkpoint(self, checkpoint: Optional[Value]) -> Self: empty = self.__class__(self.typ, self.guard) + empty.key = self.key if checkpoint is not None: empty.value = checkpoint - try: - yield empty - finally: - try: - del empty.value - except AttributeError: - pass + return empty def update(self, values: Sequence[Value]) -> bool: if len(values) == 0: diff --git a/libs/langgraph/langgraph/channels/last_value.py b/libs/langgraph/langgraph/channels/last_value.py index e5e59d111..3f5b35e3b 100644 --- a/libs/langgraph/langgraph/channels/last_value.py +++ b/libs/langgraph/langgraph/channels/last_value.py @@ -1,7 +1,5 @@ -from contextlib import contextmanager -from typing import Generator, Generic, Optional, Sequence, Type +from typing import Generic, Optional, Sequence, Type -from langchain_core.runnables import RunnableConfig from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value @@ -11,8 +9,7 @@ class LastValue(Generic[Value], BaseChannel[Value, Value, Value]): """Stores the last value received, can receive at most one value per step.""" - def __init__(self, typ: Type[Value]) -> None: - self.typ = typ + __slots__ = ("value",) def __eq__(self, value: object) -> bool: return isinstance(value, LastValue) @@ -27,20 +24,12 @@ def UpdateType(self) -> Type[Value]: """The type of the update received by the channel.""" return self.typ - @contextmanager - def from_checkpoint( - self, checkpoint: Optional[Value], config: RunnableConfig - ) -> Generator[Self, None, None]: + def from_checkpoint(self, checkpoint: Optional[Value]) -> Self: empty = self.__class__(self.typ) + empty.key = self.key if checkpoint is not None: empty.value = checkpoint - try: - yield empty - finally: - try: - del empty.value - except AttributeError: - pass + return empty def update(self, values: Sequence[Value]) -> bool: if len(values) == 0: diff --git a/libs/langgraph/langgraph/channels/named_barrier_value.py b/libs/langgraph/langgraph/channels/named_barrier_value.py index 023f54e6c..a804a3052 100644 --- a/libs/langgraph/langgraph/channels/named_barrier_value.py +++ b/libs/langgraph/langgraph/channels/named_barrier_value.py @@ -1,7 +1,5 @@ -from contextlib import contextmanager -from typing import Generator, Generic, Optional, Sequence, Type +from typing import Generic, Optional, Sequence, Type -from langchain_core.runnables import RunnableConfig from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value @@ -11,8 +9,10 @@ class NamedBarrierValue(Generic[Value], BaseChannel[Value, Value, set[Value]]): """A channel that waits until all named values are received before making the value available.""" + __slots__ = ("names", "seen") + def __init__(self, typ: Type[Value], names: set[Value]) -> None: - self.typ = typ + super().__init__(typ) self.names = names self.seen = set() @@ -32,18 +32,12 @@ def UpdateType(self) -> Type[Value]: def checkpoint(self) -> set[Value]: return self.seen - @contextmanager - def from_checkpoint( - self, checkpoint: Optional[set[Value]], config: RunnableConfig - ) -> Generator[Self, None, None]: + def from_checkpoint(self, checkpoint: Optional[set[Value]]) -> Self: empty = self.__class__(self.typ, self.names) + empty.key = self.key if checkpoint is not None: - empty.seen = checkpoint.copy() - - try: - yield empty - finally: - pass + empty.seen = checkpoint + return empty def update(self, values: Sequence[Value]) -> bool: updated = False diff --git a/libs/langgraph/langgraph/channels/topic.py b/libs/langgraph/langgraph/channels/topic.py index 7b4b0b27d..8c1c8d15b 100644 --- a/libs/langgraph/langgraph/channels/topic.py +++ b/libs/langgraph/langgraph/channels/topic.py @@ -1,7 +1,5 @@ -from contextlib import contextmanager -from typing import Any, Generator, Generic, Iterator, Optional, Sequence, Type, Union +from typing import Any, Generic, Iterator, Optional, Sequence, Type, Union -from langchain_core.runnables import RunnableConfig from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value @@ -30,9 +28,11 @@ class Topic( accumulate: Whether to accumulate values across steps. If False, the channel will be emptied after each step. """ + __slots__ = ("values", "accumulate") + def __init__(self, typ: Type[Value], accumulate: bool = False) -> None: + super().__init__(typ) # attrs - self.typ = typ self.accumulate = accumulate # state self.values = list[Value]() @@ -53,22 +53,15 @@ def UpdateType(self) -> Any: def checkpoint(self) -> tuple[set[Value], list[Value]]: return self.values - @contextmanager - def from_checkpoint( - self, - checkpoint: Optional[list[Value]], - config: RunnableConfig, - ) -> Generator[Self, None, None]: + def from_checkpoint(self, checkpoint: Optional[list[Value]]) -> Self: empty = self.__class__(self.typ, self.accumulate) + empty.key = self.key if checkpoint is not None: if isinstance(checkpoint, tuple): - empty.values = checkpoint[1].copy() + empty.values = checkpoint[1] else: - empty.values = checkpoint.copy() - try: - yield empty - finally: - pass + empty.values = checkpoint + return empty def update(self, values: Sequence[Union[Value, list[Value]]]) -> None: current = list(self.values) diff --git a/libs/langgraph/langgraph/channels/untracked_value.py b/libs/langgraph/langgraph/channels/untracked_value.py index a112b0e81..9b1020710 100644 --- a/libs/langgraph/langgraph/channels/untracked_value.py +++ b/libs/langgraph/langgraph/channels/untracked_value.py @@ -1,7 +1,5 @@ -from contextlib import contextmanager -from typing import Generator, Generic, Optional, Sequence, Type +from typing import Generic, Optional, Sequence, Type -from langchain_core.runnables import RunnableConfig from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value @@ -11,8 +9,10 @@ class UntrackedValue(Generic[Value], BaseChannel[Value, Value, Value]): """Stores the last value received, never checkpointed.""" + __slots__ = ("value", "guard") + def __init__(self, typ: Type[Value], guard: bool = True) -> None: - self.typ = typ + super().__init__(typ) self.guard = guard def __eq__(self, value: object) -> bool: @@ -31,18 +31,10 @@ def UpdateType(self) -> Type[Value]: def checkpoint(self) -> Value: raise EmptyChannelError() - @contextmanager - def from_checkpoint( - self, checkpoint: Optional[Value], config: RunnableConfig - ) -> Generator[Self, None, None]: + def from_checkpoint(self, checkpoint: Optional[Value]) -> Self: empty = self.__class__(self.typ, self.guard) - try: - yield empty - finally: - try: - del empty.value - except AttributeError: - pass + empty.key = self.key + return empty def update(self, values: Sequence[Value]) -> bool: if len(values) == 0: diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index 1b4077bd6..52401283f 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -50,6 +50,7 @@ INTERRUPT, SCHEDULED, TAG_HIDDEN, + TASKS, ) from langgraph.errors import ( CheckpointNotLatest, @@ -101,6 +102,7 @@ INPUT_DONE = object() INPUT_RESUMING = object() EMPTY_SEQ = () +SPECIAL_CHANNELS = (ERROR, INTERRUPT, SCHEDULED) class StreamProtocol(Protocol): @@ -220,6 +222,19 @@ def put_writes(self, task_id: str, writes: Sequence[tuple[str, Any]]) -> None: """Put writes for a task, to be read by the next tick.""" if not writes: return + # adjust task_writes_left + first_channel = writes[0][0] + any_channel_is_send = any(k == TASKS for k, _ in writes) + always_save = any_channel_is_send or first_channel in SPECIAL_CHANNELS + if not always_save and not self.task_writes_left: + return self._output_writes(task_id, writes) + elif first_channel == INTERRUPT: + # INTERRUPT makes us want to save the last task's writes + # so we don't decrement task_writes_left + pass + else: + self.task_writes_left -= 1 + # save writes self.checkpoint_pending_writes.extend((task_id, k, v) for k, v in writes) if self.checkpointer_put_writes is not None: self.submit( @@ -237,6 +252,7 @@ def put_writes(self, task_id: str, writes: Sequence[tuple[str, Any]]) -> None: writes, task_id, ) + # output writes self._output_writes(task_id, writes) def tick( @@ -321,6 +337,9 @@ def tick( manager=manager, checkpointer=self.checkpointer, ) + # we don't need to save the writes for the last task that completes + # unless in special conditions handled by self.put_writes() + self.task_writes_left = len(self.tasks) - 1 # produce debug output if self._checkpointer_put_after_previous is not None: diff --git a/libs/langgraph/langgraph/pregel/manager.py b/libs/langgraph/langgraph/pregel/manager.py index f70c86d46..a7e0b7283 100644 --- a/libs/langgraph/langgraph/pregel/manager.py +++ b/libs/langgraph/langgraph/pregel/manager.py @@ -42,9 +42,7 @@ def ChannelsManager( with ExitStack() as stack: yield ( { - k: stack.enter_context( - v.from_checkpoint_named(checkpoint["channel_values"].get(k), config) - ) + k: v.from_checkpoint(checkpoint["channel_values"].get(k)) for k, v in channel_specs.items() }, ManagedValueMapping( @@ -100,11 +98,7 @@ async def AsyncChannelsManager( yield ( # channels: enter each channel with checkpoint { - k: await stack.enter_async_context( - v.afrom_checkpoint_named( - checkpoint["channel_values"].get(k), config - ) - ) + k: v.from_checkpoint(checkpoint["channel_values"].get(k)) for k, v in channel_specs.items() }, # managed: build mapping from spec to result diff --git a/libs/langgraph/tests/test_channels.py b/libs/langgraph/tests/test_channels.py index 69a624969..7c6fb162b 100644 --- a/libs/langgraph/tests/test_channels.py +++ b/libs/langgraph/tests/test_channels.py @@ -12,158 +12,77 @@ def test_last_value() -> None: - with LastValue(int).from_checkpoint(None, {}) as channel: - assert channel.ValueType is int - assert channel.UpdateType is int - - with pytest.raises(EmptyChannelError): - channel.get() - with pytest.raises(InvalidUpdateError): - channel.update([5, 6]) - - channel.update([3]) - assert channel.get() == 3 - channel.update([4]) - assert channel.get() == 4 - checkpoint = channel.checkpoint() - with LastValue(int).from_checkpoint(checkpoint, {}) as channel: - assert channel.get() == 4 - - -async def test_last_value_async() -> None: - async with LastValue(int).afrom_checkpoint(None, {}) as channel: - assert channel.ValueType is int - assert channel.UpdateType is int - - with pytest.raises(EmptyChannelError): - channel.get() - with pytest.raises(InvalidUpdateError): - channel.update([5, 6]) - - channel.update([3]) - assert channel.get() == 3 - channel.update([4]) - assert channel.get() == 4 - checkpoint = channel.checkpoint() - async with LastValue(int).afrom_checkpoint(checkpoint, {}) as channel: - assert channel.get() == 4 + channel = LastValue(int).from_checkpoint(None) + assert channel.ValueType is int + assert channel.UpdateType is int + + with pytest.raises(EmptyChannelError): + channel.get() + with pytest.raises(InvalidUpdateError): + channel.update([5, 6]) + + channel.update([3]) + assert channel.get() == 3 + channel.update([4]) + assert channel.get() == 4 + checkpoint = channel.checkpoint() + channel = LastValue(int).from_checkpoint(checkpoint) + assert channel.get() == 4 def test_topic() -> None: - with Topic(str).from_checkpoint(None, {}) as channel: - assert channel.ValueType is Sequence[str] - assert channel.UpdateType is Union[str, list[str]] - - assert channel.update(["a", "b"]) - assert channel.get() == ["a", "b"] - assert channel.update([["c", "d"], "d"]) - assert channel.get() == ["c", "d", "d"] - assert channel.update([]) - with pytest.raises(EmptyChannelError): - channel.get() - assert not channel.update([]), "channel already empty" - assert channel.update(["e"]) - assert channel.get() == ["e"] - checkpoint = channel.checkpoint() - with Topic(str).from_checkpoint(checkpoint, {}) as channel: - assert channel.get() == ["e"] - with Topic(str).from_checkpoint(checkpoint, {}) as channel_copy: - channel_copy.update(["f"]) - assert channel_copy.get() == ["f"] - assert channel.get() == ["e"] - - -async def test_topic_async() -> None: - async with Topic(str).afrom_checkpoint(None, {}) as channel: - assert channel.ValueType is Sequence[str] - assert channel.UpdateType is Union[str, list[str]] - - assert channel.update(["a", "b"]) - assert channel.get() == ["a", "b"] - assert channel.update(["b", ["c", "d"], "d"]) - assert channel.get() == ["b", "c", "d", "d"] - assert channel.update([]) - with pytest.raises(EmptyChannelError): - channel.get() - assert not channel.update([]), "channel already empty" - assert channel.update(["e"]) - assert channel.get() == ["e"] - checkpoint = channel.checkpoint() - async with Topic(str).afrom_checkpoint(checkpoint, {}) as channel: - assert channel.get() == ["e"] + channel = Topic(str).from_checkpoint(None) + assert channel.ValueType is Sequence[str] + assert channel.UpdateType is Union[str, list[str]] + + assert channel.update(["a", "b"]) + assert channel.get() == ["a", "b"] + assert channel.update([["c", "d"], "d"]) + assert channel.get() == ["c", "d", "d"] + assert channel.update([]) + with pytest.raises(EmptyChannelError): + channel.get() + assert not channel.update([]), "channel already empty" + assert channel.update(["e"]) + assert channel.get() == ["e"] + checkpoint = channel.checkpoint() + channel = Topic(str).from_checkpoint(checkpoint) + assert channel.get() == ["e"] + channel_copy = Topic(str).from_checkpoint(checkpoint) + channel_copy.update(["f"]) + assert channel_copy.get() == ["f"] + assert channel.get() == ["e"] def test_topic_accumulate() -> None: - with Topic(str, accumulate=True).from_checkpoint(None, {}) as channel: - assert channel.ValueType is Sequence[str] - assert channel.UpdateType is Union[str, list[str]] - - assert channel.update(["a", "b"]) - assert channel.get() == ["a", "b"] - assert channel.update(["b", ["c", "d"], "d"]) - assert channel.get() == ["a", "b", "b", "c", "d", "d"] - assert not channel.update([]) - assert channel.get() == ["a", "b", "b", "c", "d", "d"] - checkpoint = channel.checkpoint() - with Topic(str, accumulate=True).from_checkpoint(checkpoint, {}) as channel: - assert channel.get() == ["a", "b", "b", "c", "d", "d"] - assert channel.update(["e"]) - assert channel.get() == ["a", "b", "b", "c", "d", "d", "e"] - - -async def test_topic_accumulate_async() -> None: - async with Topic(str, accumulate=True).afrom_checkpoint(None, {}) as channel: - assert channel.ValueType is Sequence[str] - assert channel.UpdateType is Union[str, list[str]] - - assert channel.update(["a", "b"]) - assert channel.get() == ["a", "b"] - assert channel.update(["b", ["c", "d"], "d"]) - assert channel.get() == ["a", "b", "b", "c", "d", "d"] - assert not channel.update([]) - assert channel.get() == ["a", "b", "b", "c", "d", "d"] - checkpoint = channel.checkpoint() - async with Topic(str, accumulate=True).afrom_checkpoint(checkpoint, {}) as channel: - assert channel.get() == ["a", "b", "b", "c", "d", "d"] - assert channel.update(["e"]) - assert channel.get() == ["a", "b", "b", "c", "d", "d", "e"] + channel = Topic(str, accumulate=True).from_checkpoint(None) + assert channel.ValueType is Sequence[str] + assert channel.UpdateType is Union[str, list[str]] + + assert channel.update(["a", "b"]) + assert channel.get() == ["a", "b"] + assert channel.update(["b", ["c", "d"], "d"]) + assert channel.get() == ["a", "b", "b", "c", "d", "d"] + assert not channel.update([]) + assert channel.get() == ["a", "b", "b", "c", "d", "d"] + checkpoint = channel.checkpoint() + channel = Topic(str, accumulate=True).from_checkpoint(checkpoint) + assert channel.get() == ["a", "b", "b", "c", "d", "d"] + assert channel.update(["e"]) + assert channel.get() == ["a", "b", "b", "c", "d", "d", "e"] def test_binop() -> None: - with BinaryOperatorAggregate(int, operator.add).from_checkpoint( - None, {} - ) as channel: - assert channel.ValueType is int - assert channel.UpdateType is int - - assert channel.get() == 0 - - channel.update([1, 2, 3]) - assert channel.get() == 6 - channel.update([4]) - assert channel.get() == 10 - checkpoint = channel.checkpoint() - with BinaryOperatorAggregate(int, operator.add).from_checkpoint( - checkpoint, {} - ) as channel: - assert channel.get() == 10 - - -async def test_binop_async() -> None: - async with BinaryOperatorAggregate(int, operator.add).afrom_checkpoint( - None, {} - ) as channel: - assert channel.ValueType is int - assert channel.UpdateType is int - - assert channel.get() == 0 - - channel.update([1, 2, 3]) - assert channel.get() == 6 - channel.update([4]) - assert channel.get() == 10 - checkpoint = channel.checkpoint() - async with BinaryOperatorAggregate(int, operator.add).afrom_checkpoint( - checkpoint, {} - ) as channel: - assert channel.get() == 10 + channel = BinaryOperatorAggregate(int, operator.add).from_checkpoint(None) + assert channel.ValueType is int + assert channel.UpdateType is int + + assert channel.get() == 0 + + channel.update([1, 2, 3]) + assert channel.get() == 6 + channel.update([4]) + assert channel.get() == 10 + checkpoint = channel.checkpoint() + channel = BinaryOperatorAggregate(int, operator.add).from_checkpoint(checkpoint) + assert channel.get() == 10 diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index ec154a2c7..b59309081 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -1656,11 +1656,7 @@ def reset(self): "writes": {"__start__": {"value": 1}}, }, parent_config=None, - pending_writes=UnsortedSequence( - (AnyStr(), "value", 1), - (AnyStr(), "start:one", "__start__"), - (AnyStr(), "start:two", "__start__"), - ), + pending_writes=[], ) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 1c1be830e..e55e50fac 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -1851,11 +1851,7 @@ def reset(self): "writes": {"__start__": {"value": 1}}, }, parent_config=None, - pending_writes=UnsortedSequence( - (AnyStr(), "value", 1), - (AnyStr(), "start:one", "__start__"), - (AnyStr(), "start:two", "__start__"), - ), + pending_writes=[], )