diff --git a/langgraph/checkpoint/mysql/aio.py b/langgraph/checkpoint/mysql/aio.py index b7ed841..6177b1c 100644 --- a/langgraph/checkpoint/mysql/aio.py +++ b/langgraph/checkpoint/mysql/aio.py @@ -377,6 +377,18 @@ def list( Yields: Iterator[CheckpointTuple]: An iterator of matching checkpoint tuples. """ + try: + # check if we are in the main thread, only bg threads can block + # we don't check in other methods to avoid the overhead + if asyncio.get_running_loop() is self.loop: + raise asyncio.InvalidStateError( + "Synchronous calls to AsyncSqliteSaver are only allowed from a " + "different thread. From the main thread, use the async interface. " + "For example, use `checkpointer.alist(...)` or `await " + "graph.ainvoke(...)`." + ) + except RuntimeError: + pass aiter_ = self.alist(config, filter=filter, before=before, limit=limit) while True: try: @@ -407,7 +419,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: if asyncio.get_running_loop() is self.loop: raise asyncio.InvalidStateError( "Synchronous calls to AIOMySQLSaver are only allowed from a " - "different thread. From the main thread, use the async interface." + "different thread. From the main thread, use the async interface. " "For example, use `await checkpointer.aget_tuple(...)` or `await " "graph.ainvoke(...)`." ) diff --git a/langgraph/checkpoint/mysql/shallow.py b/langgraph/checkpoint/mysql/shallow.py index ac59229..c77f1d1 100644 --- a/langgraph/checkpoint/mysql/shallow.py +++ b/langgraph/checkpoint/mysql/shallow.py @@ -720,6 +720,18 @@ def list( on the provided config. For shallow savers, this method returns a list with ONLY the most recent checkpoint. """ + try: + # check if we are in the main thread, only bg threads can block + # we don't check in other methods to avoid the overhead + if asyncio.get_running_loop() is self.loop: + raise asyncio.InvalidStateError( + "Synchronous calls to AsyncSqliteSaver are only allowed from a " + "different thread. From the main thread, use the async interface. " + "For example, use `checkpointer.alist(...)` or `await " + "graph.ainvoke(...)`." + ) + except RuntimeError: + pass aiter_ = self.alist(config, filter=filter, before=before, limit=limit) while True: try: @@ -745,7 +757,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: if asyncio.get_running_loop() is self.loop: raise asyncio.InvalidStateError( "Synchronous calls to asynchronous shallow savers are only allowed from a " - "different thread. From the main thread, use the async interface." + "different thread. From the main thread, use the async interface. " "For example, use `await checkpointer.aget_tuple(...)` or `await " "graph.ainvoke(...)`." ) diff --git a/tests/test_async.py b/tests/test_async.py index 7818c52..5f12ebf 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -1,3 +1,4 @@ +import asyncio from collections.abc import AsyncIterator from contextlib import asynccontextmanager from copy import deepcopy @@ -17,6 +18,7 @@ ) from langgraph.checkpoint.mysql.aio import AIOMySQLSaver, ShallowAIOMySQLSaver from langgraph.checkpoint.serde.types import TASKS +from langgraph.graph import END, START, MessagesState, StateGraph from tests.conftest import DEFAULT_BASE_URI @@ -369,3 +371,22 @@ async def test_write_with_same_checkpoint_ns_updates( results = [c async for c in saver.alist({})] assert len(results) == 1 + + +@pytest.mark.parametrize("saver_name", ["base", "pool", "shallow"]) +async def test_graph_sync_get_state_history_raises(saver_name: str) -> None: + """Regression test for https://github.com/langchain-ai/langgraph/issues/2992""" + + builder = StateGraph(MessagesState) + builder.add_node("foo", lambda _: None) + builder.add_edge(START, "foo") + builder.add_edge("foo", END) + + async with _saver(saver_name) as saver: + graph = builder.compile(checkpointer=saver) + config: RunnableConfig = {"configurable": {"thread_id": "1"}} + await graph.ainvoke({"messages": []}, config) + + # this method should not hang + with pytest.raises(asyncio.exceptions.InvalidStateError): + next(graph.get_state_history(config))