Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more tests for async cancellation #2902

Merged
merged 4 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import asyncio

Check notice on line 1 in libs/langgraph/langgraph/pregel/loop.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 60.4 ms +- 1.1 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 51.8 ms +- 0.8 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 73.4 ms +- 1.5 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 94.0 ms +- 1.2 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 610 ms +- 23 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 507 ms +- 8 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 755 ms +- 12 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 939 ms +- 18 ms ......................................... react_agent_10x: Mean +- std dev: 30.3 ms +- 0.6 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.7 ms +- 0.4 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 37.7 ms +- 0.7 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.7 ms +- 0.4 ms ......................................... react_agent_100x: Mean +- std dev: 337 ms +- 6 ms ......................................... react_agent_100x_sync: Mean +- std dev: 273 ms +- 4 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 847 ms +- 6 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 840 ms +- 6 ms ......................................... wide_state_25x300: Mean +- std dev: 22.6 ms +- 0.4 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 14.7 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 275 ms +- 13 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 274 ms +- 13 ms ......................................... wide_state_15x600: Mean +- std dev: 26.4 ms +- 0.5 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 17.0 ms +- 0.1 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 473 ms +- 14 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 474 ms +- 14 ms ......................................... wide_state_9x1200: Mean +- std dev: 26.5 ms +- 0.5 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 17.1 ms +- 0.1 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 309 ms +- 15 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 306 ms +- 14 ms

Check notice on line 1 in libs/langgraph/langgraph/pregel/loop.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +=========================================+=========+=======================+ | wide_state_15x600_sync | 17.3 ms | 17.0 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600 | 26.7 ms | 26.4 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 477 ms | 473 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 14.7 ms | 14.7 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300 | 22.8 ms | 22.6 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_sync | 51.6 ms | 51.8 ms: 1.00x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 505 ms | 507 ms: 1.00x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint_sync | 36.5 ms | 36.7 ms: 1.00x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x | 60.2 ms | 60.4 ms: 1.00x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_100x | 335 ms | 337 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 93.2 ms | 94.0 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint_sync | 929 ms | 939 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 830 ms | 840 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 836 ms | 847 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 730 ms | 755 ms: 1.03x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 589 ms | 610 ms: 1.04x slower | +-----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.00x slower | +-----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (12): wide_state_25x300_checkpoint_sync, react_agent_10x, react_agent_10x_checkpoint, wide_state_9x1200, wide_state_25x300_checkpoint, wide_state_15x600_checkpoint_sync, wide_state_9x1200_sync, react_agent_10x_sync, fanout_to_subgraph_10x_checkpoint, react_agent_100x_sync, wide_state_9x1200_checkpoint_sync, wide_state_9x1200_checkpoint
import concurrent.futures
from collections import defaultdict, deque
from contextlib import AsyncExitStack, ExitStack
Expand Down Expand Up @@ -1032,6 +1032,4 @@
traceback: Optional[TracebackType],
) -> Optional[bool]:
# unwind stack
return await asyncio.shield(
self.stack.__aexit__(exc_type, exc_value, traceback)
)
return await self.stack.__aexit__(exc_type, exc_value, traceback)
256 changes: 256 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,262 @@ def logic(inp: str) -> str:
pass


async def test_py_async_with_cancel_behavior() -> None:
"""This test confirms that in all versions of Python we support, __aexit__
is not cancelled when the coroutine containing the async with block is cancelled."""

logs: list[str] = []

class MyContextManager:
async def __aenter__(self):
logs.append("Entering")
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
logs.append("Starting exit")
try:
# Simulate some cleanup work
await asyncio.sleep(2)
logs.append("Cleanup completed")
except asyncio.CancelledError:
logs.append("Cleanup was cancelled!")
raise
logs.append("Exit finished")

async def main():
try:
async with MyContextManager():
logs.append("In context")
await asyncio.sleep(1)
logs.append("This won't print if cancelled")
except asyncio.CancelledError:
logs.append("Context was cancelled")
raise

# create task
t = asyncio.create_task(main())
# cancel after 0.2 seconds
await asyncio.sleep(0.2)
t.cancel()
# check logs before cancellation is handled
assert logs == [
"Entering",
"In context",
], "Cancelled before cleanup started"
# wait for task to finish
try:
await t
except asyncio.CancelledError:
# check logs after cancellation is handled
assert logs == [
"Entering",
"In context",
"Starting exit",
"Cleanup completed",
"Exit finished",
"Context was cancelled",
], "Cleanup started and finished after cancellation"
else:
assert False, "Task should be cancelled"


async def test_checkpoint_put_after_cancellation() -> None:
logs: list[str] = []

