Skip to content

Commit

Permalink
Interrupts shouldn't be retried (#1413)
Browse files Browse the repository at this point in the history
* Interrupts shouldn't be retried

* Add async test

* Lint
  • Loading branch information
nfcampos authored Aug 21, 2024
1 parent 2f41b28 commit 14976d4
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 1 deletion.
7 changes: 7 additions & 0 deletions libs/langgraph/langgraph/pregel/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from typing import Optional

from langgraph.errors import GraphInterrupt
from langgraph.pregel.types import PregelExecutableTask, RetryPolicy

logger = logging.getLogger(__name__)
Expand All @@ -25,6 +26,9 @@ def run_with_retry(
task.proc.invoke(task.input, task.config)
# if successful, end
break
except GraphInterrupt:
# if interrupted, end
raise
except Exception as exc:
if retry_policy is None:
raise
Expand Down Expand Up @@ -75,6 +79,9 @@ async def arun_with_retry(
await task.proc.ainvoke(task.input, task.config)
# if successful, end
break
except GraphInterrupt:
# if interrupted, end
raise
except Exception as exc:
if retry_policy is None:
raise
Expand Down
7 changes: 6 additions & 1 deletion libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6139,20 +6139,25 @@ class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str

tool_two_node_count = 0

def tool_two_node(s: State) -> State:
nonlocal tool_two_node_count
tool_two_node_count += 1
if s["market"] == "DE":
raise NodeInterrupt("Just because...")
return {"my_key": " all good"}

tool_two_graph = StateGraph(State)
tool_two_graph.add_node("tool_two", tool_two_node)
tool_two_graph.add_node("tool_two", tool_two_node, retry=RetryPolicy())
tool_two_graph.add_edge(START, "tool_two")
tool_two = tool_two_graph.compile()

assert tool_two.invoke({"my_key": "value", "market": "DE"}) == {
"my_key": "value",
"market": "DE",
}
assert tool_two_node_count == 1, "interrupts aren't retried"
assert tool_two.invoke({"my_key": "value", "market": "US"}) == {
"my_key": "value all good",
"market": "US",
Expand Down
76 changes: 76 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,82 @@ async def iambad(input: Any) -> None:
assert inner_task_cancelled


async def test_dynamic_interrupt(snapshot: SnapshotAssertion) -> None:
class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str

tool_two_node_count = 0

async def tool_two_node(s: State) -> State:
nonlocal tool_two_node_count
tool_two_node_count += 1
if s["market"] == "DE":
raise NodeInterrupt("Just because...")
return {"my_key": " all good"}

tool_two_graph = StateGraph(State)
tool_two_graph.add_node("tool_two", tool_two_node, retry=RetryPolicy())
tool_two_graph.add_edge(START, "tool_two")
tool_two = tool_two_graph.compile()

assert await tool_two.ainvoke({"my_key": "value", "market": "DE"}) == {
"my_key": "value",
"market": "DE",
}
assert tool_two_node_count == 1, "interrupts aren't retried"
assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == {
"my_key": "value all good",
"market": "US",
}

async with AsyncSqliteSaver.from_conn_string(":memory:") as saver:
tool_two = tool_two_graph.compile(checkpointer=saver)

# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
await tool_two.ainvoke({"my_key": "value", "market": "DE"})

thread1 = {"configurable": {"thread_id": "1"}}
# stop when about to enter node
assert await tool_two.ainvoke(
{"my_key": "value ⛰️", "market": "DE"}, thread1
) == {
"my_key": "value ⛰️",
"market": "DE",
}
assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [
{
"source": "loop",
"step": 0,
"writes": None,
},
{
"source": "input",
"step": -1,
"writes": {"my_key": "value ⛰️", "market": "DE"},
},
]
tup = await tool_two.checkpointer.aget_tuple(thread1)
assert await tool_two.aget_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️", "market": "DE"},
next=("tool_two",),
tasks=(
PregelTask(
AnyStr(),
"tool_two",
interrupts=(Interrupt("during", "Just because..."),),
),
),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={"source": "loop", "step": 0, "writes": None},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread1, limit=2)
][-1].config,
)


@pytest.mark.parametrize(
"checkpointer_name",
["memory", "sqlite_aio", "postgres_aio", "postgres_aio_pipe"],
Expand Down

0 comments on commit 14976d4

Please sign in to comment.