Skip to content

Commit

Permalink
Guard
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Jan 16, 2025
1 parent 145220f commit 0e4dbb4
Showing 1 changed file with 21 additions and 60 deletions.
81 changes: 21 additions & 60 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@

pytestmark = pytest.mark.anyio

NEEDS_CONTEXTVARS = pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)


async def test_checkpoint_errors() -> None:
class FaultyGetCheckpointer(MemorySaver):
Expand Down Expand Up @@ -501,10 +506,7 @@ async def iambad(input: Any) -> None:
await graph.ainvoke(1)


@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_dynamic_interrupt(checkpointer_name: str) -> None:
class State(TypedDict):
Expand Down Expand Up @@ -678,10 +680,7 @@ async def tool_two_node(s: State) -> State:
)


@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_dynamic_interrupt_subgraph(checkpointer_name: str) -> None:
class SubgraphState(TypedDict):
Expand Down Expand Up @@ -872,10 +871,7 @@ class State(TypedDict):
)


@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_copy_checkpoint(checkpointer_name: str) -> None:
class State(TypedDict):
Expand Down Expand Up @@ -1079,10 +1075,7 @@ def start(state: State) -> list[Union[Send, str]]:
)


@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_node_not_cancelled_on_other_node_interrupted(
checkpointer_name: str,
Expand Down Expand Up @@ -2442,10 +2435,7 @@ async def route_to_three(state) -> Literal["3"]:
]


@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_imp_task(checkpointer_name: str) -> None:
async with awith_checkpointer(checkpointer_name) as checkpointer:
Expand Down Expand Up @@ -2493,10 +2483,7 @@ async def graph(input: list[int]) -> list[str]:
assert mapper_calls == 2


@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_imp_task_cancel(checkpointer_name: str) -> None:
async with awith_checkpointer(checkpointer_name) as checkpointer:
Expand Down Expand Up @@ -2547,10 +2534,7 @@ async def graph(input: list[int]) -> list[str]:
assert mapper_cancels == 2


@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_imp_sync_from_async(checkpointer_name: str) -> None:
async with awith_checkpointer(checkpointer_name) as checkpointer:
Expand Down Expand Up @@ -2583,10 +2567,7 @@ def graph(state: dict) -> dict:
]


@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_imp_stream_order(checkpointer_name: str) -> None:
async with awith_checkpointer(checkpointer_name) as checkpointer:
Expand Down Expand Up @@ -6117,10 +6098,7 @@ class CustomParentState(TypedDict):
)


@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_interrupt_subgraph(checkpointer_name: str):
class State(TypedDict):
Expand Down Expand Up @@ -6153,10 +6131,7 @@ def bar(state):
assert await graph.ainvoke(Command(resume="bar"), thread1)


@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_interrupt_multiple(checkpointer_name: str):
class State(TypedDict):
Expand Down Expand Up @@ -6220,10 +6195,7 @@ async def node(s: State) -> State:
]


@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_interrupt_loop(checkpointer_name: str):
class State(TypedDict):
Expand Down Expand Up @@ -6508,10 +6480,7 @@ async def fast_node(state: State):
assert duration < 3.0


@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_multiple_interrupt_state_persistence(checkpointer_name: str) -> None:
"""Test that state is preserved correctly across multiple interrupts."""
Expand Down Expand Up @@ -6692,10 +6661,7 @@ def node_b(state):
]


@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_falsy_return_from_task(checkpointer_name: str) -> None:
"""Test with a falsy return from a task."""
Expand All @@ -6717,10 +6683,7 @@ async def graph(state: dict) -> dict:
await graph.ainvoke(Command(resume="123"), configurable)


@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_multiple_interrupts_imperative(checkpointer_name: str) -> None:
"""Test multiple interrupts with an imperative API."""
Expand Down Expand Up @@ -6760,10 +6723,7 @@ async def graph(state: dict) -> dict:
assert counter == 3


@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_double_interrupt_subgraph(checkpointer_name: str) -> None:
class AgentState(TypedDict):
Expand Down Expand Up @@ -6877,6 +6837,7 @@ def invoke_sub_agent(state: AgentState):
]


@NEEDS_CONTEXTVARS
async def test_async_streaming_with_functional_api() -> None:
"""Test streaming with functional API.
Expand Down

0 comments on commit 0e4dbb4

Please sign in to comment.