Skip to content

Commit

Permalink
Make Channel.from_checkpoint a regular function
Browse files Browse the repository at this point in the history
- context manager no longer needed since Context became a managed value
  • Loading branch information
nfcampos committed Sep 13, 2024
1 parent d1c3423 commit 0931c66
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 306 deletions.
18 changes: 4 additions & 14 deletions libs/langgraph/langgraph/channels/any_value.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from contextlib import contextmanager
from typing import Generator, Generic, Optional, Sequence, Type
from typing import Generic, Optional, Sequence, Type

from langchain_core.runnables import RunnableConfig
from typing_extensions import Self

from langgraph.channels.base import BaseChannel, Value
Expand All @@ -28,20 +26,12 @@ def UpdateType(self) -> Type[Value]:
"""The type of the update received by the channel."""
return self.typ

@contextmanager
def from_checkpoint(
self, checkpoint: Optional[Value], config: RunnableConfig
) -> Generator[Self, None, None]:
def from_checkpoint(self, checkpoint: Optional[Value]) -> Self:
empty = self.__class__(self.typ)
empty.key = self.key
if checkpoint is not None:
empty.value = checkpoint
try:
yield empty
finally:
try:
del empty.value
except AttributeError:
pass
return empty

def update(self, values: Sequence[Value]) -> bool:
if len(values) == 0:
Expand Down
44 changes: 3 additions & 41 deletions libs/langgraph/langgraph/channels/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager, contextmanager
from typing import (
Any,
AsyncIterator,
Generic,
Iterator,
Optional,
Sequence,
TypeVar,
)

from langchain_core.runnables import RunnableConfig
from typing import Any, Generic, Optional, Sequence, TypeVar

from typing_extensions import Self

from langgraph.errors import EmptyChannelError, InvalidUpdateError
Expand Down Expand Up @@ -41,38 +31,10 @@ def checkpoint(self) -> Optional[C]:
or doesn't support checkpoints."""
return self.get()

@contextmanager
@abstractmethod
def from_checkpoint(
self, checkpoint: Optional[C], config: RunnableConfig
) -> Iterator[Self]:
"""Return a new identical channel, optionally initialized from a checkpoint.
If the checkpoint contains complex data structures, they should be copied."""

@contextmanager
def from_checkpoint_named(
self, checkpoint: Optional[C], config: RunnableConfig
) -> Iterator[Self]:
with self.from_checkpoint(checkpoint, config) as value:
value.key = self.key
yield value

@asynccontextmanager
async def afrom_checkpoint(
self, checkpoint: Optional[C], config: RunnableConfig
) -> AsyncIterator[Self]:
def from_checkpoint(self, checkpoint: Optional[C]) -> Self:
"""Return a new identical channel, optionally initialized from a checkpoint.
If the checkpoint contains complex data structures, they should be copied."""
with self.from_checkpoint(checkpoint, config) as value:
yield value

@asynccontextmanager
async def afrom_checkpoint_named(
self, checkpoint: Optional[C], config: RunnableConfig
) -> AsyncIterator[Self]:
async with self.afrom_checkpoint(checkpoint, config) as value:
value.key = self.key
yield value

# state methods

Expand Down
17 changes: 3 additions & 14 deletions libs/langgraph/langgraph/channels/binop.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import collections.abc
from contextlib import contextmanager
from typing import (
Callable,
Generator,
Generic,
Optional,
Sequence,
Type,
)

from langchain_core.runnables import RunnableConfig
from typing_extensions import NotRequired, Required, Self

from langgraph.channels.base import BaseChannel, Value
Expand Down Expand Up @@ -73,20 +70,12 @@ def UpdateType(self) -> Type[Value]:
"""The type of the update received by the channel."""
return self.typ

@contextmanager
def from_checkpoint(
self, checkpoint: Optional[Value], config: RunnableConfig
) -> Generator[Self, None, None]:
def from_checkpoint(self, checkpoint: Optional[Value]) -> Self:
empty = self.__class__(self.typ, self.operator)
empty.key = self.key
if checkpoint is not None:
empty.value = checkpoint
try:
yield empty
finally:
try:
del empty.value
except AttributeError:
pass
return empty

