diff --git a/libs/langgraph/langgraph/channels/any_value.py b/libs/langgraph/langgraph/channels/any_value.py index e9bd85572e..3231b2e4d3 100644 --- a/libs/langgraph/langgraph/channels/any_value.py +++ b/libs/langgraph/langgraph/channels/any_value.py @@ -1,7 +1,5 @@ -from contextlib import contextmanager -from typing import Generator, Generic, Optional, Sequence, Type +from typing import Generic, Optional, Sequence, Type -from langchain_core.runnables import RunnableConfig from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value @@ -28,20 +26,12 @@ def UpdateType(self) -> Type[Value]: """The type of the update received by the channel.""" return self.typ - @contextmanager - def from_checkpoint( - self, checkpoint: Optional[Value], config: RunnableConfig - ) -> Generator[Self, None, None]: + def from_checkpoint(self, checkpoint: Optional[Value]) -> Self: empty = self.__class__(self.typ) + empty.key = self.key if checkpoint is not None: empty.value = checkpoint - try: - yield empty - finally: - try: - del empty.value - except AttributeError: - pass + return empty def update(self, values: Sequence[Value]) -> bool: if len(values) == 0: diff --git a/libs/langgraph/langgraph/channels/base.py b/libs/langgraph/langgraph/channels/base.py index 955b6ab76c..6ad86ee404 100644 --- a/libs/langgraph/langgraph/channels/base.py +++ b/libs/langgraph/langgraph/channels/base.py @@ -1,16 +1,6 @@ from abc import ABC, abstractmethod -from contextlib import asynccontextmanager, contextmanager -from typing import ( - Any, - AsyncIterator, - Generic, - Iterator, - Optional, - Sequence, - TypeVar, -) - -from langchain_core.runnables import RunnableConfig +from typing import Any, Generic, Optional, Sequence, TypeVar + from typing_extensions import Self from langgraph.errors import EmptyChannelError, InvalidUpdateError @@ -41,38 +31,10 @@ def checkpoint(self) -> Optional[C]: or doesn't support checkpoints.""" return self.get() - @contextmanager @abstractmethod - def from_checkpoint( - self, checkpoint: Optional[C], config: RunnableConfig - ) -> Iterator[Self]: - """Return a new identical channel, optionally initialized from a checkpoint. - If the checkpoint contains complex data structures, they should be copied.""" - - @contextmanager - def from_checkpoint_named( - self, checkpoint: Optional[C], config: RunnableConfig - ) -> Iterator[Self]: - with self.from_checkpoint(checkpoint, config) as value: - value.key = self.key - yield value - - @asynccontextmanager - async def afrom_checkpoint( - self, checkpoint: Optional[C], config: RunnableConfig - ) -> AsyncIterator[Self]: + def from_checkpoint(self, checkpoint: Optional[C]) -> Self: """Return a new identical channel, optionally initialized from a checkpoint. If the checkpoint contains complex data structures, they should be copied.""" - with self.from_checkpoint(checkpoint, config) as value: - yield value - - @asynccontextmanager - async def afrom_checkpoint_named( - self, checkpoint: Optional[C], config: RunnableConfig - ) -> AsyncIterator[Self]: - async with self.afrom_checkpoint(checkpoint, config) as value: - value.key = self.key - yield value # state methods diff --git a/libs/langgraph/langgraph/channels/binop.py b/libs/langgraph/langgraph/channels/binop.py index 02e33bd2e9..fa21eeaea8 100644 --- a/libs/langgraph/langgraph/channels/binop.py +++ b/libs/langgraph/langgraph/channels/binop.py @@ -1,15 +1,12 @@ import collections.abc -from contextlib import contextmanager from typing import ( Callable, - Generator, Generic, Optional, Sequence, Type, ) -from langchain_core.runnables import RunnableConfig from typing_extensions import NotRequired, Required, Self from langgraph.channels.base import BaseChannel, Value @@ -73,20 +70,12 @@ def UpdateType(self) -> Type[Value]: """The type of the update received by the channel.""" return self.typ - @contextmanager - def from_checkpoint( - self, checkpoint: Optional[Value], config: RunnableConfig - ) -> Generator[Self, None, None]: + def from_checkpoint(self, checkpoint: Optional[Value]) -> Self: empty = self.__class__(self.typ, self.operator) + empty.key = self.key if checkpoint is not None: empty.value = checkpoint - try: - yield empty - finally: - try: - del empty.value - except AttributeError: - pass + return empty def update(self, values: Sequence[Value]) -> bool: if not values: diff --git a/libs/langgraph/langgraph/channels/dynamic_barrier_value.py b/libs/langgraph/langgraph/channels/dynamic_barrier_value.py index 64406b8f82..730cf8d455 100644 --- a/libs/langgraph/langgraph/channels/dynamic_barrier_value.py +++ b/libs/langgraph/langgraph/channels/dynamic_barrier_value.py @@ -1,7 +1,5 @@ -from contextlib import contextmanager -from typing import Generator, Generic, NamedTuple, Optional, Sequence, Type, Union +from typing import Generic, NamedTuple, Optional, Sequence, Type, Union -from langchain_core.runnables import RunnableConfig from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value @@ -48,22 +46,17 @@ def UpdateType(self) -> Type[Value]: def checkpoint(self) -> tuple[Optional[set[Value]], set[Value]]: return (self.names, self.seen) - @contextmanager def from_checkpoint( self, checkpoint: Optional[tuple[Optional[set[Value]], set[Value]]], - config: RunnableConfig, - ) -> Generator[Self, None, None]: + ) -> Self: empty = self.__class__(self.typ) + empty.key = self.key if checkpoint is not None: names, seen = checkpoint - empty.names = names.copy() if names is not None else None - empty.seen = seen.copy() - - try: - yield empty - finally: - pass + empty.names = names if names is not None else None + empty.seen = seen + return empty def update(self, values: Sequence[Union[Value, WaitForNames]]) -> bool: if wait_for_names := [v for v in values if isinstance(v, WaitForNames)]: diff --git a/libs/langgraph/langgraph/channels/ephemeral_value.py b/libs/langgraph/langgraph/channels/ephemeral_value.py index 59e34a5f07..3b393ac631 100644 --- a/libs/langgraph/langgraph/channels/ephemeral_value.py +++ b/libs/langgraph/langgraph/channels/ephemeral_value.py @@ -1,7 +1,5 @@ -from contextlib import contextmanager -from typing import Generator, Generic, Optional, Sequence, Type +from typing import Generic, Optional, Sequence, Type -from langchain_core.runnables import RunnableConfig from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value @@ -28,20 +26,12 @@ def UpdateType(self) -> Type[Value]: """The type of the update received by the channel.""" return self.typ - @contextmanager - def from_checkpoint( - self, checkpoint: Optional[Value], config: RunnableConfig - ) -> Generator[Self, None, None]: + def from_checkpoint(self, checkpoint: Optional[Value]) -> Self: empty = self.__class__(self.typ, self.guard) + empty.key = self.key if checkpoint is not None: empty.value = checkpoint - try: - yield empty - finally: - try: - del empty.value - except AttributeError: - pass + return empty def update(self, values: Sequence[Value]) -> bool: if len(values) == 0: diff --git a/libs/langgraph/langgraph/channels/last_value.py b/libs/langgraph/langgraph/channels/last_value.py index e5e59d111b..c4ef027972 100644 --- a/libs/langgraph/langgraph/channels/last_value.py +++ b/libs/langgraph/langgraph/channels/last_value.py @@ -1,7 +1,5 @@ -from contextlib import contextmanager -from typing import Generator, Generic, Optional, Sequence, Type +from typing import Generic, Optional, Sequence, Type -from langchain_core.runnables import RunnableConfig from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value @@ -27,20 +25,12 @@ def UpdateType(self) -> Type[Value]: """The type of the update received by the channel.""" return self.typ - @contextmanager - def from_checkpoint( - self, checkpoint: Optional[Value], config: RunnableConfig - ) -> Generator[Self, None, None]: + def from_checkpoint(self, checkpoint: Optional[Value]) -> Self: empty = self.__class__(self.typ) + empty.key = self.key if checkpoint is not None: empty.value = checkpoint - try: - yield empty - finally: - try: - del empty.value - except AttributeError: - pass + return empty def update(self, values: Sequence[Value]) -> bool: if len(values) == 0: diff --git a/libs/langgraph/langgraph/channels/named_barrier_value.py b/libs/langgraph/langgraph/channels/named_barrier_value.py index 023f54e6cc..1d24a58b8c 100644 --- a/libs/langgraph/langgraph/channels/named_barrier_value.py +++ b/libs/langgraph/langgraph/channels/named_barrier_value.py @@ -1,7 +1,5 @@ -from contextlib import contextmanager -from typing import Generator, Generic, Optional, Sequence, Type +from typing import Generic, Optional, Sequence, Type -from langchain_core.runnables import RunnableConfig from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value @@ -32,18 +30,12 @@ def UpdateType(self) -> Type[Value]: def checkpoint(self) -> set[Value]: return self.seen - @contextmanager - def from_checkpoint( - self, checkpoint: Optional[set[Value]], config: RunnableConfig - ) -> Generator[Self, None, None]: + def from_checkpoint(self, checkpoint: Optional[set[Value]]) -> Self: empty = self.__class__(self.typ, self.names) + empty.key = self.key if checkpoint is not None: - empty.seen = checkpoint.copy() - - try: - yield empty - finally: - pass + empty.seen = checkpoint + return empty def update(self, values: Sequence[Value]) -> bool: updated = False diff --git a/libs/langgraph/langgraph/channels/topic.py b/libs/langgraph/langgraph/channels/topic.py index 7b4b0b27db..ebbee63889 100644 --- a/libs/langgraph/langgraph/channels/topic.py +++ b/libs/langgraph/langgraph/channels/topic.py @@ -1,7 +1,5 @@ -from contextlib import contextmanager -from typing import Any, Generator, Generic, Iterator, Optional, Sequence, Type, Union +from typing import Any, Generic, Iterator, Optional, Sequence, Type, Union -from langchain_core.runnables import RunnableConfig from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value @@ -53,22 +51,15 @@ def UpdateType(self) -> Any: def checkpoint(self) -> tuple[set[Value], list[Value]]: return self.values - @contextmanager - def from_checkpoint( - self, - checkpoint: Optional[list[Value]], - config: RunnableConfig, - ) -> Generator[Self, None, None]: + def from_checkpoint(self, checkpoint: Optional[list[Value]]) -> Self: empty = self.__class__(self.typ, self.accumulate) + empty.key = self.key if checkpoint is not None: if isinstance(checkpoint, tuple): - empty.values = checkpoint[1].copy() + empty.values = checkpoint[1] else: - empty.values = checkpoint.copy() - try: - yield empty - finally: - pass + empty.values = checkpoint + return empty def update(self, values: Sequence[Union[Value, list[Value]]]) -> None: current = list(self.values) diff --git a/libs/langgraph/langgraph/channels/untracked_value.py b/libs/langgraph/langgraph/channels/untracked_value.py index a112b0e819..10cd38515a 100644 --- a/libs/langgraph/langgraph/channels/untracked_value.py +++ b/libs/langgraph/langgraph/channels/untracked_value.py @@ -1,7 +1,5 @@ -from contextlib import contextmanager -from typing import Generator, Generic, Optional, Sequence, Type +from typing import Generic, Optional, Sequence, Type -from langchain_core.runnables import RunnableConfig from typing_extensions import Self from langgraph.channels.base import BaseChannel, Value @@ -31,18 +29,10 @@ def UpdateType(self) -> Type[Value]: def checkpoint(self) -> Value: raise EmptyChannelError() - @contextmanager - def from_checkpoint( - self, checkpoint: Optional[Value], config: RunnableConfig - ) -> Generator[Self, None, None]: + def from_checkpoint(self, checkpoint: Optional[Value]) -> Self: empty = self.__class__(self.typ, self.guard) - try: - yield empty - finally: - try: - del empty.value - except AttributeError: - pass + empty.key = self.key + return empty def update(self, values: Sequence[Value]) -> bool: if len(values) == 0: diff --git a/libs/langgraph/langgraph/pregel/manager.py b/libs/langgraph/langgraph/pregel/manager.py index f70c86d46e..a7e0b72833 100644 --- a/libs/langgraph/langgraph/pregel/manager.py +++ b/libs/langgraph/langgraph/pregel/manager.py @@ -42,9 +42,7 @@ def ChannelsManager( with ExitStack() as stack: yield ( { - k: stack.enter_context( - v.from_checkpoint_named(checkpoint["channel_values"].get(k), config) - ) + k: v.from_checkpoint(checkpoint["channel_values"].get(k)) for k, v in channel_specs.items() }, ManagedValueMapping( @@ -100,11 +98,7 @@ async def AsyncChannelsManager( yield ( # channels: enter each channel with checkpoint { - k: await stack.enter_async_context( - v.afrom_checkpoint_named( - checkpoint["channel_values"].get(k), config - ) - ) + k: v.from_checkpoint(checkpoint["channel_values"].get(k)) for k, v in channel_specs.items() }, # managed: build mapping from spec to result diff --git a/libs/langgraph/tests/test_channels.py b/libs/langgraph/tests/test_channels.py index 69a624969c..7c6fb162b6 100644 --- a/libs/langgraph/tests/test_channels.py +++ b/libs/langgraph/tests/test_channels.py @@ -12,158 +12,77 @@ def test_last_value() -> None: - with LastValue(int).from_checkpoint(None, {}) as channel: - assert channel.ValueType is int - assert channel.UpdateType is int - - with pytest.raises(EmptyChannelError): - channel.get() - with pytest.raises(InvalidUpdateError): - channel.update([5, 6]) - - channel.update([3]) - assert channel.get() == 3 - channel.update([4]) - assert channel.get() == 4 - checkpoint = channel.checkpoint() - with LastValue(int).from_checkpoint(checkpoint, {}) as channel: - assert channel.get() == 4 - - -async def test_last_value_async() -> None: - async with LastValue(int).afrom_checkpoint(None, {}) as channel: - assert channel.ValueType is int - assert channel.UpdateType is int - - with pytest.raises(EmptyChannelError): - channel.get() - with pytest.raises(InvalidUpdateError): - channel.update([5, 6]) - - channel.update([3]) - assert channel.get() == 3 - channel.update([4]) - assert channel.get() == 4 - checkpoint = channel.checkpoint() - async with LastValue(int).afrom_checkpoint(checkpoint, {}) as channel: - assert channel.get() == 4 + channel = LastValue(int).from_checkpoint(None) + assert channel.ValueType is int + assert channel.UpdateType is int + + with pytest.raises(EmptyChannelError): + channel.get() + with pytest.raises(InvalidUpdateError): + channel.update([5, 6]) + + channel.update([3]) + assert channel.get() == 3 + channel.update([4]) + assert channel.get() == 4 + checkpoint = channel.checkpoint() + channel = LastValue(int).from_checkpoint(checkpoint) + assert channel.get() == 4 def test_topic() -> None: - with Topic(str).from_checkpoint(None, {}) as channel: - assert channel.ValueType is Sequence[str] - assert channel.UpdateType is Union[str, list[str]] - - assert channel.update(["a", "b"]) - assert channel.get() == ["a", "b"] - assert channel.update([["c", "d"], "d"]) - assert channel.get() == ["c", "d", "d"] - assert channel.update([]) - with pytest.raises(EmptyChannelError): - channel.get() - assert not channel.update([]), "channel already empty" - assert channel.update(["e"]) - assert channel.get() == ["e"] - checkpoint = channel.checkpoint() - with Topic(str).from_checkpoint(checkpoint, {}) as channel: - assert channel.get() == ["e"] - with Topic(str).from_checkpoint(checkpoint, {}) as channel_copy: - channel_copy.update(["f"]) - assert channel_copy.get() == ["f"] - assert channel.get() == ["e"] - - -async def test_topic_async() -> None: - async with Topic(str).afrom_checkpoint(None, {}) as channel: - assert channel.ValueType is Sequence[str] - assert channel.UpdateType is Union[str, list[str]] - - assert channel.update(["a", "b"]) - assert channel.get() == ["a", "b"] - assert channel.update(["b", ["c", "d"], "d"]) - assert channel.get() == ["b", "c", "d", "d"] - assert channel.update([]) - with pytest.raises(EmptyChannelError): - channel.get() - assert not channel.update([]), "channel already empty" - assert channel.update(["e"]) - assert channel.get() == ["e"] - checkpoint = channel.checkpoint() - async with Topic(str).afrom_checkpoint(checkpoint, {}) as channel: - assert channel.get() == ["e"] + channel = Topic(str).from_checkpoint(None) + assert channel.ValueType is Sequence[str] + assert channel.UpdateType is Union[str, list[str]] + + assert channel.update(["a", "b"]) + assert channel.get() == ["a", "b"] + assert channel.update([["c", "d"], "d"]) + assert channel.get() == ["c", "d", "d"] + assert channel.update([]) + with pytest.raises(EmptyChannelError): + channel.get() + assert not channel.update([]), "channel already empty" + assert channel.update(["e"]) + assert channel.get() == ["e"] + checkpoint = channel.checkpoint() + channel = Topic(str).from_checkpoint(checkpoint) + assert channel.get() == ["e"] + channel_copy = Topic(str).from_checkpoint(checkpoint) + channel_copy.update(["f"]) + assert channel_copy.get() == ["f"] + assert channel.get() == ["e"] def test_topic_accumulate() -> None: - with Topic(str, accumulate=True).from_checkpoint(None, {}) as channel: - assert channel.ValueType is Sequence[str] - assert channel.UpdateType is Union[str, list[str]] - - assert channel.update(["a", "b"]) - assert channel.get() == ["a", "b"] - assert channel.update(["b", ["c", "d"], "d"]) - assert channel.get() == ["a", "b", "b", "c", "d", "d"] - assert not channel.update([]) - assert channel.get() == ["a", "b", "b", "c", "d", "d"] - checkpoint = channel.checkpoint() - with Topic(str, accumulate=True).from_checkpoint(checkpoint, {}) as channel: - assert channel.get() == ["a", "b", "b", "c", "d", "d"] - assert channel.update(["e"]) - assert channel.get() == ["a", "b", "b", "c", "d", "d", "e"] - - -async def test_topic_accumulate_async() -> None: - async with Topic(str, accumulate=True).afrom_checkpoint(None, {}) as channel: - assert channel.ValueType is Sequence[str] - assert channel.UpdateType is Union[str, list[str]] - - assert channel.update(["a", "b"]) - assert channel.get() == ["a", "b"] - assert channel.update(["b", ["c", "d"], "d"]) - assert channel.get() == ["a", "b", "b", "c", "d", "d"] - assert not channel.update([]) - assert channel.get() == ["a", "b", "b", "c", "d", "d"] - checkpoint = channel.checkpoint() - async with Topic(str, accumulate=True).afrom_checkpoint(checkpoint, {}) as channel: - assert channel.get() == ["a", "b", "b", "c", "d", "d"] - assert channel.update(["e"]) - assert channel.get() == ["a", "b", "b", "c", "d", "d", "e"] + channel = Topic(str, accumulate=True).from_checkpoint(None) + assert channel.ValueType is Sequence[str] + assert channel.UpdateType is Union[str, list[str]] + + assert channel.update(["a", "b"]) + assert channel.get() == ["a", "b"] + assert channel.update(["b", ["c", "d"], "d"]) + assert channel.get() == ["a", "b", "b", "c", "d", "d"] + assert not channel.update([]) + assert channel.get() == ["a", "b", "b", "c", "d", "d"] + checkpoint = channel.checkpoint() + channel = Topic(str, accumulate=True).from_checkpoint(checkpoint) + assert channel.get() == ["a", "b", "b", "c", "d", "d"] + assert channel.update(["e"]) + assert channel.get() == ["a", "b", "b", "c", "d", "d", "e"] def test_binop() -> None: - with BinaryOperatorAggregate(int, operator.add).from_checkpoint( - None, {} - ) as channel: - assert channel.ValueType is int - assert channel.UpdateType is int - - assert channel.get() == 0 - - channel.update([1, 2, 3]) - assert channel.get() == 6 - channel.update([4]) - assert channel.get() == 10 - checkpoint = channel.checkpoint() - with BinaryOperatorAggregate(int, operator.add).from_checkpoint( - checkpoint, {} - ) as channel: - assert channel.get() == 10 - - -async def test_binop_async() -> None: - async with BinaryOperatorAggregate(int, operator.add).afrom_checkpoint( - None, {} - ) as channel: - assert channel.ValueType is int - assert channel.UpdateType is int - - assert channel.get() == 0 - - channel.update([1, 2, 3]) - assert channel.get() == 6 - channel.update([4]) - assert channel.get() == 10 - checkpoint = channel.checkpoint() - async with BinaryOperatorAggregate(int, operator.add).afrom_checkpoint( - checkpoint, {} - ) as channel: - assert channel.get() == 10 + channel = BinaryOperatorAggregate(int, operator.add).from_checkpoint(None) + assert channel.ValueType is int + assert channel.UpdateType is int + + assert channel.get() == 0 + + channel.update([1, 2, 3]) + assert channel.get() == 6 + channel.update([4]) + assert channel.get() == 10 + checkpoint = channel.checkpoint() + channel = BinaryOperatorAggregate(int, operator.add).from_checkpoint(checkpoint) + assert channel.get() == 10