Skip to content

Commit

Permalink
Use __slots__ for Channels
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Sep 13, 2024
1 parent 911f433 commit 2cf6661
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 13 deletions.
3 changes: 1 addition & 2 deletions libs/langgraph/langgraph/channels/any_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ class AnyValue(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the last value received, assumes that if multiple values are
received, they are all equal."""

def __init__(self, typ: Type[Value]) -> None:
self.typ = typ
__slots__ = ("typ", "value")

def __eq__(self, value: object) -> bool:
return isinstance(value, AnyValue)
Expand Down
8 changes: 6 additions & 2 deletions libs/langgraph/langgraph/channels/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Generic, Optional, Sequence, TypeVar
from typing import Any, Generic, Optional, Sequence, Type, TypeVar

from typing_extensions import Self

Expand All @@ -11,7 +11,11 @@


class BaseChannel(Generic[Value, Update, C], ABC):
key: str = ""
__slots__ = ("key", "typ")

def __init__(self, typ: Type[Any], key: str = "") -> None:
self.typ = typ
self.key = key

@property
@abstractmethod
Expand Down
5 changes: 3 additions & 2 deletions libs/langgraph/langgraph/channels/binop.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ class BinaryOperatorAggregate(Generic[Value], BaseChannel[Value, Value, Value]):
```
"""

__slots__ = ("value", "operator")

def __init__(self, typ: Type[Value], operator: Callable[[Value, Value], Value]):
super().__init__(typ)
self.operator = operator
# keep the type exposed by ValueType/UpdateType as-is
self.typ = typ
# special forms from typing or collections.abc are not instantiable
# so we need to replace them with their concrete counterparts
typ = _strip_extras(typ)
Expand Down
4 changes: 3 additions & 1 deletion libs/langgraph/langgraph/channels/dynamic_barrier_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ class DynamicBarrierValue(
back to the "priming" state.
"""

__slots__ = ("names", "seen")

names: Optional[set[Value]]
seen: set[Value]

def __init__(self, typ: Type[Value]) -> None:
self.typ = typ
super().__init__(typ)
self.names = None
self.seen = set()

Expand Down
4 changes: 3 additions & 1 deletion libs/langgraph/langgraph/channels/ephemeral_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
class EphemeralValue(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the value received in the step immediately preceding, clears after."""

__slots__ = ("value", "guard")

def __init__(self, typ: Type[Value], guard: bool = True) -> None:
self.typ = typ
super().__init__(typ)
self.guard = guard

def __eq__(self, value: object) -> bool:
Expand Down
3 changes: 1 addition & 2 deletions libs/langgraph/langgraph/channels/last_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
class LastValue(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the last value received, can receive at most one value per step."""

def __init__(self, typ: Type[Value]) -> None:
self.typ = typ
__slots__ = ("value",)

def __eq__(self, value: object) -> bool:
return isinstance(value, LastValue)
Expand Down
4 changes: 3 additions & 1 deletion libs/langgraph/langgraph/channels/named_barrier_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
class NamedBarrierValue(Generic[Value], BaseChannel[Value, Value, set[Value]]):
"""A channel that waits until all named values are received before making the value available."""

__slots__ = ("names", "seen")

def __init__(self, typ: Type[Value], names: set[Value]) -> None:
self.typ = typ
super().__init__(typ)
self.names = names
self.seen = set()

Expand Down
4 changes: 3 additions & 1 deletion libs/langgraph/langgraph/channels/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ class Topic(
accumulate: Whether to accumulate values across steps. If False, the channel will be emptied after each step.
"""

__slots__ = ("values", "accumulate")

def __init__(self, typ: Type[Value], accumulate: bool = False) -> None:
super().__init__(typ)
# attrs
self.typ = typ
self.accumulate = accumulate
# state
self.values = list[Value]()
Expand Down
4 changes: 3 additions & 1 deletion libs/langgraph/langgraph/channels/untracked_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
class UntrackedValue(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the last value received, never checkpointed."""

__slots__ = ("value", "guard")

def __init__(self, typ: Type[Value], guard: bool = True) -> None:
self.typ = typ
super().__init__(typ)
self.guard = guard

def __eq__(self, value: object) -> bool:
Expand Down

0 comments on commit 2cf6661

Please sign in to comment.