Skip to content

Commit

Permalink
Merge pull request #2432 from langchain-ai/nc/15nov/copy-checkpoint
Browse files Browse the repository at this point in the history
lib: Restore prev behavior for update_state(None)
  • Loading branch information
nfcampos authored Nov 15, 2024
2 parents 0388534 + 1dbdd7d commit 07c6532
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 4 deletions.
24 changes: 22 additions & 2 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
CONFIG_KEY_STREAM,
CONFIG_KEY_STREAM_WRITER,
CONFIG_KEY_TASK_ID,
END,
ERROR,
INPUT,
INTERRUPT,
Expand Down Expand Up @@ -901,8 +902,8 @@ def update_state(
checkpoint,
LoopProtocol(config=config, step=step + 1, stop=step + 2),
) as (channels, managed):
# no values, just clear all tasks
if values is None and as_node is None:
# no values as END, just clear all tasks
if values is None and as_node == END:
if saved is not None:
# tasks for this checkpoint
next_tasks = prepare_next_tasks(
Expand Down Expand Up @@ -955,6 +956,25 @@ def update_state(
return patch_checkpoint_map(
next_config, saved.metadata if saved else None
)
# no values, copy checkpoint
if values is None and as_node is None:
next_checkpoint = create_checkpoint(checkpoint, None, step)
# copy checkpoint
next_config = checkpointer.put(
checkpoint_config,
next_checkpoint,
{
**checkpoint_metadata,
"source": "update",
"step": step + 1,
"writes": {},
"parents": saved.metadata.get("parents", {}) if saved else {},
},
{},
)
return patch_checkpoint_map(
next_config, saved.metadata if saved else None
)
# apply pending writes, if not on specific checkpoint
if (
CONFIG_KEY_CHECKPOINT_ID not in config[CONF]
Expand Down
165 changes: 163 additions & 2 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8556,8 +8556,8 @@ def tool_two_node(s: State) -> State:
parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config,
)
# clear the interrupt and next tasks
tool_two.update_state(thread1, None)
# interrupt is cleared, task will still run next
tool_two.update_state(thread1, None, as_node=END)
# interrupt and next tasks are cleared
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️", "market": "DE"},
next=(),
Expand All @@ -8575,6 +8575,167 @@ def tool_two_node(s: State) -> State:
)


@pytest.mark.skipif(not FF_SEND_V2, reason="send v2 is not enabled")
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_copy_checkpoint(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str

def tool_one(s: State) -> State:
return {"my_key": " one"}

tool_two_node_count = 0

def tool_two_node(s: State) -> State:
nonlocal tool_two_node_count
tool_two_node_count += 1
if s["market"] == "DE":
answer = interrupt("Just because...")
else:
answer = " all good"
return {"my_key": answer}

def start(state: State) -> list[Union[Send, str]]:
return ["tool_two", Send("tool_one", state)]

tool_two_graph = StateGraph(State)
tool_two_graph.add_node("tool_two", tool_two_node, retry=RetryPolicy())
tool_two_graph.add_node("tool_one", tool_one)
tool_two_graph.set_conditional_entry_point(start)
tool_two = tool_two_graph.compile()

tracer = FakeTracer()
assert tool_two.invoke(
{"my_key": "value", "market": "DE"}, {"callbacks": [tracer]}
) == {
"my_key": "value one",
"market": "DE",
}
assert tool_two_node_count == 1, "interrupts aren't retried"
assert len(tracer.runs) == 1
run = tracer.runs[0]
assert run.end_time is not None
assert run.error is None
assert run.outputs == {"market": "DE", "my_key": "value one"}

assert tool_two.invoke({"my_key": "value", "market": "US"}) == {
"my_key": "value one all good",
"market": "US",
}

tool_two = tool_two_graph.compile(checkpointer=checkpointer)

# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
tool_two.invoke({"my_key": "value", "market": "DE"})

# flow: interrupt -> resume with answer
thread2 = {"configurable": {"thread_id": "2"}}
# stop when about to enter node
assert [
c for c in tool_two.stream({"my_key": "value ⛰️", "market": "DE"}, thread2)
] == [
{
"tool_one": {"my_key": " one"},
},
{
"__interrupt__": (
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:")],
),
)
},
]
# resume with answer
assert [c for c in tool_two.stream(Command(resume=" my answer"), thread2)] == [
{"tool_two": {"my_key": " my answer"}},
]

# flow: interrupt -> clear tasks
thread1 = {"configurable": {"thread_id": "1"}}
# stop when about to enter node
assert tool_two.invoke({"my_key": "value ⛰️", "market": "DE"}, thread1) == {
"my_key": "value ⛰️ one",
"market": "DE",
}
assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [
{
"parents": {},
"source": "loop",
"step": 0,
"writes": {"tool_one": {"my_key": " one"}},
"thread_id": "1",
},
{
"parents": {},
"source": "input",
"step": -1,
"writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}},
"thread_id": "1",
},
]
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️ one", "market": "DE"},
next=("tool_two",),
tasks=(
PregelTask(
AnyStr(),
"tool_two",
(PULL, "tool_two"),
interrupts=(
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:")],
),
),
),
),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": {"tool_one": {"my_key": " one"}},
"thread_id": "1",
},
parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config,
)
# clear the interrupt and next tasks
tool_two.update_state(thread1, None)
# interrupt is cleared, next task is kept
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️ one", "market": "DE"},
next=("tool_two",),
tasks=(
PregelTask(
AnyStr(),
"tool_two",
(PULL, "tool_two"),
interrupts=(),
),
),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {},
"thread_id": "1",
},
parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config,
)


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_start_branch_then(
snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str
Expand Down

0 comments on commit 07c6532

Please sign in to comment.