From 94e64f87b16459ca393ead0895c265f16bb9cee6 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 19 Sep 2024 15:07:45 -0700 Subject: [PATCH] Allow chaning add_node/add_edge/etc calls --- libs/langgraph/langgraph/graph/graph.py | 23 ++++++++++++++--------- libs/langgraph/langgraph/graph/state.py | 11 +++++++---- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/libs/langgraph/langgraph/graph/graph.py b/libs/langgraph/langgraph/graph/graph.py index 7e20d1c26..c5a043ee7 100644 --- a/libs/langgraph/langgraph/graph/graph.py +++ b/libs/langgraph/langgraph/graph/graph.py @@ -23,6 +23,7 @@ from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.graph import Graph as DrawableGraph from langchain_core.runnables.graph import Node as DrawableNode +from typing_extensions import Self from langgraph.channels.ephemeral_value import EphemeralValue from langgraph.checkpoint.base import BaseCheckpointSaver @@ -156,7 +157,7 @@ def add_node( node: RunnableLike, *, metadata: Optional[dict[str, Any]] = None, - ) -> None: ... + ) -> Self: ... @overload def add_node( @@ -165,7 +166,7 @@ def add_node( action: RunnableLike, *, metadata: Optional[dict[str, Any]] = None, - ) -> None: ... + ) -> Self: ... def add_node( self, @@ -173,7 +174,7 @@ def add_node( action: Optional[RunnableLike] = None, *, metadata: Optional[dict[str, Any]] = None, - ) -> None: + ) -> Self: if isinstance(node, str): for character in (NS_SEP, NS_END): if character in node: @@ -205,8 +206,9 @@ def add_node( self.nodes[cast(str, node)] = NodeSpec( coerce_to_runnable(action, name=cast(str, node), trace=False), metadata ) + return self - def add_edge(self, start_key: str, end_key: str) -> None: + def add_edge(self, start_key: str, end_key: str) -> Self: if self.compiled: logger.warning( "Adding an edge to a graph that has already been compiled. This will " @@ -227,6 +229,7 @@ def add_edge(self, start_key: str, end_key: str) -> None: ) self.edges.add((start_key, end_key)) + return self def add_conditional_edges( self, @@ -238,7 +241,7 @@ def add_conditional_edges( ], path_map: Optional[Union[dict[Hashable, str], list[str]]] = None, then: Optional[str] = None, - ) -> None: + ) -> Self: """Add a conditional edge from the starting node to any number of destination nodes. Args: @@ -293,8 +296,9 @@ def add_conditional_edges( ) # save it self.branches[source][name] = Branch(path, path_map_, then) + return self - def set_entry_point(self, key: str) -> None: + def set_entry_point(self, key: str) -> Self: """Specifies the first node to be called in the graph. Equivalent to calling `add_edge(START, key)`. @@ -316,7 +320,7 @@ def set_conditional_entry_point( ], path_map: Optional[Union[dict[Hashable, str], list[str]]] = None, then: Optional[str] = None, - ) -> None: + ) -> Self: """Sets a conditional entry point in the graph. Args: @@ -333,7 +337,7 @@ def set_conditional_entry_point( """ return self.add_conditional_edges(START, path, path_map, then) - def set_finish_point(self, key: str) -> None: + def set_finish_point(self, key: str) -> Self: """Marks a node as a finish point of the graph. If the graph reaches this node, it will cease execution. @@ -346,7 +350,7 @@ def set_finish_point(self, key: str) -> None: """ return self.add_edge(key, END) - def validate(self, interrupt: Optional[Sequence[str]] = None) -> None: + def validate(self, interrupt: Optional[Sequence[str]] = None) -> Self: # assemble sources all_sources = {src for src, _ in self._all_edges} for start, branches in self.branches.items(): @@ -398,6 +402,7 @@ def validate(self, interrupt: Optional[Sequence[str]] = None) -> None: raise ValueError(f"Interrupt node `{node}` not found") self.compiled = True + return self def compile( self, diff --git a/libs/langgraph/langgraph/graph/state.py b/libs/langgraph/langgraph/graph/state.py index e5f20a51d..cee0fb849 100644 --- a/libs/langgraph/langgraph/graph/state.py +++ b/libs/langgraph/langgraph/graph/state.py @@ -23,6 +23,7 @@ from langchain_core.runnables.base import RunnableLike from pydantic import BaseModel from pydantic.v1 import BaseModel as BaseModelV1 +from typing_extensions import Self from langgraph._api.deprecation import LangGraphDeprecationWarning from langgraph.channels.base import BaseChannel @@ -211,7 +212,7 @@ def add_node( metadata: Optional[dict[str, Any]] = None, input: Optional[Type[Any]] = None, retry: Optional[RetryPolicy] = None, - ) -> None: + ) -> Self: """Adds a new node to the state graph. Will take the name of the function/runnable as the node name. @@ -235,7 +236,7 @@ def add_node( metadata: Optional[dict[str, Any]] = None, input: Optional[Type[Any]] = None, retry: Optional[RetryPolicy] = None, - ) -> None: + ) -> Self: """Adds a new node to the state graph. Args: @@ -258,7 +259,7 @@ def add_node( metadata: Optional[dict[str, Any]] = None, input: Optional[Type[Any]] = None, retry: Optional[RetryPolicy] = None, - ) -> None: + ) -> Self: """Adds a new node to the state graph. Will take the name of the function/runnable as the node name. @@ -359,8 +360,9 @@ def add_node( input=input or self.schema, retry_policy=retry, ) + return self - def add_edge(self, start_key: Union[str, list[str]], end_key: str) -> None: + def add_edge(self, start_key: Union[str, list[str]], end_key: str) -> Self: """Adds a directed edge from the start node to the end node. If the graph transitions to the start_key node, it will always transition to the end_key node next. @@ -394,6 +396,7 @@ def add_edge(self, start_key: Union[str, list[str]], end_key: str) -> None: raise ValueError(f"Need to add_node `{end_key}` first") self.waiting_edges.add((tuple(start_key), end_key)) + return self def compile( self,