Skip to content

Commit

Permalink
Add checkpointer=True mode for subgraphs that want to keep state betw…
Browse files Browse the repository at this point in the history
…eenn turns
  • Loading branch information
nfcampos committed Jan 15, 2025
1 parent cd64075 commit 5375af7
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 3 deletions.
9 changes: 9 additions & 0 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1455,6 +1455,8 @@ def _defaults(
checkpointer: Optional[BaseCheckpointSaver] = None
elif CONFIG_KEY_CHECKPOINTER in config.get(CONF, {}):
checkpointer = config[CONF][CONFIG_KEY_CHECKPOINTER]
elif self.checkpointer is True:
raise RuntimeError("checkpointer=True cannot be used for root graphs.")
else:
checkpointer = self.checkpointer
if checkpointer and not config.get(CONF):
Expand Down Expand Up @@ -1598,6 +1600,12 @@ def output() -> Iterator:
interrupt_after=interrupt_after,
debug=debug,
)
# set up subgraph checkpointing
if self.checkpointer is True:
ns = cast(str, config[CONF][CONFIG_KEY_CHECKPOINT_NS])
config[CONF][CONFIG_KEY_CHECKPOINT_NS] = NS_SEP.join(
part.split(NS_END)[0] for part in ns.split(NS_SEP)
)
# set up messages stream mode
if "messages" in stream_modes:
run_manager.inheritable_handlers.append(
Expand All @@ -1622,6 +1630,7 @@ def output() -> Iterator:
interrupt_after=interrupt_after_,
manager=run_manager,
debug=debug,
check_subgraphs=self.checkpointer is not True,
) as loop:
# create runner
runner = PregelRunner(
Expand Down
8 changes: 5 additions & 3 deletions libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ class ToolOutputMixin: # type: ignore[no-redef]
All = Literal["*"]
"""Special value to indicate that graph should interrupt on all nodes."""

Checkpointer = Union[None, Literal[False], BaseCheckpointSaver]
"""Type of the checkpointer to use for a subgraph. False disables checkpointing,
even if the parent graph has a checkpointer. None inherits checkpointer."""
Checkpointer = Union[None, bool, BaseCheckpointSaver]
"""Type of the checkpointer to use for a subgraph.
- True enables persistent checkpointing for this subgraph.
- False disables checkpointing, even if the parent graph has a checkpointer.
- None inherits checkpointer from the parent graph."""

StreamMode = Literal["values", "updates", "debug", "messages", "custom"]
"""How the stream method should emit outputs.
Expand Down
60 changes: 60 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3189,6 +3189,66 @@ def side(state: State):
]


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_subgraph_checkpoint_true(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue("checkpointer_" + checkpointer_name)

class InnerState(TypedDict):
my_key: Annotated[str, operator.add]
my_other_key: str

def inner_1(state: InnerState):
return {"my_key": " got here", "my_other_key": state["my_key"]}

def inner_2(state: InnerState):
return {"my_key": " and there"}

inner = StateGraph(InnerState)
inner.add_node("inner_1", inner_1)
inner.add_node("inner_2", inner_2)
inner.add_edge("inner_1", "inner_2")
inner.set_entry_point("inner_1")
inner.set_finish_point("inner_2")

class State(TypedDict):
my_key: str

graph = StateGraph(State)
graph.add_node("inner", inner.compile(checkpointer=True))
graph.add_edge(START, "inner")
graph.add_conditional_edges(
"inner", lambda s: "inner" if s["my_key"].count("there") < 2 else END
)
app = graph.compile(checkpointer=checkpointer)

config = {"configurable": {"thread_id": "2"}}
assert [c for c in app.stream({"my_key": ""}, config, subgraphs=True)] == [
(("inner",), {"inner_1": {"my_key": " got here", "my_other_key": ""}}),
(("inner",), {"inner_2": {"my_key": " and there"}}),
((), {"inner": {"my_key": " got here and there"}}),
(
("inner",),
{
"inner_1": {
"my_key": " got here",
"my_other_key": " got here and there got here and there",
}
},
),
(("inner",), {"inner_2": {"my_key": " and there"}}),
(
(),
{
"inner": {
"my_key": " got here and there got here and there got here and there"
}
},
),
]


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_stream_subgraphs_during_execution(
request: pytest.FixtureRequest, checkpointer_name: str
Expand Down

0 comments on commit 5375af7

Please sign in to comment.