class LongPutCheckpointer(MemorySaver):
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
logs.append("checkpoint.aput.start")
try:
await asyncio.sleep(1)
return await super().aput(config, checkpoint, metadata, new_versions)
finally:
logs.append("checkpoint.aput.end")

inner_task_cancelled = False

async def awhile(input: Any) -> None:
logs.append("awhile.start")
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
nonlocal inner_task_cancelled
inner_task_cancelled = True
raise
finally:
logs.append("awhile.end")

builder = Graph()
builder.add_node("agent", awhile)
builder.set_entry_point("agent")
builder.set_finish_point("agent")

graph = builder.compile(checkpointer=LongPutCheckpointer())
thread1 = {"configurable": {"thread_id": "1"}}

# start the task
t = asyncio.create_task(graph.ainvoke(1, thread1))
# cancel after 0.2 seconds
await asyncio.sleep(0.2)
t.cancel()
# check logs before cancellation is handled
assert sorted(logs) == [
"awhile.start",
"checkpoint.aput.start",
], "Cancelled before checkpoint put started"
# wait for task to finish
try:
await t
except asyncio.CancelledError:
# check logs after cancellation is handled
assert sorted(logs) == [
"awhile.end",
"awhile.start",
"checkpoint.aput.end",
"checkpoint.aput.start",
], "Checkpoint put is not cancelled"
else:
assert False, "Task should be cancelled"


async def test_checkpoint_put_after_cancellation_stream_anext() -> None:
logs: list[str] = []

class LongPutCheckpointer(MemorySaver):
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
logs.append("checkpoint.aput.start")
try:
await asyncio.sleep(1)
return await super().aput(config, checkpoint, metadata, new_versions)
finally:
logs.append("checkpoint.aput.end")

inner_task_cancelled = False

async def awhile(input: Any) -> None:
logs.append("awhile.start")
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
nonlocal inner_task_cancelled
inner_task_cancelled = True
raise
finally:
logs.append("awhile.end")

builder = Graph()
builder.add_node("agent", awhile)
builder.set_entry_point("agent")
builder.set_finish_point("agent")

graph = builder.compile(checkpointer=LongPutCheckpointer())
thread1 = {"configurable": {"thread_id": "1"}}

# start the task
s = graph.astream(1, thread1)
t = asyncio.create_task(s.__anext__())
# cancel after 0.2 seconds
await asyncio.sleep(0.2)
t.cancel()
# check logs before cancellation is handled
assert sorted(logs) == [
"awhile.start",
"checkpoint.aput.start",
], "Cancelled before checkpoint put started"
# wait for task to finish
try:
await t
except asyncio.CancelledError:
# check logs after cancellation is handled
assert sorted(logs) == [
"awhile.end",
"awhile.start",
"checkpoint.aput.end",
"checkpoint.aput.start",
], "Checkpoint put is not cancelled"
else:
assert False, "Task should be cancelled"


async def test_checkpoint_put_after_cancellation_stream_events_anext() -> None:
logs: list[str] = []

class LongPutCheckpointer(MemorySaver):
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
logs.append("checkpoint.aput.start")
try:
await asyncio.sleep(1)
return await super().aput(config, checkpoint, metadata, new_versions)
finally:
logs.append("checkpoint.aput.end")

inner_task_cancelled = False

async def awhile(input: Any) -> None:
logs.append("awhile.start")
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
nonlocal inner_task_cancelled
inner_task_cancelled = True
raise
finally:
logs.append("awhile.end")

builder = Graph()
builder.add_node("agent", awhile)
builder.set_entry_point("agent")
builder.set_finish_point("agent")

graph = builder.compile(checkpointer=LongPutCheckpointer())
thread1 = {"configurable": {"thread_id": "1"}}

# start the task
s = graph.astream_events(1, thread1, version="v2", include_names=["LangGraph"])
# skip first event (happens right away)
await s.__anext__()
# start the task for 2nd event
t = asyncio.create_task(s.__anext__())
# cancel after 0.2 seconds
await asyncio.sleep(0.2)
t.cancel()
# check logs before cancellation is handled
assert logs == [
"checkpoint.aput.start",
"awhile.start",
], "Cancelled before checkpoint put started"
# wait for task to finish
try:
await t
except asyncio.CancelledError:
# check logs after cancellation is handled
assert logs == [
"checkpoint.aput.start",
"awhile.start",
"awhile.end",
"checkpoint.aput.end",
], "Checkpoint put is not cancelled"
else:
assert False, "Task should be cancelled"


async def test_node_cancellation_on_external_cancel() -> None:
inner_task_cancelled = False

Expand Down
Loading