def update(self, values: Sequence[Value]) -> bool:
if not values:
Expand Down
19 changes: 6 additions & 13 deletions libs/langgraph/langgraph/channels/dynamic_barrier_value.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from contextlib import contextmanager
from typing import Generator, Generic, NamedTuple, Optional, Sequence, Type, Union
from typing import Generic, NamedTuple, Optional, Sequence, Type, Union

from langchain_core.runnables import RunnableConfig
from typing_extensions import Self

from langgraph.channels.base import BaseChannel, Value
Expand Down Expand Up @@ -48,22 +46,17 @@ def UpdateType(self) -> Type[Value]:
def checkpoint(self) -> tuple[Optional[set[Value]], set[Value]]:
return (self.names, self.seen)

@contextmanager
def from_checkpoint(
self,
checkpoint: Optional[tuple[Optional[set[Value]], set[Value]]],
config: RunnableConfig,
) -> Generator[Self, None, None]:
) -> Self:
empty = self.__class__(self.typ)
empty.key = self.key
if checkpoint is not None:
names, seen = checkpoint
empty.names = names.copy() if names is not None else None
empty.seen = seen.copy()

try:
yield empty
finally:
pass
empty.names = names if names is not None else None
empty.seen = seen
return empty

def update(self, values: Sequence[Union[Value, WaitForNames]]) -> bool:
if wait_for_names := [v for v in values if isinstance(v, WaitForNames)]:
Expand Down
18 changes: 4 additions & 14 deletions libs/langgraph/langgraph/channels/ephemeral_value.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from contextlib import contextmanager
from typing import Generator, Generic, Optional, Sequence, Type
from typing import Generic, Optional, Sequence, Type

from langchain_core.runnables import RunnableConfig
from typing_extensions import Self

from langgraph.channels.base import BaseChannel, Value
Expand All @@ -28,20 +26,12 @@ def UpdateType(self) -> Type[Value]:
"""The type of the update received by the channel."""
return self.typ

@contextmanager
def from_checkpoint(
self, checkpoint: Optional[Value], config: RunnableConfig
) -> Generator[Self, None, None]:
def from_checkpoint(self, checkpoint: Optional[Value]) -> Self:
empty = self.__class__(self.typ, self.guard)
empty.key = self.key
if checkpoint is not None:
empty.value = checkpoint
try:
yield empty
finally:
try:
del empty.value
except AttributeError:
pass
return empty

def update(self, values: Sequence[Value]) -> bool:
if len(values) == 0:
Expand Down
18 changes: 4 additions & 14 deletions libs/langgraph/langgraph/channels/last_value.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from contextlib import contextmanager
from typing import Generator, Generic, Optional, Sequence, Type
from typing import Generic, Optional, Sequence, Type

from langchain_core.runnables import RunnableConfig
from typing_extensions import Self

from langgraph.channels.base import BaseChannel, Value
Expand All @@ -27,20 +25,12 @@ def UpdateType(self) -> Type[Value]:
"""The type of the update received by the channel."""
return self.typ

@contextmanager
def from_checkpoint(
self, checkpoint: Optional[Value], config: RunnableConfig
) -> Generator[Self, None, None]:
def from_checkpoint(self, checkpoint: Optional[Value]) -> Self:
empty = self.__class__(self.typ)
empty.key = self.key
if checkpoint is not None:
empty.value = checkpoint
try:
yield empty
finally:
try:
del empty.value
except AttributeError:
pass
return empty

def update(self, values: Sequence[Value]) -> bool:
if len(values) == 0:
Expand Down
18 changes: 5 additions & 13 deletions libs/langgraph/langgraph/channels/named_barrier_value.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from contextlib import contextmanager
from typing import Generator, Generic, Optional, Sequence, Type
from typing import Generic, Optional, Sequence, Type

from langchain_core.runnables import RunnableConfig
from typing_extensions import Self

