Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interrupts shouldn't be retried #1413

Merged
merged 3 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions libs/langgraph/langgraph/pregel/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from typing import Optional

from langgraph.errors import GraphInterrupt
from langgraph.pregel.types import PregelExecutableTask, RetryPolicy

logger = logging.getLogger(__name__)
Expand All @@ -25,6 +26,9 @@ def run_with_retry(
task.proc.invoke(task.input, task.config)
# if successful, end
break
except GraphInterrupt:
# if interrupted, end
raise
except Exception as exc:
if retry_policy is None:
raise
Expand Down Expand Up @@ -75,6 +79,9 @@ async def arun_with_retry(
await task.proc.ainvoke(task.input, task.config)
# if successful, end
break
except GraphInterrupt:
# if interrupted, end
raise
except Exception as exc:
if retry_policy is None:
raise
Expand Down
7 changes: 6 additions & 1 deletion libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6139,20 +6139,25 @@ class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str

tool_two_node_count = 0

def tool_two_node(s: State) -> State:
nonlocal tool_two_node_count
tool_two_node_count += 1
if s["market"] == "DE":
raise NodeInterrupt("Just because...")
return {"my_key": " all good"}

tool_two_graph = StateGraph(State)
tool_two_graph.add_node("tool_two", tool_two_node)
tool_two_graph.add_node("tool_two", tool_two_node, retry=RetryPolicy())
tool_two_graph.add_edge(START, "tool_two")
tool_two = tool_two_graph.compile()

assert tool_two.invoke({"my_key": "value", "market": "DE"}) == {
"my_key": "value",
"market": "DE",
}
assert tool_two_node_count == 1, "interrupts aren't retried"
assert tool_two.invoke({"my_key": "value", "market": "US"}) == {
"my_key": "value all good",
"market": "US",
Expand Down
76 changes: 76 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,82 @@ async def iambad(input: Any) -> None:
assert inner_task_cancelled


async def test_dynamic_interrupt(snapshot: SnapshotAssertion) -> None:
class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str

tool_two_node_count = 0

async def tool_two_node(s: State) -> State:
nonlocal tool_two_node_count
tool_two_node_count += 1
if s["market"] == "DE":
raise NodeInterrupt("Just because...")
return {"my_key": " all good"}

tool_two_graph = StateGraph(State)
tool_two_graph.add_node("tool_two", tool_two_node, retry=RetryPolicy())
tool_two_graph.add_edge(START, "tool_two")
tool_two = tool_two_graph.compile()

assert await tool_two.ainvoke({"my_key": "value", "market": "DE"}) == {
"my_key": "value",
"market": "DE",
}
assert tool_two_node_count == 1, "interrupts aren't retried"
assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == {
"my_key": "value all good",
"market": "US",
}

async with AsyncSqliteSaver.from_conn_string(":memory:") as saver:
tool_two = tool_two_graph.compile(checkpointer=saver)

# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
await tool_two.ainvoke({"my_key": "value", "market": "DE"})

thread1 = {"configurable": {"thread_id": "1"}}
# stop when about to enter node
assert await tool_two.ainvoke(
{"my_key": "value ⛰️", "market": "DE"}, thread1
) == {
"my_key": "value ⛰️",
"market": "DE",
}
assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [
{
"source": "loop",
"step": 0,
"writes": None,
},
{
"source": "input",
"step": -1,
"writes": {"my_key": "value ⛰️", "market": "DE"},
},
]
tup = await tool_two.checkpointer.aget_tuple(thread1)
assert await tool_two.aget_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️", "market": "DE"},
next=("tool_two",),
tasks=(
PregelTask(
AnyStr(),
"tool_two",
interrupts=(Interrupt("during", "Just because..."),),
),
),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={"source": "loop", "step": 0, "writes": None},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread1, limit=2)
][-1].config,
)


@pytest.mark.parametrize(
"checkpointer_name",
["memory", "sqlite_aio", "postgres_aio", "postgres_aio_pipe"],
Expand Down
Loading