From 9a8cc75ea29e170c1a077f423d2a2b9fbe5c935c Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 19 Sep 2024 11:38:41 -0700 Subject: [PATCH 1/9] Fix some typing issues in langgraph lib --- .../langgraph/checkpoint/sqlite/aio.py | 3 +- .../langgraph/checkpoint/base/__init__.py | 5 +- .../langgraph/checkpoint/serde/types.py | 11 +-- libs/langgraph/Makefile | 3 +- libs/langgraph/bench/react_agent.py | 2 +- libs/langgraph/langgraph/_api/deprecation.py | 8 +- libs/langgraph/langgraph/channels/binop.py | 2 +- .../langgraph/channels/named_barrier_value.py | 5 +- libs/langgraph/langgraph/graph/graph.py | 47 ++++++----- libs/langgraph/langgraph/graph/message.py | 15 +++- libs/langgraph/langgraph/graph/state.py | 22 ++++-- libs/langgraph/langgraph/managed/base.py | 6 +- libs/langgraph/langgraph/managed/context.py | 39 ++++++--- .../langgraph/managed/shared_value.py | 11 +-- .../langgraph/prebuilt/chat_agent_executor.py | 34 ++++---- .../langgraph/prebuilt/tool_executor.py | 7 +- .../langgraph/langgraph/prebuilt/tool_node.py | 7 +- .../langgraph/prebuilt/tool_validator.py | 2 +- libs/langgraph/langgraph/pregel/__init__.py | 79 ++++++++++--------- libs/langgraph/langgraph/pregel/algo.py | 48 +++++++---- libs/langgraph/langgraph/pregel/debug.py | 16 +++- libs/langgraph/langgraph/pregel/executor.py | 10 ++- libs/langgraph/langgraph/pregel/io.py | 10 +-- libs/langgraph/langgraph/pregel/loop.py | 53 +++++++------ libs/langgraph/langgraph/pregel/manager.py | 10 +-- libs/langgraph/langgraph/pregel/messages.py | 15 ++-- libs/langgraph/langgraph/pregel/read.py | 2 +- libs/langgraph/langgraph/pregel/retry.py | 4 +- libs/langgraph/langgraph/pregel/runner.py | 48 +++++------ libs/langgraph/langgraph/pregel/types.py | 4 +- libs/langgraph/langgraph/pregel/utils.py | 4 +- libs/langgraph/langgraph/pregel/write.py | 4 +- libs/langgraph/langgraph/utils/config.py | 33 +++++--- libs/langgraph/langgraph/utils/fields.py | 2 +- libs/langgraph/langgraph/utils/pydantic.py | 2 +- libs/langgraph/langgraph/utils/queue.py | 10 ++- libs/langgraph/langgraph/utils/runnable.py | 62 ++++++++++----- libs/langgraph/poetry.lock | 74 ++++++++++------- libs/langgraph/pyproject.toml | 9 ++- libs/langgraph/tests/test_pregel_async.py | 2 +- 40 files changed, 432 insertions(+), 298 deletions(-) diff --git a/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py b/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py index 39476ae8e..d2347006e 100644 --- a/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py +++ b/libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py @@ -7,7 +7,6 @@ Callable, Dict, Iterator, - List, Optional, Sequence, Tuple, @@ -216,7 +215,7 @@ def put( ).result() def put_writes( - self, config: RunnableConfig, writes: List[Tuple[str, Any]], task_id: str + self, config: RunnableConfig, writes: Sequence[Tuple[str, Any]], task_id: str ) -> None: return asyncio.run_coroutine_threadsafe( self.aput_writes(config, writes, task_id), self.loop diff --git a/libs/checkpoint/langgraph/checkpoint/base/__init__.py b/libs/checkpoint/langgraph/checkpoint/base/__init__.py index 80441dad1..822389cce 100644 --- a/libs/checkpoint/langgraph/checkpoint/base/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/base/__init__.py @@ -10,6 +10,7 @@ Mapping, NamedTuple, Optional, + Sequence, Tuple, TypedDict, Union, @@ -301,7 +302,7 @@ def put( def put_writes( self, config: RunnableConfig, - writes: List[Tuple[str, Any]], + writes: Sequence[Tuple[str, Any]], task_id: str, ) -> None: """Store intermediate writes linked to a checkpoint. @@ -393,7 +394,7 @@ async def aput( async def aput_writes( self, config: RunnableConfig, - writes: List[Tuple[str, Any]], + writes: Sequence[Tuple[str, Any]], task_id: str, ) -> None: """Asynchronously store intermediate writes linked to a checkpoint. diff --git a/libs/checkpoint/langgraph/checkpoint/serde/types.py b/libs/checkpoint/langgraph/checkpoint/serde/types.py index f86c2e558..43a5bf878 100644 --- a/libs/checkpoint/langgraph/checkpoint/serde/types.py +++ b/libs/checkpoint/langgraph/checkpoint/serde/types.py @@ -1,7 +1,5 @@ from typing import ( Any, - AsyncGenerator, - Generator, Optional, Protocol, Sequence, @@ -9,7 +7,6 @@ runtime_checkable, ) -from langchain_core.runnables import RunnableConfig from typing_extensions import Self ERROR = "__error__" @@ -31,13 +28,7 @@ def UpdateType(self) -> Any: ... def checkpoint(self) -> Optional[C]: ... - def from_checkpoint( - self, checkpoint: Optional[C], config: RunnableConfig - ) -> Generator[Self, None, None]: ... - - async def afrom_checkpoint( - self, checkpoint: Optional[C], config: RunnableConfig - ) -> AsyncGenerator[Self, None]: ... + def from_checkpoint(self, checkpoint: Optional[C]) -> Self: ... def update(self, values: Sequence[Update]) -> bool: ... diff --git a/libs/langgraph/Makefile b/libs/langgraph/Makefile index 3d8a175e7..62ef303bb 100644 --- a/libs/langgraph/Makefile +++ b/libs/langgraph/Makefile @@ -74,7 +74,8 @@ lint lint_diff lint_package lint_tests: poetry run ruff check . [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff [ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I $(PYTHON_FILES) - [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) + [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) + [ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) format format_diff: poetry run ruff format $(PYTHON_FILES) diff --git a/libs/langgraph/bench/react_agent.py b/libs/langgraph/bench/react_agent.py index 4ad671f89..a6f84f2dc 100644 --- a/libs/langgraph/bench/react_agent.py +++ b/libs/langgraph/bench/react_agent.py @@ -14,7 +14,7 @@ from langgraph.pregel import Pregel -def react_agent(n_tools: int, checkpointer: BaseCheckpointSaver) -> Pregel: +def react_agent(n_tools: int, checkpointer: Optional[BaseCheckpointSaver]) -> Pregel: class FakeFuntionChatModel(FakeMessagesListChatModel): def bind_tools(self, functions: list): return self diff --git a/libs/langgraph/langgraph/_api/deprecation.py b/libs/langgraph/langgraph/_api/deprecation.py index 6fa419e83..c93e09de8 100644 --- a/libs/langgraph/langgraph/_api/deprecation.py +++ b/libs/langgraph/langgraph/_api/deprecation.py @@ -21,14 +21,14 @@ def decorator(obj: Union[F, C]) -> Union[F, C]: f" removed in {removal_str}. Use {alternative} instead.{example}" ) if isinstance(obj, type): - original_init = obj.__init__ + original_init = obj.__init__ # type: ignore[misc] @functools.wraps(original_init) - def new_init(self, *args: Any, **kwargs: Any) -> None: + def new_init(self, *args: Any, **kwargs: Any) -> None: # type: ignore[no-untyped-def] warnings.warn(message, LangGraphDeprecationWarning, stacklevel=2) original_init(self, *args, **kwargs) - obj.__init__ = new_init + obj.__init__ = new_init # type: ignore[misc] docstring = ( f"**Deprecated**: This class is deprecated as of version {since}. " @@ -68,7 +68,7 @@ def deprecated_parameter( ) -> Callable[[F], F]: def decorator(func: F) -> F: @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] if arg_name in kwargs: warnings.warn( f"Parameter '{arg_name}' in function '{func.__name__}' is " diff --git a/libs/langgraph/langgraph/channels/binop.py b/libs/langgraph/langgraph/channels/binop.py index d3fe4fce2..a2360142b 100644 --- a/libs/langgraph/langgraph/channels/binop.py +++ b/libs/langgraph/langgraph/channels/binop.py @@ -14,7 +14,7 @@ # Adapted from typing_extensions -def _strip_extras(t): +def _strip_extras(t): # type: ignore[no-untyped-def] """Strips Annotated, Required and NotRequired from a given type.""" if hasattr(t, "__origin__"): return _strip_extras(t.__origin__) diff --git a/libs/langgraph/langgraph/channels/named_barrier_value.py b/libs/langgraph/langgraph/channels/named_barrier_value.py index a804a3052..4a1d990ca 100644 --- a/libs/langgraph/langgraph/channels/named_barrier_value.py +++ b/libs/langgraph/langgraph/channels/named_barrier_value.py @@ -11,10 +11,13 @@ class NamedBarrierValue(Generic[Value], BaseChannel[Value, Value, set[Value]]): __slots__ = ("names", "seen") + names: set[Value] + seen: set[Value] + def __init__(self, typ: Type[Value], names: set[Value]) -> None: super().__init__(typ) self.names = names - self.seen = set() + self.seen: set[str] = set() def __eq__(self, value: object) -> bool: return isinstance(value, NamedBarrierValue) and value.names == self.names diff --git a/libs/langgraph/langgraph/graph/graph.py b/libs/langgraph/langgraph/graph/graph.py index b04594bdd..706e5253c 100644 --- a/libs/langgraph/langgraph/graph/graph.py +++ b/libs/langgraph/langgraph/graph/graph.py @@ -56,9 +56,11 @@ class Branch(NamedTuple): def run( self, - writer: Callable[[list[str], RunnableConfig], None], + writer: Callable[ + [list[Union[str, Send]], RunnableConfig], Optional[ChannelWrite] + ], reader: Optional[Callable[[RunnableConfig], Any]] = None, - ) -> None: + ) -> RunnableCallable: return ChannelWrite.register_writer( RunnableCallable( func=self._route, @@ -75,8 +77,10 @@ def _route( input: Any, config: RunnableConfig, *, - reader: Optional[Callable[[], Any]], - writer: Callable[[list[str], RunnableConfig], None], + reader: Optional[Callable[[RunnableConfig], Any]], + writer: Callable[ + [list[Union[str, Send]], RunnableConfig], Optional[ChannelWrite] + ], ) -> Runnable: if reader: value = reader(config) @@ -94,8 +98,10 @@ async def _aroute( input: Any, config: RunnableConfig, *, - reader: Optional[Callable[[], Any]], - writer: Callable[[list[str], RunnableConfig], Optional[Runnable]], + reader: Optional[Callable[[RunnableConfig], Any]], + writer: Callable[ + [list[Union[str, Send]], RunnableConfig], Optional[ChannelWrite] + ], ) -> Runnable: if reader: value = await asyncio.to_thread(reader, config) @@ -110,7 +116,9 @@ async def _aroute( def _finish( self, - writer: Callable[[list[str], RunnableConfig], None], + writer: Callable[ + [list[Union[str, Send]], RunnableConfig], Optional[ChannelWrite] + ], input: Any, result: Any, config: RunnableConfig, @@ -378,8 +386,8 @@ def validate(self, interrupt: Optional[Sequence[str]] = None) -> None: def compile( self, checkpointer: Optional[BaseCheckpointSaver] = None, - interrupt_before: Optional[Union[All, Sequence[str]]] = None, - interrupt_after: Optional[Union[All, Sequence[str]]] = None, + interrupt_before: Optional[Union[All, list[str]]] = None, + interrupt_after: Optional[Union[All, list[str]]] = None, debug: bool = False, ) -> "CompiledGraph": # assign default values @@ -451,7 +459,7 @@ def attach_edge(self, start: str, end: str) -> None: else: # subscribe to start channel self.nodes[end].triggers.append(start) - self.nodes[end].channels.append(start) + cast(list[str], self.nodes[end].channels).append(start) def attach_branch(self, start: str, name: str, branch: Branch) -> None: def branch_writer( @@ -530,17 +538,18 @@ def add_edge( subgraph.trim_first_node() subgraph.trim_last_node() if len(subgraph.nodes) > 1: - end_nodes[key], start_nodes[key] = graph.extend( - subgraph, prefix=key - ) + e, s = graph.extend(subgraph, prefix=key) + if s is None or e is None: + raise ValueError(f"Could not extend subgraph {key}") + end_nodes[key], start_nodes[key] = e, s else: - n = graph.add_node(node, key, metadata=metadata or None) - start_nodes[key] = n - end_nodes[key] = n + nn = graph.add_node(node, key, metadata=metadata or None) + start_nodes[key] = nn + end_nodes[key] = nn else: - n = graph.add_node(node, key, metadata=metadata or None) - start_nodes[key] = n - end_nodes[key] = n + nn = graph.add_node(node, key, metadata=metadata or None) + start_nodes[key] = nn + end_nodes[key] = nn for start, end in sorted(self.builder._all_edges): add_edge(start, end) for start, branches in self.builder.branches.items(): diff --git a/libs/langgraph/langgraph/graph/message.py b/libs/langgraph/langgraph/graph/message.py index 402812f84..34bc9c090 100644 --- a/libs/langgraph/langgraph/graph/message.py +++ b/libs/langgraph/langgraph/graph/message.py @@ -1,8 +1,9 @@ import uuid -from typing import Annotated, TypedDict, Union +from typing import Annotated, TypedDict, Union, cast from langchain_core.messages import ( AnyMessage, + BaseMessageChunk, MessageLikeRepresentation, RemoveMessage, convert_to_messages, @@ -66,8 +67,14 @@ def add_messages(left: Messages, right: Messages) -> Messages: if not isinstance(right, list): right = [right] # coerce to message - left = [message_chunk_to_message(m) for m in convert_to_messages(left)] - right = [message_chunk_to_message(m) for m in convert_to_messages(right)] + left = [ + message_chunk_to_message(cast(BaseMessageChunk, m)) + for m in convert_to_messages(left) + ] + right = [ + message_chunk_to_message(cast(BaseMessageChunk, m)) + for m in convert_to_messages(right) + ] # assign missing ids for m in left: if m.id is None: @@ -144,7 +151,7 @@ class MessageGraph(StateGraph): """ def __init__(self) -> None: - super().__init__(Annotated[list[AnyMessage], add_messages]) + super().__init__(Annotated[list[AnyMessage], add_messages]) # type: ignore[arg-type] class MessagesState(TypedDict): diff --git a/libs/langgraph/langgraph/graph/state.py b/libs/langgraph/langgraph/graph/state.py index 33601dd5e..e66cff53d 100644 --- a/libs/langgraph/langgraph/graph/state.py +++ b/libs/langgraph/langgraph/graph/state.py @@ -66,7 +66,7 @@ def _warn_invalid_state_schema(schema: Union[Type[Any], Any]) -> None: class StateNodeSpec(NamedTuple): runnable: Runnable - metadata: dict[str, Any] + metadata: Optional[dict[str, Any]] input: Type[Any] retry_policy: Optional[RetryPolicy] @@ -318,7 +318,11 @@ def add_node( ) if not isinstance(node, str): action = node - node = getattr(action, "name", action.__name__) + node = getattr(action, "name", getattr(action, "__name__")) + if node is None: + raise ValueError( + "Node name must be provided if action is not a function" + ) if node in self.nodes: raise ValueError(f"Node `{node}` already present.") if node == END or node == START: @@ -392,8 +396,8 @@ def compile( checkpointer: Optional[BaseCheckpointSaver] = None, *, store: Optional[BaseStore] = None, - interrupt_before: Optional[Union[All, Sequence[str]]] = None, - interrupt_after: Optional[Union[All, Sequence[str]]] = None, + interrupt_before: Optional[Union[All, list[str]]] = None, + interrupt_after: Optional[Union[All, list[str]]] = None, debug: bool = False, ) -> "CompiledStateGraph": """Compiles the state graph into a `CompiledGraph` object. @@ -554,7 +558,7 @@ def _get_state_key(input: Union[None, dict, Any], *, key: str) -> Any: ), ], ) - else: + elif node is not None: input_schema = node.input if node else self.builder.schema input_values = {k: k for k in self.builder.schemas[input_schema]} is_single_input = len(input_values) == 1 and "__root__" in input_values @@ -582,6 +586,8 @@ def _get_state_key(input: Union[None, dict, Any], *, key: str) -> Any: retry_policy=node.retry_policy, bound=node.runnable, ) + else: + raise RuntimeError def attach_edge(self, starts: Union[str, Sequence[str]], end: str) -> None: if isinstance(starts, str): @@ -613,7 +619,7 @@ def attach_edge(self, starts: Union[str, Sequence[str]], end: str) -> None: def attach_branch(self, start: str, name: str, branch: Branch) -> None: def branch_writer( packets: list[Union[str, Send]], config: RunnableConfig - ) -> Optional[ChannelWrite]: + ) -> None: if filtered := [p for p in packets if p != END]: writes = [ ( @@ -782,12 +788,12 @@ def _get_schema( else: keys = list(schemas[typ].keys()) if len(keys) == 1 and keys[0] == "__root__": - return create_model( # type: ignore[call-overload] + return create_model( name, root=(channels[keys[0]].UpdateType, None), ) else: - return create_model( # type: ignore[call-overload] + return create_model( name, field_definitions={ k: ( diff --git a/libs/langgraph/langgraph/managed/base.py b/libs/langgraph/langgraph/managed/base.py index b86388930..3d4eb69f3 100644 --- a/libs/langgraph/langgraph/managed/base.py +++ b/libs/langgraph/langgraph/managed/base.py @@ -106,7 +106,9 @@ def is_writable_managed_value(value: Any) -> TypeGuard[Type[WritableManagedValue class ManagedValueMapping(dict[str, ManagedValue]): - def replace_runtime_values(self, step: int, values: Union[dict[str, Any], Any]): + def replace_runtime_values( + self, step: int, values: Union[dict[str, Any], Any] + ) -> None: if not self or not values: return if all(not mv.runtime for mv in self.values()): @@ -128,7 +130,7 @@ def replace_runtime_values(self, step: int, values: Union[dict[str, Any], Any]): def replace_runtime_placeholders( self, step: int, values: Union[dict[str, Any], Any] - ): + ) -> None: if not self or not values: return if all(not mv.runtime for mv in self.values()): diff --git a/libs/langgraph/langgraph/managed/context.py b/libs/langgraph/langgraph/managed/context.py index 43cff5e67..df64419ec 100644 --- a/libs/langgraph/langgraph/managed/context.py +++ b/libs/langgraph/langgraph/managed/context.py @@ -4,7 +4,9 @@ Any, AsyncContextManager, AsyncIterator, + Callable, ContextManager, + Generic, Iterator, Optional, Type, @@ -17,15 +19,26 @@ from langgraph.managed.base import ConfiguredManagedValue, ManagedValue, V -class Context(ManagedValue): +class Context(ManagedValue[V], Generic[V]): runtime = True value: V @staticmethod def of( - ctx: Union[None, Type[ContextManager[V]], Type[AsyncContextManager[V]]] = None, - actx: Optional[Type[AsyncContextManager[V]]] = None, + ctx: Union[ + None, + Callable[..., ContextManager[V]], + Type[ContextManager[V]], + Callable[..., AsyncContextManager[V]], + Type[AsyncContextManager[V]], + ] = None, + actx: Optional[ + Union[ + Callable[..., AsyncContextManager[V]], + Type[AsyncContextManager[V]], + ] + ] = None, ) -> ConfiguredManagedValue: if ctx is None and actx is None: raise ValueError("Must provide either sync or async context manager.") @@ -40,11 +53,11 @@ def enter(cls, config: RunnableConfig, **kwargs: Any) -> Iterator[Self]: "Synchronous context manager not found. Please initialize Context value with a sync context manager, or invoke your graph asynchronously." ) ctx = ( - self.ctx(config) + self.ctx(config) # type: ignore[call-arg] if signature(self.ctx).parameters.get("config") else self.ctx() ) - with ctx as v: + with ctx as v: # type: ignore[union-attr] self.value = v yield self @@ -54,24 +67,32 @@ async def aenter(cls, config: RunnableConfig, **kwargs: Any) -> AsyncIterator[Se async with super().aenter(config, **kwargs) as self: if self.actx is not None: ctx = ( - self.actx(config) + self.actx(config) # type: ignore[call-arg] if signature(self.actx).parameters.get("config") else self.actx() ) - else: + elif self.ctx is not None: ctx = ( - self.ctx(config) + self.ctx(config) # type: ignore if signature(self.ctx).parameters.get("config") else self.ctx() ) + else: + raise ValueError( + "Asynchronous context manager not found. Please initialize Context value with an async context manager, or invoke your graph synchronously." + ) if hasattr(ctx, "__aenter__"): async with ctx as v: self.value = v yield self - else: + elif hasattr(ctx, "__enter__") and hasattr(ctx, "__exit__"): with ctx as v: self.value = v yield self + else: + raise ValueError( + "Context manager must have either __enter__ or __aenter__ method." + ) def __init__( self, diff --git a/libs/langgraph/langgraph/managed/shared_value.py b/libs/langgraph/langgraph/managed/shared_value.py index f5e0561bd..9a624c685 100644 --- a/libs/langgraph/langgraph/managed/shared_value.py +++ b/libs/langgraph/langgraph/managed/shared_value.py @@ -7,6 +7,7 @@ Optional, Sequence, Type, + cast, ) from langchain_core.runnables import RunnableConfig @@ -30,7 +31,7 @@ # Adapted from typing_extensions -def _strip_extras(t): +def _strip_extras(t): # type: ignore[no-untyped-def] """Strips Annotated, Required and NotRequired from a given type.""" if hasattr(t, "__origin__"): return _strip_extras(t.__origin__) @@ -82,9 +83,9 @@ def __init__( raise ValueError("SharedValue must be a dict") self.scope = scope self.value: Value = {} - self.store: BaseStore = config["configurable"].get(CONFIG_KEY_STORE) + self.store = cast(BaseStore, config["configurable"].get(CONFIG_KEY_STORE)) if self.store is None: - self.ns: Optional[str] = None + pass elif scope_value := config["configurable"].get(self.scope): self.ns = f"scoped:{scope}:{key}:{scope_value}" else: @@ -98,12 +99,12 @@ def __call__(self, step: int) -> Value: def _process_update( self, values: Sequence[Update] ) -> list[tuple[str, str, Optional[dict[str, Any]]]]: - writes = [] + writes: list[tuple[str, str, Optional[dict[str, Any]]]] = [] for vv in values: for k, v in vv.items(): if v is None: if k in self.value: - self.value[k] = None + del self.value[k] writes.append((self.ns, k, None)) elif not isinstance(v, dict): raise InvalidUpdateError("Received a non-dict value") diff --git a/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py b/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py index d5a1dc88e..d4a6157cb 100644 --- a/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py +++ b/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py @@ -1,6 +1,7 @@ from typing import ( Annotated, Callable, + Literal, Optional, Sequence, Type, @@ -9,7 +10,7 @@ Union, ) -from langchain_core.language_models import LanguageModelLike +from langchain_core.language_models import BaseChatModel from langchain_core.messages import ( AIMessage, BaseMessage, @@ -129,15 +130,15 @@ def _get_model_preprocessing_runnable( @deprecated_parameter("messages_modifier", "0.1.9", "state_modifier", removal="0.3.0") def create_react_agent( - model: LanguageModelLike, + model: BaseChatModel, tools: Union[ToolExecutor, Sequence[BaseTool], ToolNode], *, state_schema: Optional[StateSchemaType] = None, messages_modifier: Optional[MessagesModifier] = None, state_modifier: Optional[StateModifier] = None, checkpointer: Optional[BaseCheckpointSaver] = None, - interrupt_before: Optional[Sequence[str]] = None, - interrupt_after: Optional[Sequence[str]] = None, + interrupt_before: Optional[list[str]] = None, + interrupt_after: Optional[list[str]] = None, debug: bool = False, ) -> CompiledGraph: """Creates a graph that works with a chat model that utilizes tool calling. @@ -421,7 +422,7 @@ class Agent,Tools otherClass tool_classes = tools.tools tool_node = ToolNode(tool_classes) elif isinstance(tools, ToolNode): - tool_classes = tools.tools_by_name.values() + tool_classes = list(tools.tools_by_name.values()) tool_node = tools else: tool_classes = tools @@ -429,11 +430,11 @@ class Agent,Tools otherClass model = model.bind_tools(tool_classes) # Define the function that determines whether to continue or not - def should_continue(state: AgentState): + def should_continue(state: AgentState) -> Literal["continue", "end"]: messages = state["messages"] last_message = messages[-1] # If there is no function call, then we finish - if not last_message.tool_calls: + if not isinstance(last_message, AIMessage) or not last_message.tool_calls: return "end" # Otherwise if there is, we continue else: @@ -443,12 +444,13 @@ def should_continue(state: AgentState): model_runnable = preprocessor | model # Define the function that calls the model - def call_model( - state: AgentState, - config: RunnableConfig, - ): + def call_model(state: AgentState, config: RunnableConfig) -> AgentState: response = model_runnable.invoke(state, config) - if state["is_last_step"] and response.tool_calls: + if ( + state["is_last_step"] + and isinstance(response, AIMessage) + and response.tool_calls + ): return { "messages": [ AIMessage( @@ -460,9 +462,13 @@ def call_model( # We return a list, because this will get added to the existing list return {"messages": [response]} - async def acall_model(state: AgentState, config: RunnableConfig): + async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: response = await model_runnable.ainvoke(state, config) - if state["is_last_step"] and response.tool_calls: + if ( + state["is_last_step"] + and isinstance(response, AIMessage) + and response.tool_calls + ): return { "messages": [ AIMessage( diff --git a/libs/langgraph/langgraph/prebuilt/tool_executor.py b/libs/langgraph/langgraph/prebuilt/tool_executor.py index 341f4874d..d939f23d9 100644 --- a/libs/langgraph/langgraph/prebuilt/tool_executor.py +++ b/libs/langgraph/langgraph/prebuilt/tool_executor.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Sequence, Union +from typing import Any, Callable, Sequence, Union, cast from langchain_core.load.serializable import Serializable from langchain_core.runnables import RunnableConfig @@ -101,10 +101,11 @@ def __init__( ) -> None: super().__init__(self._execute, afunc=self._aexecute, trace=False) tools_ = [ - tool if isinstance(tool, BaseTool) else create_tool(tool) for tool in tools + tool if isinstance(tool, BaseTool) else cast(BaseTool, create_tool(tool)) + for tool in tools ] self.tools = tools_ - self.tool_map = {t.name: t for t in tools} + self.tool_map = {t.name: t for t in tools_} self.invalid_tool_msg_template = invalid_tool_msg_template def _execute( diff --git a/libs/langgraph/langgraph/prebuilt/tool_node.py b/libs/langgraph/langgraph/prebuilt/tool_node.py index a596e829f..c0f7c83f6 100644 --- a/libs/langgraph/langgraph/prebuilt/tool_node.py +++ b/libs/langgraph/langgraph/prebuilt/tool_node.py @@ -94,7 +94,7 @@ def __init__( self.handle_tool_errors = handle_tool_errors for tool_ in tools: if not isinstance(tool_, BaseTool): - tool_ = create_tool(tool_) + tool_ = cast(BaseTool, create_tool(tool_)) self.tools_by_name[tool_.name] = tool_ def _func( @@ -188,10 +188,7 @@ def _parse_input( if not isinstance(message, AIMessage): raise ValueError("Last message is not an AIMessage") - tool_calls = [ - self._inject_state(call, input) - for call in cast(AIMessage, message).tool_calls - ] + tool_calls = [self._inject_state(call, input) for call in message.tool_calls] return tool_calls, output_type def _validate_tool_call(self, call: ToolCall) -> Optional[ToolMessage]: diff --git a/libs/langgraph/langgraph/prebuilt/tool_validator.py b/libs/langgraph/langgraph/prebuilt/tool_validator.py index 2222e7f2f..401a35db7 100644 --- a/libs/langgraph/langgraph/prebuilt/tool_validator.py +++ b/libs/langgraph/langgraph/prebuilt/tool_validator.py @@ -211,7 +211,7 @@ def _func( """Validate and run tool calls synchronously.""" output_type, message = self._get_message(input) - def run_one(call: ToolCall): + def run_one(call: ToolCall) -> ToolMessage: schema = self.schemas_by_name[call["name"]] try: if issubclass(schema, BaseModel): diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index 802690688..5389b7ddf 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -95,7 +95,7 @@ patch_configurable, ) from langgraph.utils.pydantic import create_model -from langgraph.utils.queue import AsyncQueue, SyncQueue +from langgraph.utils.queue import AsyncQueue, SyncQueue # type: ignore[attr-defined] from langgraph.utils.runnable import RunnableCallable WriteValue = Union[Callable[[Input], Output], Any] @@ -172,9 +172,9 @@ def write_to( class Pregel(Runnable[Union[dict[str, Any], Any], Union[dict[str, Any], Any]]): - nodes: Mapping[str, PregelNode] + nodes: dict[str, PregelNode] - channels: Mapping[str, Union[BaseChannel, ManagedValueSpec]] + channels: dict[str, Union[BaseChannel, ManagedValueSpec]] stream_mode: StreamMode = "values" """Mode to stream output, defaults to 'values'.""" @@ -214,8 +214,8 @@ class Pregel(Runnable[Union[dict[str, Any], Any], Union[dict[str, Any], Any]]): def __init__( self, *, - nodes: Mapping[str, PregelNode], - channels: Mapping[str, Union[BaseChannel, ManagedValueSpec]] = None, + nodes: dict[str, PregelNode], + channels: Optional[dict[str, Union[BaseChannel, ManagedValueSpec]]], auto_validate: bool = True, stream_mode: StreamMode = "values", output_channels: Union[str, Sequence[str]], @@ -256,12 +256,14 @@ def copy(self, update: dict[str, Any]) -> Self: return self.__class__(**attrs) def with_config(self, config: RunnableConfig | None = None, **kwargs: Any) -> Self: - return self.copy({"config": merge_configs(self.config, config, kwargs)}) + return self.copy( + {"config": merge_configs(self.config, config, cast(RunnableConfig, kwargs))} + ) def validate(self) -> Self: validate_graph( self.nodes, - self.channels, + {k: v for k, v in self.channels.items() if isinstance(v, BaseChannel)}, self.input_channels, self.output_channels, self.stream_channels, @@ -312,11 +314,12 @@ def get_input_schema( if isinstance(self.input_channels, str): return super().get_input_schema(config) else: - return create_model( # type: ignore[call-overload] + return create_model( self.get_name("Input"), field_definitions={ k: (self.channels[k].UpdateType, None) for k in self.input_channels or self.channels.keys() + if isinstance(self.channels[k], BaseChannel) }, ) @@ -341,10 +344,12 @@ def get_output_schema( if isinstance(self.output_channels, str): return super().get_output_schema(config) else: - return create_model( # type: ignore[call-overload] + return create_model( self.get_name("Output"), field_definitions={ - k: (self.channels[k].ValueType, None) for k in self.output_channels + k: (self.channels[k].ValueType, None) + for k in self.output_channels + if isinstance(self.channels[k], BaseChannel) }, ) @@ -413,7 +418,7 @@ def _prepare_state_snapshot( self, config: RunnableConfig, saved: Optional[CheckpointTuple], - recurse: Optional[BaseCheckpointSaver] = False, + recurse: Optional[BaseCheckpointSaver] = None, ) -> StateSnapshot: if not saved: return StateSnapshot( @@ -486,7 +491,7 @@ async def _aprepare_state_snapshot( self, config: RunnableConfig, saved: Optional[CheckpointTuple], - recurse: Optional[BaseCheckpointSaver] = False, + recurse: Optional[BaseCheckpointSaver] = None, ) -> StateSnapshot: if not saved: return StateSnapshot( @@ -545,7 +550,7 @@ async def _aprepare_state_snapshot( } } task_states[task.id] = await subgraphs[task.name].aget_state( - config, subgraphs=recurse + config, subgraphs=True ) # assemble the state snapshot return StateSnapshot( @@ -828,7 +833,7 @@ def update_state( writers = self.nodes[as_node].flat_writers if not writers: raise InvalidUpdateError(f"Node {as_node} has no writers") - writes = deque() + writes: deque[tuple[str, Any]] = deque() task = PregelTaskWrites(as_node, writes, [INTERRUPT]) task_id = str(uuid5(UUID(checkpoint["id"]), INTERRUPT)) run = RunnableSequence(*writers) if len(writers) > 1 else writers[0] @@ -925,21 +930,12 @@ async def aupdate_state( ) step = saved.metadata.get("step", -1) if saved else -1 # merge configurable fields with previous checkpoint config - checkpoint_config = { - **config, - "configurable": { - **config["configurable"], - # TODO: add proper support for updating nested subgraph state - "checkpoint_ns": "", - }, - } + checkpoint_config = patch_configurable( + config, + {"checkpoint_ns": config["configurable"].get("checkpoint_ns", "")}, + ) if saved: - checkpoint_config = { - "configurable": { - **config.get("configurable", {}), - **saved.config["configurable"], - } - } + checkpoint_config = patch_configurable(config, saved.config["configurable"]) # find last node that updated the state, if not provided if values is None and as_node is None: next_config = await checkpointer.aput( @@ -986,7 +982,7 @@ async def aupdate_state( writers = self.nodes[as_node].flat_writers if not writers: raise InvalidUpdateError(f"Node {as_node} has no writers") - writes = deque() + writes: deque[tuple[str, Any]] = deque() task = PregelTaskWrites(as_node, writes, [INTERRUPT]) task_id = str(uuid5(UUID(checkpoint["id"]), INTERRUPT)) run = RunnableSequence(*writers) if len(writers) > 1 else writers[0] @@ -1052,7 +1048,7 @@ def _defaults( debug: Optional[bool], ) -> tuple[ bool, - Sequence[StreamMode], + set[StreamMode], Union[str, Sequence[str]], Optional[Sequence[str]], Optional[Sequence[str]], @@ -1079,7 +1075,7 @@ def _defaults( checkpointer = self.checkpointer return ( debug, - stream_mode, + set(stream_mode), output_keys, interrupt_before, interrupt_after, @@ -1249,7 +1245,7 @@ def output() -> Iterator: # a pending waiter to return immediately loop.stack.callback(stream._count.release) - def get_waiter() -> asyncio.Task[None]: + def get_waiter() -> concurrent.futures.Future[None]: nonlocal waiter if waiter is None or waiter.done(): waiter = loop.submit(stream.wait) @@ -1390,13 +1386,6 @@ def output() -> Iterator: else: yield payload - if subgraphs: - - def get_waiter() -> asyncio.Task[None]: - return aioloop.create_task(stream.wait()) - else: - get_waiter = None - config = ensure_config(self.config, config) callback_manager = get_async_callback_manager_for_config(config) run_manager = await callback_manager.on_chain_start( @@ -1437,6 +1426,11 @@ def get_waiter() -> asyncio.Task[None]: interrupt_after=interrupt_after, debug=debug, ) + # set up messages stream mode + if "messages" in stream_modes: + run_manager.inheritable_handlers.append( + StreamMessagesHandler(stream.put) + ) async with AsyncPregelLoop( input, stream=StreamProtocol(stream.put_nowait, stream_modes), @@ -1457,6 +1451,13 @@ def get_waiter() -> asyncio.Task[None]: # enable subgraph streaming if subgraphs: loop.config["configurable"][CONFIG_KEY_STREAM] = loop.stream + # enable concurrent streaming + if subgraphs or "messages" in stream_modes: + + def get_waiter() -> asyncio.Task[None]: + return aioloop.create_task(stream.wait()) + else: + get_waiter = None # Similarly to Bulk Synchronous Parallel / Pregel model # computation proceeds in steps, while there are channel updates # channel updates from step N are only visible in step N+1 diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index c2c5866e4..bf4baeb0f 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -4,6 +4,7 @@ from typing import ( Any, Callable, + Iterable, Iterator, Literal, Mapping, @@ -20,7 +21,12 @@ from langchain_core.runnables.config import RunnableConfig from langgraph.channels.base import BaseChannel -from langgraph.checkpoint.base import BaseCheckpointSaver, Checkpoint, copy_checkpoint +from langgraph.checkpoint.base import ( + BaseCheckpointSaver, + Checkpoint, + V, + copy_checkpoint, +) from langgraph.constants import ( CONFIG_KEY_CHECKPOINT_MAP, CONFIG_KEY_CHECKPOINTER, @@ -46,13 +52,18 @@ from langgraph.pregel.types import All, PregelExecutableTask, PregelTask from langgraph.utils.config import merge_configs, patch_config -EMPTY_SEQ = tuple() +EMPTY_SEQ: tuple[str, ...] = tuple() class WritesProtocol(Protocol): - name: str - writes: Sequence[tuple[str, Any]] - triggers: Sequence[str] + @property + def name(self) -> str: ... + + @property + def writes(self) -> Sequence[tuple[str, Any]]: ... + + @property + def triggers(self) -> Sequence[str]: ... class PregelTaskWrites(NamedTuple): @@ -64,14 +75,14 @@ class PregelTaskWrites(NamedTuple): def should_interrupt( checkpoint: Checkpoint, interrupt_nodes: Union[All, Sequence[str]], - tasks: list[PregelExecutableTask], + tasks: Iterable[PregelExecutableTask], ) -> list[PregelExecutableTask]: version_type = type(next(iter(checkpoint["channel_versions"].values()), None)) - null_version = version_type() + null_version = version_type() # type: ignore[misc] seen = checkpoint["versions_seen"].get(INTERRUPT, {}) # interrupt if any channel has been updated since last interrupt any_updates_since_prev_interrupt = any( - version > seen.get(chan, null_version) + version > seen.get(chan, null_version) # type: ignore[operator] for chan, version in checkpoint["channel_versions"].items() ) # and any triggered node is in interrupt_nodes list @@ -161,8 +172,8 @@ def increment(current: Optional[int], channel: BaseChannel) -> int: def apply_writes( checkpoint: Checkpoint, channels: Mapping[str, BaseChannel], - tasks: Sequence[WritesProtocol], - get_next_version: Optional[Callable[[int, BaseChannel], int]], + tasks: Iterable[WritesProtocol], + get_next_version: Optional[Callable[[Optional[V], BaseChannel], V]], ) -> dict[str, list[Any]]: # update seen versions for task in tasks: @@ -189,7 +200,8 @@ def apply_writes( }: if channels[chan].consume() and get_next_version is not None: checkpoint["channel_versions"][chan] = get_next_version( - max_version, channels[chan] + max_version, # type: ignore[arg-type] + channels[chan], ) # clear pending sends @@ -222,7 +234,8 @@ def apply_writes( if chan in channels: if channels[chan].update(vals) and get_next_version is not None: checkpoint["channel_versions"][chan] = get_next_version( - max_version, channels[chan] + max_version, # type: ignore[arg-type] + channels[chan], ) updated_channels.add(chan) @@ -231,7 +244,8 @@ def apply_writes( if chan not in updated_channels: if channels[chan].update([]) and get_next_version is not None: checkpoint["channel_versions"][chan] = get_next_version( - max_version, channels[chan] + max_version, # type: ignore[arg-type] + channels[chan], ) # Return managed values writes to be applied externally @@ -280,7 +294,7 @@ def prepare_next_tasks( checkpointer: Optional[BaseCheckpointSaver] = None, manager: Union[None, ParentRunManager, AsyncParentRunManager] = None, ) -> Union[dict[str, PregelTask], dict[str, PregelExecutableTask]]: - tasks: Union[dict[str, PregelTask], dict[str, PregelExecutableTask]] = {} + tasks: dict[str, Union[PregelTask, PregelExecutableTask]] = {} # Consume pending packets for idx, _ in enumerate(checkpoint["pending_sends"]): if task := prepare_single_task( @@ -377,7 +391,7 @@ def prepare_single_task( managed.replace_runtime_placeholders(step, packet.arg) if proc.metadata: metadata.update(proc.metadata) - writes = deque() + writes: deque[tuple[str, Any]] = deque() return PregelExecutableTask( packet.node, packet.arg, @@ -438,7 +452,7 @@ def prepare_single_task( return proc = processes[name] version_type = type(next(iter(checkpoint["channel_versions"].values()), None)) - null_version = version_type() + null_version = version_type() # type: ignore[misc] if null_version is None: return seen = checkpoint["versions_seen"].get(name, {}) @@ -449,7 +463,7 @@ def prepare_single_task( if not isinstance( read_channel(channels, chan, return_exception=True), EmptyChannelError ) - and checkpoint["channel_versions"].get(chan, null_version) + and checkpoint["channel_versions"].get(chan, null_version) # type: ignore[operator] > seen.get(chan, null_version) ): try: diff --git a/libs/langgraph/langgraph/pregel/debug.py b/libs/langgraph/langgraph/pregel/debug.py index a83c95709..782d5a13c 100644 --- a/libs/langgraph/langgraph/pregel/debug.py +++ b/libs/langgraph/langgraph/pregel/debug.py @@ -2,7 +2,17 @@ from dataclasses import asdict from datetime import datetime, timezone from pprint import pformat -from typing import Any, Iterator, Literal, Mapping, Optional, Sequence, TypedDict, Union +from typing import ( + Any, + Iterable, + Iterator, + Literal, + Mapping, + Optional, + Sequence, + TypedDict, + Union, +) from uuid import UUID from langchain_core.runnables.config import RunnableConfig @@ -48,8 +58,6 @@ class CheckpointPayload(TypedDict): class DebugOutputBase(TypedDict): timestamp: str step: int - type: str - payload: dict[str, Any] class DebugOutputTask(DebugOutputBase): @@ -201,7 +209,7 @@ def print_step_checkpoint( def tasks_w_writes( - tasks: list[PregelExecutableTask], + tasks: Iterable[Union[PregelTask, PregelExecutableTask]], pending_writes: Optional[list[PendingWrite]], states: Optional[dict[str, Union[RunnableConfig, StateSnapshot]]], ) -> tuple[PregelTask, ...]: diff --git a/libs/langgraph/langgraph/pregel/executor.py b/libs/langgraph/langgraph/pregel/executor.py index f606d4c28..46f1c3f64 100644 --- a/libs/langgraph/langgraph/pregel/executor.py +++ b/libs/langgraph/langgraph/pregel/executor.py @@ -9,9 +9,11 @@ Awaitable, Callable, ContextManager, + Coroutine, Optional, Protocol, TypeVar, + cast, ) from langchain_core.runnables import RunnableConfig @@ -42,7 +44,7 @@ def __init__(self, config: RunnableConfig) -> None: self.executor = self.stack.enter_context(get_executor_for_config(config)) self.tasks: dict[concurrent.futures.Future, tuple[bool, bool]] = {} - def submit( + def submit( # type: ignore[valid-type] self, fn: Callable[P, T], *args: P.args, @@ -68,7 +70,7 @@ def done(self, task: concurrent.futures.Future) -> None: else: self.tasks.pop(task) - def __enter__(self) -> "submit": + def __enter__(self) -> Submit: return self.submit def __exit__( @@ -105,7 +107,7 @@ def __init__(self) -> None: self.sentinel = object() self.loop = asyncio.get_running_loop() - def submit( + def submit( # type: ignore[valid-type] self, fn: Callable[P, Awaitable[T]], *args: P.args, @@ -114,7 +116,7 @@ def submit( __reraise_on_exit__: bool = True, **kwargs: P.kwargs, ) -> asyncio.Task[T]: - coro = fn(*args, **kwargs) + coro = cast(Coroutine[None, None, T], fn(*args, **kwargs)) if self.context_not_supported: task = self.loop.create_task(coro, name=__name__) else: diff --git a/libs/langgraph/langgraph/pregel/io.py b/libs/langgraph/langgraph/pregel/io.py index a02afbdec..ad2252c9d 100644 --- a/libs/langgraph/langgraph/pregel/io.py +++ b/libs/langgraph/langgraph/pregel/io.py @@ -28,7 +28,7 @@ def read_channel( def read_channels( channels: Mapping[str, BaseChannel], - select: Union[list[str], str], + select: Union[Sequence[str], str], *, skip_empty: bool = True, ) -> Union[dict[str, Any], Any]: @@ -97,7 +97,7 @@ def __radd__(self, other: dict[str, Any]) -> "AddableUpdatesDict": raise TypeError("AddableUpdatesDict does not support right-side addition") -EMPTY_SEQ = tuple() +EMPTY_SEQ: tuple[str, ...] = tuple() def map_output_updates( @@ -131,16 +131,16 @@ def map_output_updates( for task, writes in output_tasks if any(chan in output_channels for chan, _ in writes) ) - grouped = {t.name: [] for t, _ in output_tasks} + grouped: dict[str, list[Any]] = {t.name: [] for t, _ in output_tasks} for node, value in updated: grouped[node].append(value) for node, value in grouped.items(): if len(value) == 0: - grouped[node] = None + grouped[node] = None # type: ignore[assignment] if len(value) == 1: grouped[node] = value[0] if cached: - grouped["__metadata__"] = {"cached": cached} + grouped["__metadata__"] = {"cached": cached} # type: ignore[assignment] yield AddableUpdatesDict(grouped) diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index 7560905ed..ba98b1605 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -14,7 +14,6 @@ Mapping, Optional, Sequence, - Tuple, Type, TypeVar, Union, @@ -28,6 +27,7 @@ from langgraph.channels.base import BaseChannel from langgraph.checkpoint.base import ( BaseCheckpointSaver, + ChannelVersions, Checkpoint, CheckpointMetadata, CheckpointTuple, @@ -92,7 +92,7 @@ ) from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager from langgraph.pregel.read import PregelNode -from langgraph.pregel.types import PregelExecutableTask +from langgraph.pregel.types import PregelExecutableTask, StreamMode from langgraph.pregel.utils import get_new_channel_versions from langgraph.store.base import BaseStore from langgraph.store.batch import AsyncBatchedStore @@ -105,31 +105,32 @@ EMPTY_SEQ = () SPECIAL_CHANNELS = (ERROR, INTERRUPT, SCHEDULED) +StreamChunk = tuple[tuple[str, ...], str, Any] + class StreamProtocol: __slots__ = ("modes", "__call__") - modes: Sequence[Literal["values", "updates", "debug"]] + modes: set[StreamMode] - __call__: Callable[[Tuple[str, str, Any]], None] + __call__: Callable[[StreamChunk], None] def __init__( self, - __call__: Callable[[Tuple[str, str, Any]], None], - modes: Sequence[Literal["values", "updates", "debug"]], + __call__: Callable[[StreamChunk], None], + modes: set[StreamMode], ) -> None: self.__call__ = __call__ self.modes = modes -class DuplexStream(StreamProtocol): - def __init__(self, *streams: StreamProtocol) -> None: - def __call__(value: Tuple[str, str, Any]) -> None: - for stream in streams: - if value[1] in stream.modes: - stream(value) +def DuplexStream(*streams: StreamProtocol) -> StreamProtocol: + def __call__(value: StreamChunk) -> None: + for stream in streams: + if value[1] in stream.modes: + stream(value) # type: ignore - super().__init__(__call__, {mode for s in streams for mode in s.modes}) + return StreamProtocol(__call__, {mode for s in streams for mode in s.modes}) class PregelLoop: @@ -156,6 +157,7 @@ class PregelLoop: RunnableConfig, Sequence[tuple[str, Any]], str, + ChannelVersions, ], Any, ] @@ -209,7 +211,7 @@ def __init__( or CONFIG_KEY_DEDUPE_TASKS in config["configurable"] ) self.debug = debug - if CONFIG_KEY_STREAM in config["configurable"]: + if self.stream is not None and CONFIG_KEY_STREAM in config["configurable"]: self.stream = DuplexStream( self.stream, config["configurable"][CONFIG_KEY_STREAM] ) @@ -233,7 +235,7 @@ def __init__( else: self.checkpoint_config = config self.checkpoint_ns = ( - tuple(self.config["configurable"].get("checkpoint_ns").split(NS_SEP)) + tuple(cast(str, self.config["configurable"]["checkpoint_ns"]).split(NS_SEP)) if self.config["configurable"].get("checkpoint_ns") else () ) @@ -435,7 +437,7 @@ def tick( # debug flag if self.debug: - print_step_tasks(self.step, self.tasks.values()) + print_step_tasks(self.step, list(self.tasks.values())) return True @@ -482,6 +484,7 @@ def _first(self, *, input_keys: Union[str, Sequence[str]]) -> None: self.config, self.step, for_execution=True, + checkpointer=None, manager=None, ) # apply input writes @@ -589,7 +592,7 @@ def _emit( if mode not in self.stream.modes: return for v in values(*args, **kwargs): - self.stream((self.checkpoint_ns, mode, v)) + self.stream((self.checkpoint_ns, mode, v)) # type: ignore def _output_writes( self, task_id: str, writes: Sequence[tuple[str, Any]], *, cached: bool = False @@ -650,7 +653,7 @@ def __init__( self.checkpointer_put_writes = checkpointer.put_writes else: self.checkpointer_get_next_version = increment - self._checkpointer_put_after_previous = None + self._checkpointer_put_after_previous = None # type: ignore[assignment] self.checkpointer_put_writes = None def _checkpointer_put_after_previous( @@ -659,13 +662,15 @@ def _checkpointer_put_after_previous( config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, - new_versions: Optional[dict[str, Union[str, float, int]]], + new_versions: ChannelVersions, ) -> RunnableConfig: try: if prev is not None: prev.result() finally: - self.checkpointer.put(config, checkpoint, metadata, new_versions) + cast(BaseCheckpointSaver, self.checkpointer).put( + config, checkpoint, metadata, new_versions + ) def _update_mv(self, key: str, values: Sequence[Any]) -> None: return self.submit(cast(WritableManagedValue, self.managed[key]).update, values) @@ -766,7 +771,7 @@ def __init__( self.checkpointer_put_writes = checkpointer.aput_writes else: self.checkpointer_get_next_version = increment - self._checkpointer_put_after_previous = None + self._checkpointer_put_after_previous = None # type: ignore[method-assign] self.checkpointer_put_writes = None async def _checkpointer_put_after_previous( @@ -775,13 +780,15 @@ async def _checkpointer_put_after_previous( config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, - new_versions: Optional[dict[str, Union[str, float, int]]], + new_versions: ChannelVersions, ) -> RunnableConfig: try: if prev is not None: await prev finally: - await self.checkpointer.aput(config, checkpoint, metadata, new_versions) + await cast(BaseCheckpointSaver, self.checkpointer).aput( + config, checkpoint, metadata, new_versions + ) def _update_mv(self, key: str, values: Sequence[Any]) -> None: return self.submit( diff --git a/libs/langgraph/langgraph/pregel/manager.py b/libs/langgraph/langgraph/pregel/manager.py index a7e0b7283..c6d6c07aa 100644 --- a/libs/langgraph/langgraph/pregel/manager.py +++ b/libs/langgraph/langgraph/pregel/manager.py @@ -28,8 +28,8 @@ def ChannelsManager( ) -> Iterator[tuple[Mapping[str, BaseChannel], ManagedValueMapping]]: """Manage channels for the lifetime of a Pregel invocation (multiple steps).""" config_for_managed = patch_configurable(config, {CONFIG_KEY_STORE: store}) - channel_specs: Mapping[str, BaseChannel] = {} - managed_specs: Mapping[str, ManagedValueSpec] = {} + channel_specs: dict[str, BaseChannel] = {} + managed_specs: dict[str, ManagedValueSpec] = {} for k, v in specs.items(): if isinstance(v, BaseChannel): channel_specs[k] = v @@ -66,11 +66,11 @@ async def AsyncChannelsManager( store: Optional[BaseStore] = None, *, skip_context: bool = False, -) -> AsyncIterator[Mapping[str, BaseChannel]]: +) -> AsyncIterator[tuple[Mapping[str, BaseChannel], ManagedValueMapping]]: """Manage channels for the lifetime of a Pregel invocation (multiple steps).""" config_for_managed = patch_configurable(config, {CONFIG_KEY_STORE: store}) - channel_specs: Mapping[str, BaseChannel] = {} - managed_specs: Mapping[str, ManagedValueSpec] = {} + channel_specs: dict[str, BaseChannel] = {} + managed_specs: dict[str, ManagedValueSpec] = {} for k, v in specs.items(): if isinstance(v, BaseChannel): channel_specs[k] = v diff --git a/libs/langgraph/langgraph/pregel/messages.py b/libs/langgraph/langgraph/pregel/messages.py index 3ade1dd5f..0a96f0fee 100644 --- a/libs/langgraph/langgraph/pregel/messages.py +++ b/libs/langgraph/langgraph/pregel/messages.py @@ -8,6 +8,8 @@ Optional, Sequence, Tuple, + Union, + cast, ) from uuid import UUID, uuid4 @@ -17,13 +19,14 @@ from langchain_core.tracers._streaming import T, _StreamingCallbackHandler from langgraph.constants import NS_SEP +from langgraph.pregel.loop import StreamChunk class StreamMessagesHandler(BaseCallbackHandler, _StreamingCallbackHandler): - def __init__(self, stream: Callable[[Tuple[str, str, Any]], None]): + def __init__(self, stream: Callable[[StreamChunk], None]): self.stream = stream - self.metadata: dict[str, tuple[str, dict[str, Any]]] = {} - self.seen = set() + self.metadata: dict[UUID, tuple[tuple[str, ...], dict[str, Any]]] = {} + self.seen: set[Union[int, str]] = set() def _emit( self, @@ -31,7 +34,7 @@ def _emit( message: BaseMessage, *, dedupe: bool = False, - ): + ) -> None: ident = id(message) if dedupe and message.id in self.seen: return @@ -65,7 +68,7 @@ def on_chat_model_start( ) -> Any: if metadata: self.metadata[run_id] = ( - tuple(metadata["langgraph_checkpoint_ns"].split(NS_SEP)), + tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP)), metadata, ) @@ -116,7 +119,7 @@ def on_chain_start( ) -> Any: if metadata and kwargs.get("name") == metadata.get("langgraph_node"): self.metadata[run_id] = ( - tuple(metadata["langgraph_checkpoint_ns"].split(NS_SEP)), + tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP)), metadata, ) diff --git a/libs/langgraph/langgraph/pregel/read.py b/libs/langgraph/langgraph/pregel/read.py index e0e483ffd..79643a090 100644 --- a/libs/langgraph/langgraph/pregel/read.py +++ b/libs/langgraph/langgraph/pregel/read.py @@ -27,7 +27,7 @@ from langgraph.utils.config import merge_configs from langgraph.utils.runnable import RunnableCallable, RunnableSeq -READ_TYPE = Callable[[str, bool], Union[Any, dict[str, Any]]] +READ_TYPE = Callable[[Union[str, Sequence[str]], bool], Union[Any, dict[str, Any]]] class ChannelRead(RunnableCallable): diff --git a/libs/langgraph/langgraph/pregel/retry.py b/libs/langgraph/langgraph/pregel/retry.py index 553b25468..90ccaa7d0 100644 --- a/libs/langgraph/langgraph/pregel/retry.py +++ b/libs/langgraph/langgraph/pregel/retry.py @@ -47,7 +47,7 @@ def run_with_retry( if not isinstance(exc, retry_policy.retry_on): raise elif callable(retry_policy.retry_on): - if not retry_policy.retry_on(exc): + if not retry_policy.retry_on(exc): # type: ignore[call-arg] raise else: raise TypeError( @@ -113,7 +113,7 @@ async def arun_with_retry( if not isinstance(exc, retry_policy.retry_on): raise elif callable(retry_policy.retry_on): - if not retry_policy.retry_on(exc): + if not retry_policy.retry_on(exc): # type: ignore[call-arg] raise else: raise TypeError( diff --git a/libs/langgraph/langgraph/pregel/runner.py b/libs/langgraph/langgraph/pregel/runner.py index 1f282c584..7e1122485 100644 --- a/libs/langgraph/langgraph/pregel/runner.py +++ b/libs/langgraph/langgraph/pregel/runner.py @@ -45,12 +45,12 @@ def tick( yield # fast path if single task with no timeout if len(tasks) == 1 and timeout is None: - task = tasks[0] + t = tasks[0] try: - run_with_retry(task, retry_policy) - self.commit(task, None) + run_with_retry(t, retry_policy) + self.commit(t, None) except Exception as exc: - self.commit(task, exc) + self.commit(t, exc) if reraise: raise return @@ -64,16 +64,16 @@ def tick( # execute tasks, and wait for one to fail or all to finish. # each task is independent from all other concurrent tasks # yield updates/debug output as each task finishes - for task in tasks: - if not task.writes: + for t in tasks: + if not t.writes: futures[ self.submit( run_with_retry, - task, + t, retry_policy, __reraise_on_exit__=reraise, ) - ] = task + ] = t all_futures = futures.copy() end_time = timeout + time.monotonic() if timeout else None while len(futures) > (1 if get_waiter is not None else 0): @@ -88,7 +88,7 @@ def tick( task = futures.pop(fut) if task is None: # waiter task finished, schedule another - if inflight: + if inflight and get_waiter is not None: futures[get_waiter()] = None else: # task finished, commit writes @@ -119,12 +119,12 @@ async def atick( yield # fast path if single task with no waiter and no timeout if len(tasks) == 1 and get_waiter is None and timeout is None: - task = tasks[0] + t = tasks[0] try: - await arun_with_retry(task, retry_policy, stream=self.use_astream) - self.commit(task, None) + await arun_with_retry(t, retry_policy, stream=self.use_astream) + self.commit(t, None) except Exception as exc: - self.commit(task, exc) + self.commit(t, exc) if reraise: raise return @@ -138,19 +138,19 @@ async def atick( # execute tasks, and wait for one to fail or all to finish. # each task is independent from all other concurrent tasks # yield updates/debug output as each task finishes - for task in tasks: - if not task.writes: + for t in tasks: + if not t.writes: futures[ self.submit( arun_with_retry, - task, + t, retry_policy, stream=self.use_astream, - __name__=task.name, + __name__=t.name, __cancel_on_exit__=True, __reraise_on_exit__=reraise, ) - ] = task + ] = t all_futures = futures.copy() end_time = timeout + loop.time() if timeout else None while len(futures) > (1 if get_waiter is not None else 0): @@ -165,7 +165,7 @@ async def atick( task = futures.pop(fut) if task is None: # waiter task finished, schedule another - if inflight: + if inflight and get_waiter is not None: futures[get_waiter()] = None else: # task finished, commit writes @@ -208,7 +208,7 @@ def commit( def _should_stop_others( - done: Union[set[concurrent.futures.Future[Any]], set[asyncio.Task[Any]]], + done: Union[set[concurrent.futures.Future[Any]], set[asyncio.Future[Any]]], ) -> bool: for fut in done: if fut.cancelled(): @@ -220,10 +220,10 @@ def _should_stop_others( def _exception( - fut: Union[concurrent.futures.Future[Any], asyncio.Task[Any]], + fut: Union[concurrent.futures.Future[Any], asyncio.Future[Any]], ) -> Optional[BaseException]: if fut.cancelled(): - if isinstance(fut, asyncio.Task): + if isinstance(fut, asyncio.Future): return asyncio.CancelledError() else: return concurrent.futures.CancelledError() @@ -240,8 +240,8 @@ def _panic_or_proceed( timeout_exc_cls: Type[Exception] = TimeoutError, panic: bool = True, ) -> None: - done: set[Union[concurrent.futures.Future[Any], asyncio.Task[Any]]] = set() - inflight: set[Union[concurrent.futures.Future[Any], asyncio.Task[Any]]] = set() + done: set[Union[concurrent.futures.Future[Any], asyncio.Future[Any]]] = set() + inflight: set[Union[concurrent.futures.Future[Any], asyncio.Future[Any]]] = set() for fut, val in futs.items(): if val is None: continue diff --git a/libs/langgraph/langgraph/pregel/types.py b/libs/langgraph/langgraph/pregel/types.py index f9ad46608..218456085 100644 --- a/libs/langgraph/langgraph/pregel/types.py +++ b/libs/langgraph/langgraph/pregel/types.py @@ -66,7 +66,7 @@ class CachePolicy(NamedTuple): class PregelTask(NamedTuple): id: str name: str - path: tuple[str, ...] + path: tuple[Union[str, int], ...] error: Optional[Exception] = None interrupts: tuple[Interrupt, ...] = () state: Union[None, RunnableConfig, "StateSnapshot"] = None @@ -82,7 +82,7 @@ class PregelExecutableTask(NamedTuple): retry_policy: Optional[RetryPolicy] cache_policy: Optional[CachePolicy] id: str - path: tuple[str, ...] + path: tuple[Union[str, int], ...] scheduled: bool = False diff --git a/libs/langgraph/langgraph/pregel/utils.py b/libs/langgraph/langgraph/pregel/utils.py index d3d0d989f..c6dc064d3 100644 --- a/libs/langgraph/langgraph/pregel/utils.py +++ b/libs/langgraph/langgraph/pregel/utils.py @@ -7,11 +7,11 @@ def get_new_channel_versions( """Get new channel versions.""" if previous_versions: version_type = type(next(iter(current_versions.values()), None)) - null_version = version_type() + null_version = version_type() # type: ignore[misc] new_versions = { k: v for k, v in current_versions.items() - if v > previous_versions.get(k, null_version) + if v > previous_versions.get(k, null_version) # type: ignore[operator] } else: new_versions = current_versions diff --git a/libs/langgraph/langgraph/pregel/write.py b/libs/langgraph/langgraph/pregel/write.py index fd732966a..9c3b7782d 100644 --- a/libs/langgraph/langgraph/pregel/write.py +++ b/libs/langgraph/langgraph/pregel/write.py @@ -50,7 +50,7 @@ def __init__( self, writes: Sequence[Union[ChannelWriteEntry, Send]], *, - tags: Optional[list[str]] = None, + tags: Optional[Sequence[str]] = None, require_at_least_one_of: Optional[Sequence[str]] = None, ): super().__init__(func=self._write, afunc=self._awrite, name=None, tags=tags) @@ -158,6 +158,6 @@ def register_writer(runnable: R) -> R: def _mk_future(val: Any) -> asyncio.Future: - fut = asyncio.Future() + fut: asyncio.Future[Any] = asyncio.Future() fut.set_result(val) return fut diff --git a/libs/langgraph/langgraph/utils/config.py b/libs/langgraph/langgraph/utils/config.py index 4a69050a1..cc352cf07 100644 --- a/libs/langgraph/langgraph/utils/config.py +++ b/libs/langgraph/langgraph/utils/config.py @@ -1,7 +1,12 @@ from collections import ChainMap from typing import Any, Optional, Sequence -from langchain_core.callbacks import AsyncCallbackManager, CallbackManager, Callbacks +from langchain_core.callbacks import ( + AsyncCallbackManager, + BaseCallbackManager, + CallbackManager, + Callbacks, +) from langchain_core.runnables import RunnableConfig from langchain_core.runnables.config import ( CONFIG_KEYS, @@ -63,20 +68,20 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: if not value: continue if key == "metadata": - if base_value := base.get(key): # type: ignore + if base_value := base.get(key): base[key] = {**base_value, **value} # type: ignore else: - base[key] = value + base[key] = value # type: ignore[literal-required] elif key == "tags": - if base_value := base.get(key): # type: ignore + if base_value := base.get(key): base[key] = [*base_value, *value] # type: ignore else: - base[key] = value + base[key] = value # type: ignore[literal-required] elif key == "configurable": - if base_value := base.get(key): # type: ignore + if base_value := base.get(key): base[key] = {**base_value, **value} # type: ignore else: - base[key] = value + base[key] = value # type: ignore[literal-required] elif key == "callbacks": base_callbacks = base.get("callbacks") # callbacks can be either None, list[handler] or manager @@ -92,7 +97,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: for callback in value: mngr.add_handler(callback, inherit=True) base["callbacks"] = mngr - else: + elif isinstance(value, BaseCallbackManager): # value is a manager if base_callbacks is None: base["callbacks"] = value.copy() @@ -104,11 +109,13 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: else: # base_callbacks is also a manager base["callbacks"] = base_callbacks.merge(value) + else: + raise NotImplementedError elif key == "recursion_limit": if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT: base["recursion_limit"] = config["recursion_limit"] else: - base[key] = config[key] + base[key] = config[key] # type: ignore[literal-required] return base @@ -138,7 +145,7 @@ def patch_config( Returns: RunnableConfig: The patched config. """ - config = config.copy() or {} + config = config.copy() if config is not None else {} if callbacks is not None: # If we're replacing callbacks, we need to unset run_name # As that should apply only to the same run as the original callbacks @@ -176,7 +183,7 @@ def get_callback_manager_for_config( if all_tags is not None and tags is not None: all_tags = [*all_tags, *tags] elif tags is not None: - all_tags = tags + all_tags = list(tags) # use existing callbacks if they exist if (callbacks := config.get("callbacks")) and isinstance( callbacks, CallbackManager @@ -214,7 +221,7 @@ def get_async_callback_manager_for_config( if all_tags is not None and tags is not None: all_tags = [*all_tags, *tags] elif tags is not None: - all_tags = tags + all_tags = list(tags) # use existing callbacks if they exist if (callbacks := config.get("callbacks")) and isinstance( callbacks, AsyncCallbackManager @@ -263,7 +270,7 @@ def ensure_config(*configs: Optional[RunnableConfig]) -> RunnableConfig: continue for k, v in config.items(): if v is not None and k in CONFIG_KEYS: - empty[k] = v + empty[k] = v # type: ignore[literal-required] for k, v in config.items(): if v is not None and k not in CONFIG_KEYS: empty["configurable"][k] = v diff --git a/libs/langgraph/langgraph/utils/fields.py b/libs/langgraph/langgraph/utils/fields.py index a4c29a9ec..f4786cb34 100644 --- a/libs/langgraph/langgraph/utils/fields.py +++ b/libs/langgraph/langgraph/utils/fields.py @@ -59,7 +59,7 @@ def _is_readonly_type(type_: Any) -> bool: return False -_DEFAULT_KEYS = frozenset() +_DEFAULT_KEYS: frozenset[str] = frozenset() def get_field_default(name: str, type_: Any, schema: Type[Any]) -> Any: diff --git a/libs/langgraph/langgraph/utils/pydantic.py b/libs/langgraph/langgraph/utils/pydantic.py index 9accc66e1..cd0984202 100644 --- a/libs/langgraph/langgraph/utils/pydantic.py +++ b/libs/langgraph/langgraph/utils/pydantic.py @@ -19,7 +19,7 @@ def create_model( """ try: # for langchain-core >= 0.3.0 - from langchain_core.runnables.pydantic import create_model_v2 + from langchain_core.utils.pydantic import create_model_v2 return create_model_v2( model_name, diff --git a/libs/langgraph/langgraph/utils/queue.py b/libs/langgraph/langgraph/utils/queue.py index 99d94f5c4..b68e15322 100644 --- a/libs/langgraph/langgraph/utils/queue.py +++ b/libs/langgraph/langgraph/utils/queue.py @@ -1,3 +1,5 @@ +# type: ignore + import asyncio import queue import sys @@ -5,6 +7,7 @@ import types from collections import deque from time import monotonic +from typing import Optional PY_310 = sys.version_info >= (3, 10) @@ -14,7 +17,7 @@ class AsyncQueue(asyncio.Queue): Subclassed from asyncio.Queue, adding a wait() method.""" - async def wait(self): + async def wait(self) -> None: """If queue is empty, wait until an item is available. Copied from Queue.get(), removing the call to .get_nowait(), @@ -47,7 +50,7 @@ async def wait(self): class Semaphore(threading.Semaphore): """Semaphore subclass with a wait() method.""" - def wait(self, blocking: bool = True, timeout: float = None): + def wait(self, blocking: bool = True, timeout: Optional[float] = None): """Block until the semaphore can be acquired, but don't acquire it.""" if not blocking and timeout is not None: raise ValueError("can't specify timeout for non-blocking acquire") @@ -125,3 +128,6 @@ def qsize(self): return len(self._queue) __class_getitem__ = classmethod(types.GenericAlias) + + +__all__ = ["AsyncQueue", "SyncQueue"] diff --git a/libs/langgraph/langgraph/utils/runnable.py b/libs/langgraph/langgraph/utils/runnable.py index 7230d67bb..f0b16442d 100644 --- a/libs/langgraph/langgraph/utils/runnable.py +++ b/libs/langgraph/langgraph/utils/runnable.py @@ -5,7 +5,18 @@ from contextlib import AsyncExitStack from contextvars import copy_context from functools import partial, wraps -from typing import Any, AsyncIterator, Awaitable, Callable, Iterator, Optional, Sequence +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Coroutine, + Iterator, + Optional, + Sequence, + Union, + cast, +) from langchain_core.runnables.base import ( Runnable, @@ -19,7 +30,7 @@ run_in_executor, var_child_runnable_config, ) -from langchain_core.runnables.utils import Input, Output, accepts_config +from langchain_core.runnables.utils import Input, accepts_config from langchain_core.tracers._streaming import _StreamingCallbackHandler from typing_extensions import TypeGuard @@ -52,8 +63,8 @@ class RunnableCallable(Runnable): def __init__( self, - func: Callable[..., Optional[Runnable]], - afunc: Optional[Callable[..., Awaitable[Optional[Runnable]]]] = None, + func: Optional[Callable[..., Union[Any, Runnable]]], + afunc: Optional[Callable[..., Awaitable[Union[Any, Runnable]]]] = None, *, name: Optional[str] = None, tags: Optional[Sequence[str]] = None, @@ -155,7 +166,7 @@ async def ainvoke( try: child_config = patch_config(config, callbacks=run_manager.get_child()) context.run(_set_config_context, child_config) - coro = self.afunc(input, **kwargs) + coro = cast(Coroutine[None, None, Any], self.afunc(input, **kwargs)) if ASYNCIO_ACCEPTS_CONTEXT: ret = await asyncio.create_task(coro, context=context) else: @@ -168,9 +179,8 @@ async def ainvoke( else: context.run(_set_config_context, config) if ASYNCIO_ACCEPTS_CONTEXT: - ret = await asyncio.create_task( - self.afunc(input, **kwargs), context=context - ) + coro = cast(Coroutine[None, None, Any], self.afunc(input, **kwargs)) + ret = await asyncio.create_task(coro, context=context) else: ret = await self.afunc(input, **kwargs) if isinstance(ret, Runnable) and self.recurse: @@ -200,7 +210,9 @@ def is_async_generator( ) -def coerce_to_runnable(thing: RunnableLike, *, name: str, trace: bool) -> Runnable: +def coerce_to_runnable( + thing: RunnableLike, *, name: Optional[str], trace: bool +) -> Runnable: """Coerce a runnable-like object into a Runnable. Args: @@ -219,7 +231,7 @@ def coerce_to_runnable(thing: RunnableLike, *, name: str, trace: bool) -> Runnab else: return RunnableCallable( thing, - wraps(thing)(partial(run_in_executor, None, thing)), + wraps(thing)(partial(run_in_executor, None, thing)), # type: ignore[arg-type] name=name, trace=trace, ) @@ -257,7 +269,7 @@ def __init__( if isinstance(step, RunnableSequence): steps_flat.extend(step.steps) elif isinstance(step, RunnableSeq): - steps_flat.extend(step.steps) + steps_flat.extend(step.steps) # type: ignore[has-type] else: steps_flat.append(coerce_to_runnable(step, name=None, trace=True)) if len(steps_flat) < 2: @@ -288,7 +300,7 @@ def __or__( else: return RunnableSeq( *self.steps, - coerce_to_runnable(other), + coerce_to_runnable(other, name=None, trace=True), name=self.name, ) @@ -312,14 +324,16 @@ def __ror__( ) else: return RunnableSequence( - coerce_to_runnable(other), + coerce_to_runnable(other, name=None, trace=True), *self.steps, name=self.name, ) def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Output: + ) -> Any: + if config is None: + config = ensure_config() # setup callbacks and context callback_manager = get_callback_manager_for_config(config) # start the root run @@ -356,7 +370,9 @@ async def ainvoke( input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], - ) -> Output: + ) -> Any: + if config is None: + config = ensure_config() # setup callbacks callback_manager = get_async_callback_manager_for_config(config) # start the root run @@ -397,7 +413,9 @@ def stream( input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], - ) -> Iterator[Output]: + ) -> Iterator[Any]: + if config is None: + config = ensure_config() # setup callbacks callback_manager = get_callback_manager_for_config(config) # start the root run @@ -424,7 +442,7 @@ def stream( iterator = step.transform(iterator, config) if stream_handler := next( ( - h + cast(_StreamingCallbackHandler, h) for h in run_manager.handlers if isinstance(h, _StreamingCallbackHandler) ), @@ -432,7 +450,7 @@ def stream( ): # populates streamed_output in astream_log() output if needed iterator = stream_handler.tap_output_iter(run_manager.run_id, iterator) - output: Output = None + output: Any = None add_supported = False for chunk in iterator: yield chunk @@ -458,7 +476,9 @@ async def astream( input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], - ) -> AsyncIterator[Output]: + ) -> AsyncIterator[Any]: + if config is None: + config = ensure_config() # setup callbacks callback_manager = get_async_callback_manager_for_config(config) # start the root run @@ -488,7 +508,7 @@ async def astream( stack.push_async_callback(aiterator.aclose) if stream_handler := next( ( - h + cast(_StreamingCallbackHandler, h) for h in run_manager.handlers if isinstance(h, _StreamingCallbackHandler) ), @@ -498,7 +518,7 @@ async def astream( aiterator = stream_handler.tap_output_aiter( run_manager.run_id, aiterator ) - output: Output = None + output: Any = None add_supported = False async for chunk in aiterator: yield chunk diff --git a/libs/langgraph/poetry.lock b/libs/langgraph/poetry.lock index ba50565a8..24de562c4 100644 --- a/libs/langgraph/poetry.lock +++ b/libs/langgraph/poetry.lock @@ -1478,44 +1478,44 @@ files = [ [[package]] name = "mypy" -version = "1.10.0" +version = "1.11.2" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" files = [ - {file = "mypy-1.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:da1cbf08fb3b851ab3b9523a884c232774008267b1f83371ace57f412fe308c2"}, - {file = "mypy-1.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:12b6bfc1b1a66095ab413160a6e520e1dc076a28f3e22f7fb25ba3b000b4ef99"}, - {file = "mypy-1.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e36fb078cce9904c7989b9693e41cb9711e0600139ce3970c6ef814b6ebc2b2"}, - {file = "mypy-1.10.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2b0695d605ddcd3eb2f736cd8b4e388288c21e7de85001e9f85df9187f2b50f9"}, - {file = "mypy-1.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:cd777b780312ddb135bceb9bc8722a73ec95e042f911cc279e2ec3c667076051"}, - {file = "mypy-1.10.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3be66771aa5c97602f382230165b856c231d1277c511c9a8dd058be4784472e1"}, - {file = "mypy-1.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8b2cbaca148d0754a54d44121b5825ae71868c7592a53b7292eeb0f3fdae95ee"}, - {file = "mypy-1.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ec404a7cbe9fc0e92cb0e67f55ce0c025014e26d33e54d9e506a0f2d07fe5de"}, - {file = "mypy-1.10.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e22e1527dc3d4aa94311d246b59e47f6455b8729f4968765ac1eacf9a4760bc7"}, - {file = "mypy-1.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:a87dbfa85971e8d59c9cc1fcf534efe664d8949e4c0b6b44e8ca548e746a8d53"}, - {file = "mypy-1.10.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a781f6ad4bab20eef8b65174a57e5203f4be627b46291f4589879bf4e257b97b"}, - {file = "mypy-1.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b808e12113505b97d9023b0b5e0c0705a90571c6feefc6f215c1df9381256e30"}, - {file = "mypy-1.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f55583b12156c399dce2df7d16f8a5095291354f1e839c252ec6c0611e86e2e"}, - {file = "mypy-1.10.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4cf18f9d0efa1b16478c4c129eabec36148032575391095f73cae2e722fcf9d5"}, - {file = "mypy-1.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:bc6ac273b23c6b82da3bb25f4136c4fd42665f17f2cd850771cb600bdd2ebeda"}, - {file = "mypy-1.10.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9fd50226364cd2737351c79807775136b0abe084433b55b2e29181a4c3c878c0"}, - {file = "mypy-1.10.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f90cff89eea89273727d8783fef5d4a934be2fdca11b47def50cf5d311aff727"}, - {file = "mypy-1.10.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fcfc70599efde5c67862a07a1aaf50e55bce629ace26bb19dc17cece5dd31ca4"}, - {file = "mypy-1.10.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:075cbf81f3e134eadaf247de187bd604748171d6b79736fa9b6c9685b4083061"}, - {file = "mypy-1.10.0-cp38-cp38-win_amd64.whl", hash = "sha256:3f298531bca95ff615b6e9f2fc0333aae27fa48052903a0ac90215021cdcfa4f"}, - {file = "mypy-1.10.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fa7ef5244615a2523b56c034becde4e9e3f9b034854c93639adb667ec9ec2976"}, - {file = "mypy-1.10.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3236a4c8f535a0631f85f5fcdffba71c7feeef76a6002fcba7c1a8e57c8be1ec"}, - {file = "mypy-1.10.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a2b5cdbb5dd35aa08ea9114436e0d79aceb2f38e32c21684dcf8e24e1e92821"}, - {file = "mypy-1.10.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92f93b21c0fe73dc00abf91022234c79d793318b8a96faac147cd579c1671746"}, - {file = "mypy-1.10.0-cp39-cp39-win_amd64.whl", hash = "sha256:28d0e038361b45f099cc086d9dd99c15ff14d0188f44ac883010e172ce86c38a"}, - {file = "mypy-1.10.0-py3-none-any.whl", hash = "sha256:f8c083976eb530019175aabadb60921e73b4f45736760826aa1689dda8208aee"}, - {file = "mypy-1.10.0.tar.gz", hash = "sha256:3d087fcbec056c4ee34974da493a826ce316947485cef3901f511848e687c131"}, + {file = "mypy-1.11.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d42a6dd818ffce7be66cce644f1dff482f1d97c53ca70908dff0b9ddc120b77a"}, + {file = "mypy-1.11.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:801780c56d1cdb896eacd5619a83e427ce436d86a3bdf9112527f24a66618fef"}, + {file = "mypy-1.11.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:41ea707d036a5307ac674ea172875f40c9d55c5394f888b168033177fce47383"}, + {file = "mypy-1.11.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e658bd2d20565ea86da7d91331b0eed6d2eee22dc031579e6297f3e12c758c8"}, + {file = "mypy-1.11.2-cp310-cp310-win_amd64.whl", hash = "sha256:478db5f5036817fe45adb7332d927daa62417159d49783041338921dcf646fc7"}, + {file = "mypy-1.11.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:75746e06d5fa1e91bfd5432448d00d34593b52e7e91a187d981d08d1f33d4385"}, + {file = "mypy-1.11.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a976775ab2256aadc6add633d44f100a2517d2388906ec4f13231fafbb0eccca"}, + {file = "mypy-1.11.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cd953f221ac1379050a8a646585a29574488974f79d8082cedef62744f0a0104"}, + {file = "mypy-1.11.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:57555a7715c0a34421013144a33d280e73c08df70f3a18a552938587ce9274f4"}, + {file = "mypy-1.11.2-cp311-cp311-win_amd64.whl", hash = "sha256:36383a4fcbad95f2657642a07ba22ff797de26277158f1cc7bd234821468b1b6"}, + {file = "mypy-1.11.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e8960dbbbf36906c5c0b7f4fbf2f0c7ffb20f4898e6a879fcf56a41a08b0d318"}, + {file = "mypy-1.11.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06d26c277962f3fb50e13044674aa10553981ae514288cb7d0a738f495550b36"}, + {file = "mypy-1.11.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6e7184632d89d677973a14d00ae4d03214c8bc301ceefcdaf5c474866814c987"}, + {file = "mypy-1.11.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3a66169b92452f72117e2da3a576087025449018afc2d8e9bfe5ffab865709ca"}, + {file = "mypy-1.11.2-cp312-cp312-win_amd64.whl", hash = "sha256:969ea3ef09617aff826885a22ece0ddef69d95852cdad2f60c8bb06bf1f71f70"}, + {file = "mypy-1.11.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:37c7fa6121c1cdfcaac97ce3d3b5588e847aa79b580c1e922bb5d5d2902df19b"}, + {file = "mypy-1.11.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4a8a53bc3ffbd161b5b2a4fff2f0f1e23a33b0168f1c0778ec70e1a3d66deb86"}, + {file = "mypy-1.11.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2ff93107f01968ed834f4256bc1fc4475e2fecf6c661260066a985b52741ddce"}, + {file = "mypy-1.11.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:edb91dded4df17eae4537668b23f0ff6baf3707683734b6a818d5b9d0c0c31a1"}, + {file = "mypy-1.11.2-cp38-cp38-win_amd64.whl", hash = "sha256:ee23de8530d99b6db0573c4ef4bd8f39a2a6f9b60655bf7a1357e585a3486f2b"}, + {file = "mypy-1.11.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:801ca29f43d5acce85f8e999b1e431fb479cb02d0e11deb7d2abb56bdaf24fd6"}, + {file = "mypy-1.11.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:af8d155170fcf87a2afb55b35dc1a0ac21df4431e7d96717621962e4b9192e70"}, + {file = "mypy-1.11.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f7821776e5c4286b6a13138cc935e2e9b6fde05e081bdebf5cdb2bb97c9df81d"}, + {file = "mypy-1.11.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:539c570477a96a4e6fb718b8d5c3e0c0eba1f485df13f86d2970c91f0673148d"}, + {file = "mypy-1.11.2-cp39-cp39-win_amd64.whl", hash = "sha256:3f14cd3d386ac4d05c5a39a51b84387403dadbd936e17cb35882134d4f8f0d24"}, + {file = "mypy-1.11.2-py3-none-any.whl", hash = "sha256:b499bc07dbdcd3de92b0a8b29fdf592c111276f6a12fe29c30f6c417dd546d12"}, + {file = "mypy-1.11.2.tar.gz", hash = "sha256:7f9993ad3e0ffdc95c2a14b66dee63729f021968bff8ad911867579c65d13a79"}, ] [package.dependencies] mypy-extensions = ">=1.0.0" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typing-extensions = ">=4.1.0" +typing-extensions = ">=4.6.0" [package.extras] dmypy = ["psutil (>=4.0)"] @@ -2979,6 +2979,20 @@ files = [ {file = "types_python_dateutil-2.9.0.20240316-py3-none-any.whl", hash = "sha256:6b8cb66d960771ce5ff974e9dd45e38facb81718cc1e208b10b1baccbfdbee3b"}, ] +[[package]] +name = "types-requests" +version = "2.32.0.20240914" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-requests-2.32.0.20240914.tar.gz", hash = "sha256:2850e178db3919d9bf809e434eef65ba49d0e7e33ac92d588f4a5e295fffd405"}, + {file = "types_requests-2.32.0.20240914-py3-none-any.whl", hash = "sha256:59c2f673eb55f32a99b2894faf6020e1a9f4a402ad0f192bfee0b64469054310"}, +] + +[package.dependencies] +urllib3 = ">=2" + [[package]] name = "typing-extensions" version = "4.12.2" @@ -3202,4 +3216,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<4.0" -content-hash = "73c2dec0a0e833ad8742ebfca86d8e3d602a8a63671a782d21d8e0079a02d448" +content-hash = "2c74c10f4650f14f2757e1a688761a9680ecd251da088ea1e8c5ceda51aec067" diff --git a/libs/langgraph/pyproject.toml b/libs/langgraph/pyproject.toml index a583db910..954ca2973 100644 --- a/libs/langgraph/pyproject.toml +++ b/libs/langgraph/pyproject.toml @@ -32,6 +32,7 @@ psycopg = {extras = ["binary"], version = ">=3.0.0", python = ">=3.10"} uvloop = "^0.20.0" pyperf = "^2.7.0" py-spy = "^0.3.14" +types-requests = "^2.32.0.20240914" [tool.ruff] lint.select = [ "E", "F", "I" ] @@ -49,8 +50,14 @@ docstring-code-format = false docstring-code-line-length = "dynamic" [tool.mypy] -ignore_missing_imports = "True" +# https://mypy.readthedocs.io/en/stable/config_file.html disallow_untyped_defs = "True" +explicit_package_bases = "True" +warn_no_return = "False" +warn_unused_ignores = "True" +warn_redundant_casts = "True" +allow_redefinition = "True" +disable_error_code = "typeddict-item, return-value, override" [tool.coverage.run] omit = ["tests/*"] diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 6bbbd0506..fec2aeafa 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -3962,7 +3962,7 @@ def search_api(query: str) -> str: assert [ c - for c in app.stream( + async for c in app.astream( {"messages": [HumanMessage(content="what is weather in sf")]}, stream_mode="messages", ) From cc318b1156cf18305a48a27065479a815d758240 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 19 Sep 2024 11:43:30 -0700 Subject: [PATCH 2/9] Fix --- .../langgraph/checkpoint/postgres/__init__.py | 4 ++-- .../langgraph/checkpoint/postgres/aio.py | 6 +++--- .../langgraph/checkpoint/postgres/base.py | 4 ++-- libs/checkpoint/langgraph/checkpoint/memory/__init__.py | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py index 821b74af6..a2274cd60 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py @@ -1,6 +1,6 @@ import threading from contextlib import contextmanager -from typing import Any, Iterator, List, Optional, Union +from typing import Any, Iterator, Optional, Sequence, Union from langchain_core.runnables import RunnableConfig from psycopg import Connection, Cursor, Pipeline @@ -332,7 +332,7 @@ def put( def put_writes( self, config: RunnableConfig, - writes: List[tuple[str, Any]], + writes: Sequence[tuple[str, Any]], task_id: str, ) -> None: """Store intermediate writes linked to a checkpoint. diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py index 43a8cfe47..5b6269fe0 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py @@ -1,6 +1,6 @@ import asyncio from contextlib import asynccontextmanager -from typing import Any, AsyncIterator, Iterator, List, Optional, Union +from typing import Any, AsyncIterator, Iterator, Optional, Sequence, Union from langchain_core.runnables import RunnableConfig from psycopg import AsyncConnection, AsyncCursor, AsyncPipeline @@ -291,7 +291,7 @@ async def aput( async def aput_writes( self, config: RunnableConfig, - writes: list[tuple[str, Any]], + writes: Sequence[tuple[str, Any]], task_id: str, ) -> None: """Store intermediate writes linked to a checkpoint asynchronously. @@ -424,7 +424,7 @@ def put( def put_writes( self, config: RunnableConfig, - writes: List[tuple[str, Any]], + writes: Sequence[tuple[str, Any]], task_id: str, ) -> None: """Store intermediate writes linked to a checkpoint. diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py index f2fa21eba..76232e337 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py @@ -1,5 +1,5 @@ import random -from typing import Any, List, Optional, Tuple, cast +from typing import Any, List, Optional, Sequence, Tuple, cast from langchain_core.runnables import RunnableConfig from psycopg.types.json import Jsonb @@ -209,7 +209,7 @@ def _dump_writes( checkpoint_ns: str, checkpoint_id: str, task_id: str, - writes: list[tuple[str, Any]], + writes: Sequence[tuple[str, Any]], ) -> list[tuple[str, str, str, str, int, str, str, bytes]]: return [ ( diff --git a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py index 20d53ea06..176aec24c 100644 --- a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py @@ -4,7 +4,7 @@ from contextlib import AbstractAsyncContextManager, AbstractContextManager from functools import partial from types import TracebackType -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple +from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple from langchain_core.runnables import RunnableConfig @@ -344,7 +344,7 @@ def put( def put_writes( self, config: RunnableConfig, - writes: List[Tuple[str, Any]], + writes: Sequence[Tuple[str, Any]], task_id: str, ) -> None: """Save a list of writes to the in-memory storage. @@ -447,7 +447,7 @@ async def aput( async def aput_writes( self, config: RunnableConfig, - writes: List[Tuple[str, Any]], + writes: Sequence[Tuple[str, Any]], task_id: str, ) -> None: """Asynchronous version of put_writes. From 4f767cd2ca8047034eb90c7c17dbc4be0c12dca4 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 19 Sep 2024 11:46:21 -0700 Subject: [PATCH 3/9] Fix --- libs/langgraph/langgraph/pregel/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index 5389b7ddf..3cf791ef8 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -1429,7 +1429,7 @@ def output() -> Iterator: # set up messages stream mode if "messages" in stream_modes: run_manager.inheritable_handlers.append( - StreamMessagesHandler(stream.put) + StreamMessagesHandler(stream.put_nowait) ) async with AsyncPregelLoop( input, From 5de9b354162bfee89b859ad72a404fb59488c35f Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 19 Sep 2024 13:01:54 -0700 Subject: [PATCH 4/9] Finish --- .../langgraph/checkpoint/base/__init__.py | 4 +- libs/langgraph/Makefile | 2 +- .../channels/dynamic_barrier_value.py | 2 +- .../langgraph/channels/ephemeral_value.py | 4 +- libs/langgraph/langgraph/graph/graph.py | 79 ++++++++++++------- libs/langgraph/langgraph/graph/message.py | 4 +- libs/langgraph/langgraph/graph/state.py | 38 ++++++--- .../langgraph/prebuilt/chat_agent_executor.py | 2 +- .../langgraph/langgraph/prebuilt/tool_node.py | 2 +- libs/langgraph/langgraph/pregel/__init__.py | 44 ++++++----- libs/langgraph/langgraph/pregel/algo.py | 10 ++- libs/langgraph/langgraph/pregel/debug.py | 4 +- libs/langgraph/langgraph/pregel/loop.py | 22 ++++-- libs/langgraph/langgraph/pregel/messages.py | 13 +-- libs/langgraph/langgraph/pregel/read.py | 12 +-- libs/langgraph/langgraph/pregel/runner.py | 25 +++--- libs/langgraph/langgraph/pregel/validate.py | 16 ++-- libs/langgraph/langgraph/pregel/write.py | 5 +- libs/langgraph/langgraph/utils/runnable.py | 2 +- libs/langgraph/pyproject.toml | 2 +- 20 files changed, 171 insertions(+), 121 deletions(-) diff --git a/libs/checkpoint/langgraph/checkpoint/base/__init__.py b/libs/checkpoint/langgraph/checkpoint/base/__init__.py index 822389cce..ae98e5df5 100644 --- a/libs/checkpoint/langgraph/checkpoint/base/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/base/__init__.py @@ -13,11 +13,11 @@ Sequence, Tuple, TypedDict, + TypeVar, Union, ) from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig -from typing_extensions import TypeVar from langgraph.checkpoint.base.id import uuid6 from langgraph.checkpoint.serde.base import SerializerProtocol, maybe_add_typed_methods @@ -29,7 +29,7 @@ SendProtocol, ) -V = TypeVar("V", int, float, str, default=int) +V = TypeVar("V", int, float, str) PendingWrite = Tuple[str, str, Any] diff --git a/libs/langgraph/Makefile b/libs/langgraph/Makefile index 62ef303bb..7ec12f80a 100644 --- a/libs/langgraph/Makefile +++ b/libs/langgraph/Makefile @@ -75,7 +75,7 @@ lint lint_diff lint_package lint_tests: [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff [ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I $(PYTHON_FILES) [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) - [ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) + [ "$(PYTHON_FILES)" != "langgraph" ] || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) format format_diff: poetry run ruff format $(PYTHON_FILES) diff --git a/libs/langgraph/langgraph/channels/dynamic_barrier_value.py b/libs/langgraph/langgraph/channels/dynamic_barrier_value.py index bbecd3d8b..dfa77f350 100644 --- a/libs/langgraph/langgraph/channels/dynamic_barrier_value.py +++ b/libs/langgraph/langgraph/channels/dynamic_barrier_value.py @@ -6,7 +6,7 @@ from langgraph.errors import EmptyChannelError, InvalidUpdateError -class WaitForNames(NamedTuple): +class WaitForNames(NamedTuple, Generic[Value]): names: set[Value] diff --git a/libs/langgraph/langgraph/channels/ephemeral_value.py b/libs/langgraph/langgraph/channels/ephemeral_value.py index 5beba22eb..537a8763c 100644 --- a/libs/langgraph/langgraph/channels/ephemeral_value.py +++ b/libs/langgraph/langgraph/channels/ephemeral_value.py @@ -1,4 +1,4 @@ -from typing import Generic, Optional, Sequence, Type +from typing import Any, Generic, Optional, Sequence, Type from typing_extensions import Self @@ -11,7 +11,7 @@ class EphemeralValue(Generic[Value], BaseChannel[Value, Value, Value]): __slots__ = ("value", "guard") - def __init__(self, typ: Type[Value], guard: bool = True) -> None: + def __init__(self, typ: Any, guard: bool = True) -> None: super().__init__(typ) self.guard = guard diff --git a/libs/langgraph/langgraph/graph/graph.py b/libs/langgraph/langgraph/graph/graph.py index 706e5253c..93fe28b3d 100644 --- a/libs/langgraph/langgraph/graph/graph.py +++ b/libs/langgraph/langgraph/graph/graph.py @@ -57,7 +57,7 @@ class Branch(NamedTuple): def run( self, writer: Callable[ - [list[Union[str, Send]], RunnableConfig], Optional[ChannelWrite] + [Sequence[Union[str, Send]], RunnableConfig], Optional[ChannelWrite] ], reader: Optional[Callable[[RunnableConfig], Any]] = None, ) -> RunnableCallable: @@ -79,7 +79,7 @@ def _route( *, reader: Optional[Callable[[RunnableConfig], Any]], writer: Callable[ - [list[Union[str, Send]], RunnableConfig], Optional[ChannelWrite] + [Sequence[Union[str, Send]], RunnableConfig], Optional[ChannelWrite] ], ) -> Runnable: if reader: @@ -100,7 +100,7 @@ async def _aroute( *, reader: Optional[Callable[[RunnableConfig], Any]], writer: Callable[ - [list[Union[str, Send]], RunnableConfig], Optional[ChannelWrite] + [Sequence[Union[str, Send]], RunnableConfig], Optional[ChannelWrite] ], ) -> Runnable: if reader: @@ -117,18 +117,20 @@ async def _aroute( def _finish( self, writer: Callable[ - [list[Union[str, Send]], RunnableConfig], Optional[ChannelWrite] + [Sequence[Union[str, Send]], RunnableConfig], Optional[ChannelWrite] ], input: Any, result: Any, config: RunnableConfig, - ): + ) -> Union[Runnable, Any]: if not isinstance(result, list): result = [result] if self.ends: - destinations = [r if isinstance(r, Send) else self.ends[r] for r in result] + destinations: Sequence[Union[Send, str]] = [ + r if isinstance(r, Send) else self.ends[r] for r in result + ] else: - destinations = result + destinations = cast(Sequence[Union[Send, str]], result) if any(dest is None or dest == START for dest in destinations): raise ValueError("Branch did not return a valid destination") if any(p.node == END for p in destinations if isinstance(p, Send)): @@ -186,14 +188,20 @@ def add_node( ) if not isinstance(node, str): action = node - node = getattr(action, "name", action.__name__) + node = getattr(action, "name", getattr(action, "__name__")) + if node is None: + raise ValueError( + "Node name must be provided if action is not a function" + ) + if action is None: + raise RuntimeError if node in self.nodes: raise ValueError(f"Node `{node}` already present.") if node == END or node == START: raise ValueError(f"Node `{node}` is reserved.") - self.nodes[node] = NodeSpec( - coerce_to_runnable(action, name=node, trace=False), metadata + self.nodes[cast(str, node)] = NodeSpec( + coerce_to_runnable(action, name=cast(str, node), trace=False), metadata ) def add_edge(self, start_key: str, end_key: str) -> None: @@ -257,16 +265,20 @@ def add_conditional_edges( # coerce path_map to a dictionary try: if isinstance(path_map, dict): - path_map = path_map.copy() + path_map_ = path_map.copy() elif isinstance(path_map, list): - path_map = {name: name for name in path_map} - elif rtn_type := get_type_hints(path.__call__).get( - "return" - ) or get_type_hints(path).get("return"): + path_map_ = {name: name for name in path_map} + elif callable(path) and ( + rtn_type := get_type_hints(path.__call__).get("return") + if hasattr(path, "__call__") + else get_type_hints(path).get("return") + ): if get_origin(rtn_type) is Literal: - path_map = {name: name for name in get_args(rtn_type)} + path_map_ = {name: name for name in get_args(rtn_type)} + else: + path_map_ = None except Exception: - pass + path_map_ = None # find a name for the condition path = coerce_to_runnable(path, name=None, trace=True) name = path.name or "condition" @@ -276,7 +288,7 @@ def add_conditional_edges( f"Branch with name `{path.name}` already exists for node " f"`{source}`" ) # save it - self.branches[source][name] = Branch(path, path_map, then) + self.branches[source][name] = Branch(path, path_map_, then) def set_entry_point(self, key: str) -> None: """Specifies the first node to be called in the graph. @@ -405,7 +417,6 @@ def compile( # create empty compiled graph compiled = CompiledGraph( - builder=self, nodes={}, channels={START: EphemeralValue(Any), END: EphemeralValue(Any)}, input_channels=START, @@ -418,6 +429,7 @@ def compile( auto_validate=False, debug=debug, ) + compiled.builder = self # attach nodes, edges, and branches for key, node in self.nodes.items(): @@ -437,10 +449,6 @@ def compile( class CompiledGraph(Pregel): builder: Graph - def __init__(self, *, builder: Graph, **kwargs): - super().__init__(**kwargs) - self.builder = builder - def attach_node(self, key: str, node: NodeSpec) -> None: self.channels[key] = EphemeralValue(Any) self.nodes[key] = ( @@ -463,7 +471,7 @@ def attach_edge(self, start: str, end: str) -> None: def attach_branch(self, start: str, name: str, branch: Branch) -> None: def branch_writer( - packets: list[Union[str, Send]], config: RunnableConfig + packets: Sequence[Union[str, Send]], config: RunnableConfig ) -> Optional[ChannelWrite]: writes = [ ( @@ -473,7 +481,10 @@ def branch_writer( ) for p in packets ] - return ChannelWrite(writes, tags=[TAG_HIDDEN]) + return ChannelWrite( + cast(Sequence[Union[ChannelWriteEntry, Send]], writes), + tags=[TAG_HIDDEN], + ) # add hidden start node if start == START and start not in self.nodes: @@ -489,7 +500,7 @@ def branch_writer( channel_name = f"branch:{start}:{name}:{end}" self.channels[channel_name] = EphemeralValue(Any) self.nodes[end].triggers.append(channel_name) - self.nodes[end].channels.append(channel_name) + cast(list[str], self.nodes[end].channels).append(channel_name) def get_graph( self, @@ -504,17 +515,25 @@ def get_graph( } end_nodes: dict[str, DrawableNode] = {} if xray: - subgraphs = dict(self.get_subgraphs()) + subgraphs = { + k: v for k, v in self.get_subgraphs() if isinstance(v, CompiledGraph) + } else: subgraphs = {} def add_edge( - start: str, end: str, label: Optional[str] = None, conditional: bool = False + start: str, + end: str, + label: Optional[Hashable] = None, + conditional: bool = False, ) -> None: if end == END and END not in end_nodes: end_nodes[END] = graph.add_node(self.get_output_schema(config), END) return graph.add_edge( - start_nodes[start], end_nodes[end], label, conditional + start_nodes[start], + end_nodes[end], + str(label) if label is not None else None, + conditional, ) for key, n in self.builder.nodes.items(): @@ -563,7 +582,7 @@ def add_edge( elif branch.then is not None: ends = {k: k for k in default_ends if k not in (END, branch.then)} else: - ends = default_ends + ends = cast(dict[Hashable, str], default_ends) for label, end in ends.items(): add_edge( start, diff --git a/libs/langgraph/langgraph/graph/message.py b/libs/langgraph/langgraph/graph/message.py index 34bc9c090..6575bd10c 100644 --- a/libs/langgraph/langgraph/graph/message.py +++ b/libs/langgraph/langgraph/graph/message.py @@ -63,9 +63,9 @@ def add_messages(left: Messages, right: Messages) -> Messages: """ # coerce to list if not isinstance(left, list): - left = [left] + left = [left] # type: ignore[assignment] if not isinstance(right, list): - right = [right] + right = [right] # type: ignore[assignment] # coerce to message left = [ message_chunk_to_message(cast(BaseMessageChunk, m)) diff --git a/libs/langgraph/langgraph/graph/state.py b/libs/langgraph/langgraph/graph/state.py index e66cff53d..803434316 100644 --- a/libs/langgraph/langgraph/graph/state.py +++ b/libs/langgraph/langgraph/graph/state.py @@ -7,11 +7,13 @@ from typing import ( Any, Callable, + Literal, NamedTuple, Optional, Sequence, Type, Union, + cast, get_origin, get_type_hints, overload, @@ -122,7 +124,7 @@ class StateGraph(Graph): >>> print(step1) {'x': [0.5, 0.75]}""" - nodes: dict[str, StateNodeSpec] + nodes: dict[str, StateNodeSpec] # type: ignore[assignment] channels: dict[str, BaseChannel] managed: dict[str, ManagedValueSpec] schemas: dict[Type[Any], dict[str, Union[BaseChannel, ManagedValueSpec]]] @@ -302,7 +304,7 @@ def add_node( if not isinstance(node, str): action = node if isinstance(action, Runnable): - node = action.name + node = action.get_name() else: node = getattr(action, "__name__", action.__class__.__name__) if node is None: @@ -323,13 +325,15 @@ def add_node( raise ValueError( "Node name must be provided if action is not a function" ) + if action is None: + raise RuntimeError if node in self.nodes: raise ValueError(f"Node `{node}` already present.") if node == END or node == START: raise ValueError(f"Node `{node}` is reserved.") for character in (NS_SEP, NS_END): - if character in node: + if character in cast(str, node): raise ValueError( f"'{character}' is a reserved character and is not allowed in the node names." ) @@ -349,8 +353,8 @@ def add_node( pass if input is not None: self._add_schema(input) - self.nodes[node] = StateNodeSpec( - coerce_to_runnable(action, name=node, trace=False), + self.nodes[cast(str, node)] = StateNodeSpec( + coerce_to_runnable(action, name=cast(str, node), trace=False), metadata, input=input or self.schema, retry_policy=retry, @@ -449,7 +453,6 @@ def compile( ) compiled = CompiledStateGraph( - builder=self, config_type=self.config_schema, nodes={}, channels={ @@ -468,6 +471,7 @@ def compile( debug=debug, store=store, ) + compiled.builder = self compiled.attach_node(START, None) for key, node in self.nodes.items(): @@ -618,7 +622,7 @@ def attach_edge(self, starts: Union[str, Sequence[str]], end: str) -> None: def attach_branch(self, start: str, name: str, branch: Branch) -> None: def branch_writer( - packets: list[Union[str, Send]], config: RunnableConfig + packets: Sequence[Union[str, Send]], config: RunnableConfig ) -> None: if filtered := [p for p in packets if p != END]: writes = [ @@ -638,7 +642,9 @@ def branch_writer( ), ) ) - ChannelWrite.do_write(config, writes) + ChannelWrite.do_write( + config, cast(Sequence[Union[Send, ChannelWriteEntry]], writes) + ) # attach branch publisher schema = ( @@ -708,11 +714,23 @@ def _get_channels( if name != "__slots__" } return ( - {k: v for k, v in all_keys.items() if not is_managed_value(v)}, + {k: v for k, v in all_keys.items() if isinstance(v, BaseChannel)}, {k: v for k, v in all_keys.items() if is_managed_value(v)}, ) +@overload +def _get_channel( + name: str, annotation: Any, *, allow_managed: Literal[False] +) -> BaseChannel: ... + + +@overload +def _get_channel( + name: str, annotation: Any, *, allow_managed: Literal[True] = True +) -> Union[BaseChannel, ManagedValueSpec]: ... + + def _get_channel( name: str, annotation: Any, *, allow_managed: bool = True ) -> Union[BaseChannel, ManagedValueSpec]: @@ -728,7 +746,7 @@ def _get_channel( channel.key = name return channel - fallback = LastValue(annotation) + fallback: LastValue = LastValue(annotation) fallback.key = name return fallback diff --git a/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py b/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py index d4a6157cb..dd2cddb5e 100644 --- a/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py +++ b/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py @@ -419,7 +419,7 @@ class Agent,Tools otherClass raise ValueError(f"Missing required key(s) {missing_keys} in state_schema") if isinstance(tools, ToolExecutor): - tool_classes = tools.tools + tool_classes: Sequence[BaseTool] = tools.tools tool_node = ToolNode(tool_classes) elif isinstance(tools, ToolNode): tool_classes = list(tools.tools_by_name.values()) diff --git a/libs/langgraph/langgraph/prebuilt/tool_node.py b/libs/langgraph/langgraph/prebuilt/tool_node.py index c0f7c83f6..52b80f75a 100644 --- a/libs/langgraph/langgraph/prebuilt/tool_node.py +++ b/libs/langgraph/langgraph/prebuilt/tool_node.py @@ -382,7 +382,7 @@ def _get_state_args(tool: BaseTool) -> Dict[str, Optional[str]]: full_schema = tool.get_input_schema() tool_args_to_state_fields: Dict = {} - def _is_injection(type_arg: Any): + def _is_injection(type_arg: Any) -> bool: if isinstance(type_arg, InjectedState) or ( isinstance(type_arg, type) and issubclass(type_arg, InjectedState) ): diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index 3cf791ef8..9f16ac698 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -138,7 +138,7 @@ def subscribe_to( ) return PregelNode( channels=cast( - Union[Mapping[None, str], Mapping[str, str]], + Union[list[str], Mapping[str, str]], ( {key: channels} if isinstance(channels, str) and key is not None @@ -305,7 +305,9 @@ def config_specs(self) -> list[ConfigurableFieldSpec]: @property def InputType(self) -> Any: if isinstance(self.input_channels, str): - return self.channels[self.input_channels].UpdateType + channel = self.channels[self.input_channels] + if isinstance(channel, BaseChannel): + return channel.UpdateType def get_input_schema( self, config: Optional[RunnableConfig] = None @@ -317,9 +319,9 @@ def get_input_schema( return create_model( self.get_name("Input"), field_definitions={ - k: (self.channels[k].UpdateType, None) + k: (c.UpdateType, None) for k in self.input_channels or self.channels.keys() - if isinstance(self.channels[k], BaseChannel) + if (c := self.channels[k]) and isinstance(c, BaseChannel) }, ) @@ -335,7 +337,9 @@ def get_input_jsonschema( @property def OutputType(self) -> Any: if isinstance(self.output_channels, str): - return self.channels[self.output_channels].ValueType + channel = self.channels[self.output_channels] + if isinstance(channel, BaseChannel): + return channel.ValueType def get_output_schema( self, config: Optional[RunnableConfig] = None @@ -347,9 +351,9 @@ def get_output_schema( return create_model( self.get_name("Output"), field_definitions={ - k: (self.channels[k].ValueType, None) + k: (c.ValueType, None) for k in self.output_channels - if isinstance(self.channels[k], BaseChannel) + if (c := self.channels[k]) and isinstance(c, BaseChannel) }, ) @@ -1050,8 +1054,8 @@ def _defaults( bool, set[StreamMode], Union[str, Sequence[str]], - Optional[Sequence[str]], - Optional[Sequence[str]], + Union[All, Sequence[str]], + Union[All, Sequence[str]], Optional[BaseCheckpointSaver], ]: debug = debug if debug is not None else self.debug @@ -1199,8 +1203,8 @@ def output() -> Iterator: debug, stream_modes, output_keys, - interrupt_before, - interrupt_after, + interrupt_before_, + interrupt_after_, checkpointer, ) = self._defaults( config, @@ -1253,7 +1257,7 @@ def get_waiter() -> concurrent.futures.Future[None]: else: return waiter else: - get_waiter = None + get_waiter = None # type: ignore[assignment] # Similarly to Bulk Synchronous Parallel / Pregel model # computation proceeds in steps, while there are channel updates # channel updates from step N are only visible in step N+1 @@ -1261,8 +1265,8 @@ def get_waiter() -> concurrent.futures.Future[None]: # with channel updates applied only at the transition between steps while loop.tick( input_keys=self.input_channels, - interrupt_before=interrupt_before, - interrupt_after=interrupt_after, + interrupt_before=interrupt_before_, + interrupt_after=interrupt_after_, manager=run_manager, ): for _ in runner.tick( @@ -1397,7 +1401,7 @@ def output() -> Iterator: # if running from astream_log() run each proc with streaming do_stream = next( ( - h + cast(_StreamingCallbackHandler, h) for h in run_manager.handlers if isinstance(h, _StreamingCallbackHandler) ), @@ -1415,8 +1419,8 @@ def output() -> Iterator: debug, stream_modes, output_keys, - interrupt_before, - interrupt_after, + interrupt_before_, + interrupt_after_, checkpointer, ) = self._defaults( config, @@ -1457,7 +1461,7 @@ def output() -> Iterator: def get_waiter() -> asyncio.Task[None]: return aioloop.create_task(stream.wait()) else: - get_waiter = None + get_waiter = None # type: ignore[assignment] # Similarly to Bulk Synchronous Parallel / Pregel model # computation proceeds in steps, while there are channel updates # channel updates from step N are only visible in step N+1 @@ -1465,8 +1469,8 @@ def get_waiter() -> asyncio.Task[None]: # with channel updates applied only at the transition between steps while loop.tick( input_keys=self.input_channels, - interrupt_before=interrupt_before, - interrupt_after=interrupt_after, + interrupt_before=interrupt_before_, + interrupt_after=interrupt_after_, manager=run_manager, ): async for _ in runner.atick( diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index bf4baeb0f..40a10a40f 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -52,6 +52,8 @@ from langgraph.pregel.types import All, PregelExecutableTask, PregelTask from langgraph.utils.config import merge_configs, patch_config +GetNextVersion = Callable[[Optional[V], BaseChannel], V] + EMPTY_SEQ: tuple[str, ...] = tuple() @@ -173,7 +175,7 @@ def apply_writes( checkpoint: Checkpoint, channels: Mapping[str, BaseChannel], tasks: Iterable[WritesProtocol], - get_next_version: Optional[Callable[[Optional[V], BaseChannel], V]], + get_next_version: Optional[GetNextVersion], ) -> dict[str, list[Any]]: # update seen versions for task in tasks: @@ -200,7 +202,7 @@ def apply_writes( }: if channels[chan].consume() and get_next_version is not None: checkpoint["channel_versions"][chan] = get_next_version( - max_version, # type: ignore[arg-type] + max_version, channels[chan], ) @@ -234,7 +236,7 @@ def apply_writes( if chan in channels: if channels[chan].update(vals) and get_next_version is not None: checkpoint["channel_versions"][chan] = get_next_version( - max_version, # type: ignore[arg-type] + max_version, channels[chan], ) updated_channels.add(chan) @@ -244,7 +246,7 @@ def apply_writes( if chan not in updated_channels: if channels[chan].update([]) and get_next_version is not None: checkpoint["channel_versions"][chan] = get_next_version( - max_version, # type: ignore[arg-type] + max_version, channels[chan], ) diff --git a/libs/langgraph/langgraph/pregel/debug.py b/libs/langgraph/langgraph/pregel/debug.py index 782d5a13c..56f1eb9c6 100644 --- a/libs/langgraph/langgraph/pregel/debug.py +++ b/libs/langgraph/langgraph/pregel/debug.py @@ -82,7 +82,7 @@ class DebugOutputCheckpoint(DebugOutputBase): def map_debug_tasks( - step: int, tasks: list[PregelExecutableTask] + step: int, tasks: Iterable[PregelExecutableTask] ) -> Iterator[DebugOutputTask]: ts = datetime.now(timezone.utc).isoformat() for task in tasks: @@ -132,7 +132,7 @@ def map_debug_checkpoint( stream_channels: Union[str, Sequence[str]], metadata: CheckpointMetadata, checkpoint: Checkpoint, - tasks: list[PregelExecutableTask], + tasks: Iterable[PregelExecutableTask], pending_writes: list[PendingWrite], ) -> Iterator[DebugOutputCheckpoint]: yield { diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index ba98b1605..49baa1846 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -64,6 +64,7 @@ WritableManagedValue, ) from langgraph.pregel.algo import ( + GetNextVersion, PregelTaskWrites, apply_writes, increment, @@ -92,7 +93,7 @@ ) from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager from langgraph.pregel.read import PregelNode -from langgraph.pregel.types import PregelExecutableTask, StreamMode +from langgraph.pregel.types import All, PregelExecutableTask, StreamMode from langgraph.pregel.utils import get_new_channel_versions from langgraph.store.base import BaseStore from langgraph.store.batch import AsyncBatchedStore @@ -146,7 +147,7 @@ class PregelLoop: skip_done_tasks: bool is_nested: bool - checkpointer_get_next_version: Callable[[Optional[V]], V] + checkpointer_get_next_version: GetNextVersion checkpointer_put_writes: Optional[ Callable[[RunnableConfig, Sequence[tuple[str, Any]], str], Any] ] @@ -281,8 +282,8 @@ def tick( self, *, input_keys: Union[str, Sequence[str]], - interrupt_after: Sequence[str] = EMPTY_SEQ, - interrupt_before: Sequence[str] = EMPTY_SEQ, + interrupt_after: Union[All, Sequence[str]] = EMPTY_SEQ, + interrupt_before: Union[All, Sequence[str]] = EMPTY_SEQ, manager: Union[None, AsyncParentRunManager, ParentRunManager] = None, ) -> bool: """Execute a single iteration of the Pregel loop. @@ -681,6 +682,10 @@ def __enter__(self) -> Self: if self.config.get("configurable", {}).get( CONFIG_KEY_ENSURE_LATEST ) and self.checkpoint_config["configurable"].get("checkpoint_id"): + if self.checkpointer is None: + raise RuntimeError( + "Cannot ensure latest checkpoint without checkpointer" + ) saved = self.checkpointer.get_tuple( patch_configurable(self.checkpoint_config, {"checkpoint_id": None}) ) @@ -771,7 +776,7 @@ def __init__( self.checkpointer_put_writes = checkpointer.aput_writes else: self.checkpointer_get_next_version = increment - self._checkpointer_put_after_previous = None # type: ignore[method-assign] + self._checkpointer_put_after_previous = None # type: ignore[assignment] self.checkpointer_put_writes = None async def _checkpointer_put_after_previous( @@ -801,6 +806,10 @@ async def __aenter__(self) -> Self: if self.config.get("configurable", {}).get( CONFIG_KEY_ENSURE_LATEST ) and self.checkpoint_config["configurable"].get("checkpoint_id"): + if self.checkpointer is None: + raise RuntimeError( + "Cannot ensure latest checkpoint without checkpointer" + ) saved = await self.checkpointer.aget_tuple( patch_configurable(self.checkpoint_config, {"checkpoint_id": None}) ) @@ -858,6 +867,3 @@ async def __aexit__( return await asyncio.shield( self.stack.__aexit__(exc_type, exc_value, traceback) ) - - -EMPTY_SEQ = tuple() diff --git a/libs/langgraph/langgraph/pregel/messages.py b/libs/langgraph/langgraph/pregel/messages.py index 0a96f0fee..7c3f90b10 100644 --- a/libs/langgraph/langgraph/pregel/messages.py +++ b/libs/langgraph/langgraph/pregel/messages.py @@ -7,7 +7,6 @@ List, Optional, Sequence, - Tuple, Union, cast, ) @@ -21,20 +20,16 @@ from langgraph.constants import NS_SEP from langgraph.pregel.loop import StreamChunk +Meta = tuple[tuple[str, ...], dict[str, Any]] + class StreamMessagesHandler(BaseCallbackHandler, _StreamingCallbackHandler): def __init__(self, stream: Callable[[StreamChunk], None]): self.stream = stream - self.metadata: dict[UUID, tuple[tuple[str, ...], dict[str, Any]]] = {} + self.metadata: dict[UUID, Meta] = {} self.seen: set[Union[int, str]] = set() - def _emit( - self, - meta: Tuple[str, dict[str, Any]], - message: BaseMessage, - *, - dedupe: bool = False, - ) -> None: + def _emit(self, meta: Meta, message: BaseMessage, *, dedupe: bool = False) -> None: ident = id(message) if dedupe and message.id in self.seen: return diff --git a/libs/langgraph/langgraph/pregel/read.py b/libs/langgraph/langgraph/pregel/read.py index 79643a090..3ad988b89 100644 --- a/libs/langgraph/langgraph/pregel/read.py +++ b/libs/langgraph/langgraph/pregel/read.py @@ -18,7 +18,7 @@ RunnablePassthrough, RunnableSerializable, ) -from langchain_core.runnables.base import Input, Other, Output, coerce_to_runnable +from langchain_core.runnables.base import Input, Other, coerce_to_runnable from langchain_core.runnables.utils import ConfigurableFieldSpec from langgraph.constants import CONFIG_KEY_READ @@ -206,7 +206,7 @@ def __or__( Mapping[str, Runnable[Any, Other] | Callable[[Any], Other]], ], ) -> PregelNode: - if ChannelWrite.is_writer(other): + if isinstance(other, Runnable) and ChannelWrite.is_writer(other): return self.copy(update=dict(writers=[*self.writers, other])) elif self.bound is DEFAULT_BOUND: return self.copy(update=dict(bound=coerce_to_runnable(other))) @@ -237,7 +237,7 @@ def invoke( input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], - ) -> Output: + ) -> Any: return self.bound.invoke( input, merge_configs({"metadata": self.metadata, "tags": self.tags}, config), @@ -249,7 +249,7 @@ async def ainvoke( input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], - ) -> Output: + ) -> Any: return await self.bound.ainvoke( input, merge_configs({"metadata": self.metadata, "tags": self.tags}, config), @@ -261,7 +261,7 @@ def stream( input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], - ) -> Iterator[Output]: + ) -> Iterator[Any]: yield from self.bound.stream( input, merge_configs({"metadata": self.metadata, "tags": self.tags}, config), @@ -273,7 +273,7 @@ async def astream( input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], - ) -> AsyncIterator[Output]: + ) -> AsyncIterator[Any]: async for item in self.bound.astream( input, merge_configs({"metadata": self.metadata, "tags": self.tags}, config), diff --git a/libs/langgraph/langgraph/pregel/runner.py b/libs/langgraph/langgraph/pregel/runner.py index 7e1122485..14e84352f 100644 --- a/libs/langgraph/langgraph/pregel/runner.py +++ b/libs/langgraph/langgraph/pregel/runner.py @@ -5,11 +5,13 @@ Any, AsyncIterator, Callable, + Iterable, Iterator, Optional, Sequence, Type, Union, + cast, ) from langgraph.constants import ERROR, INTERRUPT, NO_WRITES @@ -33,7 +35,7 @@ def __init__( def tick( self, - tasks: Sequence[PregelExecutableTask], + tasks: Iterable[PregelExecutableTask], *, reraise: bool = True, timeout: Optional[float] = None, @@ -106,7 +108,7 @@ def tick( async def atick( self, - tasks: Sequence[PregelExecutableTask], + tasks: Iterable[PregelExecutableTask], *, reraise: bool = True, timeout: Optional[float] = None, @@ -141,14 +143,17 @@ async def atick( for t in tasks: if not t.writes: futures[ - self.submit( - arun_with_retry, - t, - retry_policy, - stream=self.use_astream, - __name__=t.name, - __cancel_on_exit__=True, - __reraise_on_exit__=reraise, + cast( + asyncio.Future, + self.submit( + arun_with_retry, + t, + retry_policy, + stream=self.use_astream, + __name__=t.name, + __cancel_on_exit__=True, + __reraise_on_exit__=reraise, + ), ) ] = t all_futures = futures.copy() diff --git a/libs/langgraph/langgraph/pregel/validate.py b/libs/langgraph/langgraph/pregel/validate.py index 8627642e9..232014240 100644 --- a/libs/langgraph/langgraph/pregel/validate.py +++ b/libs/langgraph/langgraph/pregel/validate.py @@ -1,4 +1,4 @@ -from typing import Mapping, Optional, Sequence, Union +from typing import Any, Mapping, Optional, Sequence, Union from langgraph.channels.base import BaseChannel from langgraph.constants import RESERVED @@ -65,18 +65,18 @@ def validate_graph( raise ValueError(f"Output channel '{chan}' not in 'channels'") if interrupt_after_nodes != "*": - for node in interrupt_after_nodes: - if node not in nodes: - raise ValueError(f"Node {node} not in nodes") + for n in interrupt_after_nodes: + if n not in nodes: + raise ValueError(f"Node {n} not in nodes") if interrupt_before_nodes != "*": - for node in interrupt_before_nodes: - if node not in nodes: - raise ValueError(f"Node {node} not in nodes") + for n in interrupt_before_nodes: + if n not in nodes: + raise ValueError(f"Node {n} not in nodes") def validate_keys( keys: Optional[Union[str, Sequence[str]]], - channels: Mapping[str, BaseChannel], + channels: Mapping[str, Any], ) -> None: if isinstance(keys, str): if keys not in channels: diff --git a/libs/langgraph/langgraph/pregel/write.py b/libs/langgraph/langgraph/pregel/write.py index 9c3b7782d..2adcab757 100644 --- a/libs/langgraph/langgraph/pregel/write.py +++ b/libs/langgraph/langgraph/pregel/write.py @@ -9,6 +9,7 @@ Sequence, TypeVar, Union, + cast, ) from langchain_core.runnables import Runnable, RunnableConfig @@ -34,7 +35,7 @@ class ChannelWriteEntry(NamedTuple): class ChannelWrite(RunnableCallable): - writes: Sequence[Union[ChannelWriteEntry, Send]] + writes: list[Union[ChannelWriteEntry, Send]] """ Sequence of write entries, each of which is a tuple of: - channel name @@ -54,7 +55,7 @@ def __init__( require_at_least_one_of: Optional[Sequence[str]] = None, ): super().__init__(func=self._write, afunc=self._awrite, name=None, tags=tags) - self.writes = writes + self.writes = cast(list[Union[ChannelWriteEntry, Send]], writes) self.require_at_least_one_of = require_at_least_one_of def get_name( diff --git a/libs/langgraph/langgraph/utils/runnable.py b/libs/langgraph/langgraph/utils/runnable.py index f0b16442d..56d5d5df4 100644 --- a/libs/langgraph/langgraph/utils/runnable.py +++ b/libs/langgraph/langgraph/utils/runnable.py @@ -269,7 +269,7 @@ def __init__( if isinstance(step, RunnableSequence): steps_flat.extend(step.steps) elif isinstance(step, RunnableSeq): - steps_flat.extend(step.steps) # type: ignore[has-type] + steps_flat.extend(step.steps) else: steps_flat.append(coerce_to_runnable(step, name=None, trace=True)) if len(steps_flat) < 2: diff --git a/libs/langgraph/pyproject.toml b/libs/langgraph/pyproject.toml index 954ca2973..4975cabb0 100644 --- a/libs/langgraph/pyproject.toml +++ b/libs/langgraph/pyproject.toml @@ -57,7 +57,7 @@ warn_no_return = "False" warn_unused_ignores = "True" warn_redundant_casts = "True" allow_redefinition = "True" -disable_error_code = "typeddict-item, return-value, override" +disable_error_code = "typeddict-item, return-value, override, has-type" [tool.coverage.run] omit = ["tests/*"] From 95c1ca3adcd1b14f304aecd01140a9dd636f1e8f Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 19 Sep 2024 13:20:45 -0700 Subject: [PATCH 5/9] Fix --- libs/langgraph/Makefile | 2 +- .../channels/dynamic_barrier_value.py | 6 +++--- libs/langgraph/langgraph/graph/graph.py | 18 ++++++++++++------ libs/langgraph/langgraph/graph/state.py | 2 +- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/libs/langgraph/Makefile b/libs/langgraph/Makefile index 7ec12f80a..1e249a0cd 100644 --- a/libs/langgraph/Makefile +++ b/libs/langgraph/Makefile @@ -75,7 +75,7 @@ lint lint_diff lint_package lint_tests: [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff [ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I $(PYTHON_FILES) [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) - [ "$(PYTHON_FILES)" != "langgraph" ] || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) + [ "$(PYTHON_FILES)" = "" ] || poetry run mypy langgraph --cache-dir $(MYPY_CACHE) format format_diff: poetry run ruff format $(PYTHON_FILES) diff --git a/libs/langgraph/langgraph/channels/dynamic_barrier_value.py b/libs/langgraph/langgraph/channels/dynamic_barrier_value.py index dfa77f350..f64191e86 100644 --- a/libs/langgraph/langgraph/channels/dynamic_barrier_value.py +++ b/libs/langgraph/langgraph/channels/dynamic_barrier_value.py @@ -1,4 +1,4 @@ -from typing import Generic, NamedTuple, Optional, Sequence, Type, Union +from typing import Any, Generic, NamedTuple, Optional, Sequence, Type, Union from typing_extensions import Self @@ -6,8 +6,8 @@ from langgraph.errors import EmptyChannelError, InvalidUpdateError -class WaitForNames(NamedTuple, Generic[Value]): - names: set[Value] +class WaitForNames(NamedTuple): + names: set[Any] class DynamicBarrierValue( diff --git a/libs/langgraph/langgraph/graph/graph.py b/libs/langgraph/langgraph/graph/graph.py index 93fe28b3d..12f9cf943 100644 --- a/libs/langgraph/langgraph/graph/graph.py +++ b/libs/langgraph/langgraph/graph/graph.py @@ -268,13 +268,15 @@ def add_conditional_edges( path_map_ = path_map.copy() elif isinstance(path_map, list): path_map_ = {name: name for name in path_map} - elif callable(path) and ( - rtn_type := get_type_hints(path.__call__).get("return") - if hasattr(path, "__call__") - else get_type_hints(path).get("return") - ): + elif isinstance(path, Runnable): + path_map_ = None + elif rtn_type := get_type_hints(path.__call__).get( # type: ignore[operator] + "return" + ) or get_type_hints(path).get("return"): if get_origin(rtn_type) is Literal: path_map_ = {name: name for name in get_args(rtn_type)} + else: + path_map_ = None else: path_map_ = None except Exception: @@ -417,6 +419,7 @@ def compile( # create empty compiled graph compiled = CompiledGraph( + builder=self, nodes={}, channels={START: EphemeralValue(Any), END: EphemeralValue(Any)}, input_channels=START, @@ -429,7 +432,6 @@ def compile( auto_validate=False, debug=debug, ) - compiled.builder = self # attach nodes, edges, and branches for key, node in self.nodes.items(): @@ -449,6 +451,10 @@ def compile( class CompiledGraph(Pregel): builder: Graph + def __init__(self, *, builder: Graph, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.builder = builder + def attach_node(self, key: str, node: NodeSpec) -> None: self.channels[key] = EphemeralValue(Any) self.nodes[key] = ( diff --git a/libs/langgraph/langgraph/graph/state.py b/libs/langgraph/langgraph/graph/state.py index 803434316..ace4d2553 100644 --- a/libs/langgraph/langgraph/graph/state.py +++ b/libs/langgraph/langgraph/graph/state.py @@ -453,6 +453,7 @@ def compile( ) compiled = CompiledStateGraph( + builder=self, config_type=self.config_schema, nodes={}, channels={ @@ -471,7 +472,6 @@ def compile( debug=debug, store=store, ) - compiled.builder = self compiled.attach_node(START, None) for key, node in self.nodes.items(): From a47dc2b6340ce83663b5ece438a97ca0a34bc03a Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 19 Sep 2024 14:37:32 -0700 Subject: [PATCH 6/9] Apply suggestions from code review Co-authored-by: William FH <13333726+hinthornw@users.noreply.github.com> --- libs/langgraph/langgraph/graph/graph.py | 2 +- libs/langgraph/langgraph/graph/state.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langgraph/langgraph/graph/graph.py b/libs/langgraph/langgraph/graph/graph.py index 12f9cf943..2a1e8c8ad 100644 --- a/libs/langgraph/langgraph/graph/graph.py +++ b/libs/langgraph/langgraph/graph/graph.py @@ -194,7 +194,7 @@ def add_node( "Node name must be provided if action is not a function" ) if action is None: - raise RuntimeError + raise RuntimeError(f"Expected a function or Runnable action in add_node. Received None.") if node in self.nodes: raise ValueError(f"Node `{node}` already present.") if node == END or node == START: diff --git a/libs/langgraph/langgraph/graph/state.py b/libs/langgraph/langgraph/graph/state.py index ace4d2553..b759ff4ca 100644 --- a/libs/langgraph/langgraph/graph/state.py +++ b/libs/langgraph/langgraph/graph/state.py @@ -320,7 +320,7 @@ def add_node( ) if not isinstance(node, str): action = node - node = getattr(action, "name", getattr(action, "__name__")) + node = getattr(action, "name", getattr(action, "__name__", None)) if node is None: raise ValueError( "Node name must be provided if action is not a function" From e96934533e48888a648d60f88e585cd54f2f8fce Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 19 Sep 2024 15:06:39 -0700 Subject: [PATCH 7/9] Update graph.py --- libs/langgraph/langgraph/graph/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/graph/graph.py b/libs/langgraph/langgraph/graph/graph.py index 2a1e8c8ad..418bb3a7e 100644 --- a/libs/langgraph/langgraph/graph/graph.py +++ b/libs/langgraph/langgraph/graph/graph.py @@ -194,7 +194,7 @@ def add_node( "Node name must be provided if action is not a function" ) if action is None: - raise RuntimeError(f"Expected a function or Runnable action in add_node. Received None.") + raise RuntimeError("Expected a function or Runnable action in add_node. Received None.") if node in self.nodes: raise ValueError(f"Node `{node}` already present.") if node == END or node == START: From ed27e54761f840224ffd04f6d06173239ffe4feb Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 19 Sep 2024 16:41:46 -0700 Subject: [PATCH 8/9] Update graph.py --- libs/langgraph/langgraph/graph/graph.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/graph/graph.py b/libs/langgraph/langgraph/graph/graph.py index 418bb3a7e..7e20d1c26 100644 --- a/libs/langgraph/langgraph/graph/graph.py +++ b/libs/langgraph/langgraph/graph/graph.py @@ -194,7 +194,9 @@ def add_node( "Node name must be provided if action is not a function" ) if action is None: - raise RuntimeError("Expected a function or Runnable action in add_node. Received None.") + raise RuntimeError( + "Expected a function or Runnable action in add_node. Received None." + ) if node in self.nodes: raise ValueError(f"Node `{node}` already present.") if node == END or node == START: From 7643c1171f613acdacd9f855523c5f2e919862f7 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 19 Sep 2024 16:49:05 -0700 Subject: [PATCH 9/9] Lint --- libs/langgraph/langgraph/graph/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/graph/state.py b/libs/langgraph/langgraph/graph/state.py index b759ff4ca..e5f20a51d 100644 --- a/libs/langgraph/langgraph/graph/state.py +++ b/libs/langgraph/langgraph/graph/state.py @@ -320,7 +320,7 @@ def add_node( ) if not isinstance(node, str): action = node - node = getattr(action, "name", getattr(action, "__name__", None)) + node = cast(str, getattr(action, "name", getattr(action, "__name__", None))) if node is None: raise ValueError( "Node name must be provided if action is not a function"