Skip to content

Commit

Permalink
Merge pull request #1773 from langchain-ai/nc/19sep/return-self
Browse files Browse the repository at this point in the history
Allow chaning add_node/add_edge/etc calls
  • Loading branch information
nfcampos authored Sep 20, 2024
2 parents 0bbe461 + 94e64f8 commit b9fe384
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
23 changes: 14 additions & 9 deletions libs/langgraph/langgraph/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.graph import Graph as DrawableGraph
from langchain_core.runnables.graph import Node as DrawableNode
from typing_extensions import Self

from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.checkpoint.base import BaseCheckpointSaver
Expand Down Expand Up @@ -156,7 +157,7 @@ def add_node(
node: RunnableLike,
*,
metadata: Optional[dict[str, Any]] = None,
) -> None: ...
) -> Self: ...

@overload
def add_node(
Expand All @@ -165,15 +166,15 @@ def add_node(
action: RunnableLike,
*,
metadata: Optional[dict[str, Any]] = None,
) -> None: ...
) -> Self: ...

def add_node(
self,
node: Union[str, RunnableLike],
action: Optional[RunnableLike] = None,
*,
metadata: Optional[dict[str, Any]] = None,
) -> None:
) -> Self:
if isinstance(node, str):
for character in (NS_SEP, NS_END):
if character in node:
Expand Down Expand Up @@ -205,8 +206,9 @@ def add_node(
self.nodes[cast(str, node)] = NodeSpec(
coerce_to_runnable(action, name=cast(str, node), trace=False), metadata
)
return self

def add_edge(self, start_key: str, end_key: str) -> None:
def add_edge(self, start_key: str, end_key: str) -> Self:
if self.compiled:
logger.warning(
"Adding an edge to a graph that has already been compiled. This will "
Expand All @@ -227,6 +229,7 @@ def add_edge(self, start_key: str, end_key: str) -> None:
)

self.edges.add((start_key, end_key))
return self

def add_conditional_edges(
self,
Expand All @@ -238,7 +241,7 @@ def add_conditional_edges(
],
path_map: Optional[Union[dict[Hashable, str], list[str]]] = None,
then: Optional[str] = None,
) -> None:
) -> Self:
"""Add a conditional edge from the starting node to any number of destination nodes.
Args:
Expand Down Expand Up @@ -293,8 +296,9 @@ def add_conditional_edges(
)
# save it
self.branches[source][name] = Branch(path, path_map_, then)
return self

def set_entry_point(self, key: str) -> None:
def set_entry_point(self, key: str) -> Self:
"""Specifies the first node to be called in the graph.
Equivalent to calling `add_edge(START, key)`.
Expand All @@ -316,7 +320,7 @@ def set_conditional_entry_point(
],
path_map: Optional[Union[dict[Hashable, str], list[str]]] = None,
then: Optional[str] = None,
) -> None:
) -> Self:
"""Sets a conditional entry point in the graph.
Args:
Expand All @@ -333,7 +337,7 @@ def set_conditional_entry_point(
"""
return self.add_conditional_edges(START, path, path_map, then)

def set_finish_point(self, key: str) -> None:
def set_finish_point(self, key: str) -> Self:
"""Marks a node as a finish point of the graph.
If the graph reaches this node, it will cease execution.
Expand All @@ -346,7 +350,7 @@ def set_finish_point(self, key: str) -> None:
"""
return self.add_edge(key, END)

def validate(self, interrupt: Optional[Sequence[str]] = None) -> None:
def validate(self, interrupt: Optional[Sequence[str]] = None) -> Self:
# assemble sources
all_sources = {src for src, _ in self._all_edges}
for start, branches in self.branches.items():
Expand Down Expand Up @@ -398,6 +402,7 @@ def validate(self, interrupt: Optional[Sequence[str]] = None) -> None:
raise ValueError(f"Interrupt node `{node}` not found")

self.compiled = True
return self

def compile(
self,
Expand Down
11 changes: 7 additions & 4 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from langchain_core.runnables.base import RunnableLike
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self

from langgraph._api.deprecation import LangGraphDeprecationWarning
from langgraph.channels.base import BaseChannel
Expand Down Expand Up @@ -211,7 +212,7 @@ def add_node(
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
retry: Optional[RetryPolicy] = None,
) -> None:
) -> Self:
"""Adds a new node to the state graph.
Will take the name of the function/runnable as the node name.
Expand All @@ -235,7 +236,7 @@ def add_node(
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
retry: Optional[RetryPolicy] = None,
) -> None:
) -> Self:
"""Adds a new node to the state graph.
Args:
Expand All @@ -258,7 +259,7 @@ def add_node(
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
retry: Optional[RetryPolicy] = None,
) -> None:
) -> Self:
"""Adds a new node to the state graph.
Will take the name of the function/runnable as the node name.
Expand Down Expand Up @@ -359,8 +360,9 @@ def add_node(
input=input or self.schema,
retry_policy=retry,
)
return self

def add_edge(self, start_key: Union[str, list[str]], end_key: str) -> None:
def add_edge(self, start_key: Union[str, list[str]], end_key: str) -> Self:
"""Adds a directed edge from the start node to the end node.
If the graph transitions to the start_key node, it will always transition to the end_key node next.
Expand Down Expand Up @@ -394,6 +396,7 @@ def add_edge(self, start_key: Union[str, list[str]], end_key: str) -> None:
raise ValueError(f"Need to add_node `{end_key}` first")

self.waiting_edges.add((tuple(start_key), end_key))
return self

def compile(
self,
Expand Down

0 comments on commit b9fe384

Please sign in to comment.