from langgraph.channels.base import BaseChannel, Value
Expand Down Expand Up @@ -32,18 +30,12 @@ def UpdateType(self) -> Type[Value]:
def checkpoint(self) -> set[Value]:
return self.seen

@contextmanager
def from_checkpoint(
self, checkpoint: Optional[set[Value]], config: RunnableConfig
) -> Generator[Self, None, None]:
def from_checkpoint(self, checkpoint: Optional[set[Value]]) -> Self:
empty = self.__class__(self.typ, self.names)
empty.key = self.key
if checkpoint is not None:
empty.seen = checkpoint.copy()

try:
yield empty
finally:
pass
empty.seen = checkpoint
return empty

def update(self, values: Sequence[Value]) -> bool:
updated = False
Expand Down
21 changes: 6 additions & 15 deletions libs/langgraph/langgraph/channels/topic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from contextlib import contextmanager
from typing import Any, Generator, Generic, Iterator, Optional, Sequence, Type, Union
from typing import Any, Generic, Iterator, Optional, Sequence, Type, Union

from langchain_core.runnables import RunnableConfig
from typing_extensions import Self

from langgraph.channels.base import BaseChannel, Value
Expand Down Expand Up @@ -53,22 +51,15 @@ def UpdateType(self) -> Any:
def checkpoint(self) -> tuple[set[Value], list[Value]]:
return self.values

@contextmanager
def from_checkpoint(
self,
checkpoint: Optional[list[Value]],
config: RunnableConfig,
) -> Generator[Self, None, None]:
def from_checkpoint(self, checkpoint: Optional[list[Value]]) -> Self:
empty = self.__class__(self.typ, self.accumulate)
empty.key = self.key
if checkpoint is not None:
if isinstance(checkpoint, tuple):
empty.values = checkpoint[1].copy()
empty.values = checkpoint[1]
else:
empty.values = checkpoint.copy()
try:
yield empty
finally:
pass
empty.values = checkpoint
return empty

def update(self, values: Sequence[Union[Value, list[Value]]]) -> None:
current = list(self.values)
Expand Down
18 changes: 4 additions & 14 deletions libs/langgraph/langgraph/channels/untracked_value.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from contextlib import contextmanager
from typing import Generator, Generic, Optional, Sequence, Type
from typing import Generic, Optional, Sequence, Type

from langchain_core.runnables import RunnableConfig
from typing_extensions import Self

from langgraph.channels.base import BaseChannel, Value
Expand Down Expand Up @@ -31,18 +29,10 @@ def UpdateType(self) -> Type[Value]:
def checkpoint(self) -> Value:
raise EmptyChannelError()

@contextmanager
def from_checkpoint(
self, checkpoint: Optional[Value], config: RunnableConfig
) -> Generator[Self, None, None]:
def from_checkpoint(self, checkpoint: Optional[Value]) -> Self:
empty = self.__class__(self.typ, self.guard)
try:
yield empty
finally:
try:
del empty.value
except AttributeError:
pass
empty.key = self.key
return empty

def update(self, values: Sequence[Value]) -> bool:
if len(values) == 0:
Expand Down
10 changes: 2 additions & 8 deletions libs/langgraph/langgraph/pregel/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def ChannelsManager(
with ExitStack() as stack:
yield (
{
k: stack.enter_context(
v.from_checkpoint_named(checkpoint["channel_values"].get(k), config)
)
k: v.from_checkpoint(checkpoint["channel_values"].get(k))
for k, v in channel_specs.items()
},
ManagedValueMapping(
Expand Down Expand Up @@ -100,11 +98,7 @@ async def AsyncChannelsManager(
yield (
# channels: enter each channel with checkpoint
{
k: await stack.enter_async_context(
v.afrom_checkpoint_named(
checkpoint["channel_values"].get(k), config
)
)
k: v.from_checkpoint(checkpoint["channel_values"].get(k))
for k, v in channel_specs.items()
},
# managed: build mapping from spec to result
Expand Down
Loading

0 comments on commit 0931c66

Please sign in to comment.