From d163938a59261870b392bf164fe8d31cf41f6e59 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 16 Sep 2024 17:36:53 -0700 Subject: [PATCH 1/5] Stream subgraph output while it executes --- libs/langgraph/langgraph/pregel/__init__.py | 14 ++++-- libs/langgraph/langgraph/pregel/runner.py | 53 ++++++++++++++------- libs/langgraph/langgraph/utils/aio.py | 25 ++++++++++ 3 files changed, 71 insertions(+), 21 deletions(-) create mode 100644 libs/langgraph/langgraph/utils/aio.py diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index c0c5e0bcb..338842cde 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -83,6 +83,7 @@ from langgraph.pregel.validate import validate_graph, validate_keys from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry from langgraph.store.base import BaseStore +from langgraph.utils.aio import Queue from langgraph.utils.config import ( ensure_config, merge_configs, @@ -1323,11 +1324,14 @@ async def astream( ``` """ - stream = deque() + stream = Queue() def output() -> Iterator: - while stream: - ns, mode, payload = stream.popleft() + while True: + try: + ns, mode, payload = stream.get_nowait() + except asyncio.QueueEmpty: + break if subgraphs and isinstance(stream_mode, list): yield (ns, mode, payload) elif isinstance(stream_mode, list): @@ -1337,6 +1341,7 @@ def output() -> Iterator: else: yield payload + aioloop = asyncio.get_event_loop() config = ensure_config(self.config, config) callback_manager = get_async_callback_manager_for_config(config) run_manager = await callback_manager.on_chain_start( @@ -1379,7 +1384,7 @@ def output() -> Iterator: ) async with AsyncPregelLoop( input, - stream=StreamProtocol(stream.append, stream_modes), + stream=StreamProtocol(stream.put_nowait, stream_modes), config=config, store=self.store, checkpointer=checkpointer, @@ -1412,6 +1417,7 @@ def output() -> Iterator: loop.tasks.values(), timeout=self.step_timeout, retry_policy=self.retry_policy, + extra=lambda: aioloop.create_task(stream.wait()), ): # emit output for o in output(): diff --git a/libs/langgraph/langgraph/pregel/runner.py b/libs/langgraph/langgraph/pregel/runner.py index 7a09afde3..1a6e5f84f 100644 --- a/libs/langgraph/langgraph/pregel/runner.py +++ b/libs/langgraph/langgraph/pregel/runner.py @@ -100,29 +100,36 @@ async def atick( reraise: bool = True, timeout: Optional[float] = None, retry_policy: Optional[RetryPolicy] = None, + extra: Optional[Callable[[], asyncio.Future[None]]] = None, ) -> AsyncIterator[None]: loop = asyncio.get_event_loop() # give control back to the caller yield + if extra is not None: + futures: dict[asyncio.Future, Optional[PregelExecutableTask]] = { + extra(): None + } + else: + futures = {} # execute tasks, and wait for one to fail or all to finish. # each task is independent from all other concurrent tasks # yield updates/debug output as each task finishes - futures = { - self.submit( - arun_with_retry, - task, - retry_policy, - stream=self.use_astream, - __name__=task.name, - __cancel_on_exit__=True, - __reraise_on_exit__=reraise, - ): task - for task in tasks - if not task.writes - } + for task in tasks: + if not task.writes: + futures[ + self.submit( + arun_with_retry, + task, + retry_policy, + stream=self.use_astream, + __name__=task.name, + __cancel_on_exit__=True, + __reraise_on_exit__=reraise, + ) + ] = task all_futures = futures.copy() end_time = timeout + loop.time() if timeout else None - while futures: + while len(futures) > (1 if extra is not None else 0): done, _ = await asyncio.wait( futures, return_when=asyncio.FIRST_COMPLETED, @@ -132,6 +139,10 @@ async def atick( break # timed out for fut in done: task = futures.pop(fut) + if task is None: + # extra task finished, schedule another + futures[extra()] = None + continue if exc := _exception(fut): if isinstance(exc, GraphInterrupt): # save interrupt to checkpointer @@ -156,6 +167,9 @@ async def atick( break # give control back to the caller yield + # cancel extra task + for fut in futures: + fut.cancel() # panic on failure or timeout _panic_or_proceed( all_futures, timeout_exc_cls=asyncio.TimeoutError, panic=reraise @@ -187,15 +201,20 @@ def _exception( def _panic_or_proceed( - futs: Union[set[concurrent.futures.Future[Any]], set[asyncio.Task[Any]]], + futs: Union[ + dict[concurrent.futures.Future, Optional[PregelExecutableTask]], + dict[asyncio.Future, Optional[PregelExecutableTask]], + ], *, timeout_exc_cls: Type[Exception] = TimeoutError, panic: bool = True, ) -> None: done: set[Union[concurrent.futures.Future[Any], asyncio.Task[Any]]] = set() inflight: set[Union[concurrent.futures.Future[Any], asyncio.Task[Any]]] = set() - for fut in futs: - if fut.done(): + for fut, val in futs.items(): + if val is None: + continue + elif fut.done(): done.add(fut) else: inflight.add(fut) diff --git a/libs/langgraph/langgraph/utils/aio.py b/libs/langgraph/langgraph/utils/aio.py new file mode 100644 index 000000000..3d0f71e06 --- /dev/null +++ b/libs/langgraph/langgraph/utils/aio.py @@ -0,0 +1,25 @@ +import asyncio + + +class Queue(asyncio.Queue): + async def wait(self): + """If queue is empty, wait until an item is available.""" + while self.empty(): + getter = self._get_loop().create_future() + self._getters.append(getter) + try: + await getter + except: + getter.cancel() # Just in case getter is not done yet. + try: + # Clean self._getters from canceled getters. + self._getters.remove(getter) + except ValueError: + # The getter could be removed from self._getters by a + # previous put_nowait call. + pass + if not self.empty() and not getter.cancelled(): + # We were woken up by put_nowait(), but can't take + # the call. Wake up the next in line. + self._wakeup_next(self._getters) + raise From 54207f4d893bb507d950012cd53100efa811004f Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 16 Sep 2024 18:02:57 -0700 Subject: [PATCH 2/5] Add test --- libs/langgraph/tests/test_pregel_async.py | 76 +++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 86617146e..a3a1510c2 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -4,6 +4,7 @@ import sys from collections import Counter from contextlib import asynccontextmanager, contextmanager +from time import perf_counter from typing import ( Annotated, Any, @@ -363,6 +364,7 @@ async def iambad(input: State) -> None: assert awhiles == 2 +@pytest.mark.repeat(10) async def test_step_timeout_on_stream_hang() -> None: inner_task_cancelled = False @@ -6958,6 +6960,80 @@ async def side(state: State): assert times_called == 1 +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +async def test_stream_subgraphs_during_execution(checkpointer_name: str) -> None: + class InnerState(TypedDict): + my_key: Annotated[str, operator.add] + my_other_key: str + + async def inner_1(state: InnerState): + return {"my_key": "got here", "my_other_key": state["my_key"]} + + async def inner_2(state: InnerState): + await asyncio.sleep(0.5) + return { + "my_key": " and there", + "my_other_key": state["my_key"], + } + + inner = StateGraph(InnerState) + inner.add_node("inner_1", inner_1) + inner.add_node("inner_2", inner_2) + inner.add_edge("inner_1", "inner_2") + inner.set_entry_point("inner_1") + inner.set_finish_point("inner_2") + + class State(TypedDict): + my_key: Annotated[str, operator.add] + + async def outer_1(state: State): + await asyncio.sleep(0.2) + return {"my_key": " and parallel"} + + async def outer_2(state: State): + return {"my_key": " and back again"} + + graph = StateGraph(State) + graph.add_node("inner", inner.compile()) + graph.add_node("outer_1", outer_1) + graph.add_node("outer_2", outer_2) + + graph.add_edge(START, "inner") + graph.add_edge(START, "outer_1") + graph.add_edge(["inner", "outer_1"], "outer_2") + graph.add_edge("outer_2", END) + + async with awith_checkpointer(checkpointer_name) as checkpointer: + app = graph.compile(checkpointer=checkpointer) + + start = perf_counter() + chunks: list[tuple[float, Any]] = [] + config = {"configurable": {"thread_id": "2"}} + async for c in app.astream({"my_key": ""}, config, subgraphs=True): + chunks.append((round(perf_counter() - start, 1), c)) + + assert chunks == [ + # arrives before "inner" finishes + ( + 0.0, + ( + (AnyStr("inner:"),), + {"inner_1": {"my_key": "got here", "my_other_key": ""}}, + ), + ), + (0.2, ((), {"outer_1": {"my_key": " and parallel"}})), + ( + 0.5, + ( + (AnyStr("inner:"),), + {"inner_2": {"my_key": " and there", "my_other_key": "got here"}}, + ), + ), + (0.5, ((), {"inner": {"my_key": "got here and there"}})), + (0.5, ((), {"outer_2": {"my_key": " and back again"}})), + ] + + @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_nested_graph_interrupts_parallel(checkpointer_name: str) -> None: class InnerState(TypedDict): From c30119824c6261ab90c86682d27479ad70a55396 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 16 Sep 2024 18:06:06 -0700 Subject: [PATCH 3/5] Add comment --- libs/langgraph/langgraph/pregel/runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/langgraph/langgraph/pregel/runner.py b/libs/langgraph/langgraph/pregel/runner.py index 1a6e5f84f..d4064f9f3 100644 --- a/libs/langgraph/langgraph/pregel/runner.py +++ b/libs/langgraph/langgraph/pregel/runner.py @@ -105,6 +105,7 @@ async def atick( loop = asyncio.get_event_loop() # give control back to the caller yield + # add extra task if requested if extra is not None: futures: dict[asyncio.Future, Optional[PregelExecutableTask]] = { extra(): None From bd2ecba62240d68ee3262fc9949ccd963180e02c Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 17 Sep 2024 09:29:16 -0700 Subject: [PATCH 4/5] Fix for py 3.9 --- libs/langgraph/langgraph/utils/aio.py | 8 +++++++- libs/langgraph/poetry.lock | 6 +++--- libs/langgraph/pyproject.toml | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/libs/langgraph/langgraph/utils/aio.py b/libs/langgraph/langgraph/utils/aio.py index 3d0f71e06..bcddd11e3 100644 --- a/libs/langgraph/langgraph/utils/aio.py +++ b/libs/langgraph/langgraph/utils/aio.py @@ -1,11 +1,17 @@ import asyncio +import sys + +PY_310 = sys.version_info >= (3, 10) class Queue(asyncio.Queue): async def wait(self): """If queue is empty, wait until an item is available.""" while self.empty(): - getter = self._get_loop().create_future() + if PY_310: + getter = self._get_loop().create_future() + else: + getter = self._loop.create_future() self._getters.append(getter) try: await getter diff --git a/libs/langgraph/poetry.lock b/libs/langgraph/poetry.lock index b6f161afa..ba50565a8 100644 --- a/libs/langgraph/poetry.lock +++ b/libs/langgraph/poetry.lock @@ -1238,7 +1238,7 @@ typing-extensions = ">=4.7" [[package]] name = "langgraph-checkpoint" -version = "1.0.9" +version = "1.0.10" description = "Library with base interfaces for LangGraph checkpoint savers." optional = false python-versions = "^3.9.0,<4.0" @@ -1255,7 +1255,7 @@ url = "../checkpoint" [[package]] name = "langgraph-checkpoint-postgres" -version = "1.0.6" +version = "1.0.7" description = "Library with a Postgres implementation of LangGraph checkpoint saver." optional = false python-versions = "^3.9.0,<4.0" @@ -3202,4 +3202,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<4.0" -content-hash = "f72e42e6f957927f9acf19a838ad5331eb2ed11f0e58df3a6d3240eafb3dc057" +content-hash = "73c2dec0a0e833ad8742ebfca86d8e3d602a8a63671a782d21d8e0079a02d448" diff --git a/libs/langgraph/pyproject.toml b/libs/langgraph/pyproject.toml index 655970f60..a583db910 100644 --- a/libs/langgraph/pyproject.toml +++ b/libs/langgraph/pyproject.toml @@ -28,7 +28,7 @@ pytest-repeat = "^0.9.3" langgraph-checkpoint = {path = "../checkpoint", develop = true} langgraph-checkpoint-sqlite = {path = "../checkpoint-sqlite", develop = true} langgraph-checkpoint-postgres = {path = "../checkpoint-postgres", develop = true} -psycopg = {extras = ["binary"], version = ">=3.0.0"} +psycopg = {extras = ["binary"], version = ">=3.0.0", python = ">=3.10"} uvloop = "^0.20.0" pyperf = "^2.7.0" py-spy = "^0.3.14" From f59435a892e96fb9092e0055a6df2e1dfc5111a9 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 17 Sep 2024 09:31:48 -0700 Subject: [PATCH 5/5] Add more comments --- libs/langgraph/langgraph/pregel/__init__.py | 2 +- libs/langgraph/langgraph/pregel/runner.py | 16 ++++++++-------- libs/langgraph/langgraph/utils/aio.py | 6 +++++- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index 338842cde..dc803c735 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -1417,7 +1417,7 @@ def output() -> Iterator: loop.tasks.values(), timeout=self.step_timeout, retry_policy=self.retry_policy, - extra=lambda: aioloop.create_task(stream.wait()), + get_waiter=lambda: aioloop.create_task(stream.wait()), ): # emit output for o in output(): diff --git a/libs/langgraph/langgraph/pregel/runner.py b/libs/langgraph/langgraph/pregel/runner.py index d4064f9f3..b3bea32c4 100644 --- a/libs/langgraph/langgraph/pregel/runner.py +++ b/libs/langgraph/langgraph/pregel/runner.py @@ -100,15 +100,15 @@ async def atick( reraise: bool = True, timeout: Optional[float] = None, retry_policy: Optional[RetryPolicy] = None, - extra: Optional[Callable[[], asyncio.Future[None]]] = None, + get_waiter: Optional[Callable[[], asyncio.Future[None]]] = None, ) -> AsyncIterator[None]: loop = asyncio.get_event_loop() # give control back to the caller yield - # add extra task if requested - if extra is not None: + # add waiter task if requested + if get_waiter is not None: futures: dict[asyncio.Future, Optional[PregelExecutableTask]] = { - extra(): None + get_waiter(): None } else: futures = {} @@ -130,7 +130,7 @@ async def atick( ] = task all_futures = futures.copy() end_time = timeout + loop.time() if timeout else None - while len(futures) > (1 if extra is not None else 0): + while len(futures) > (1 if get_waiter is not None else 0): done, _ = await asyncio.wait( futures, return_when=asyncio.FIRST_COMPLETED, @@ -141,8 +141,8 @@ async def atick( for fut in done: task = futures.pop(fut) if task is None: - # extra task finished, schedule another - futures[extra()] = None + # waiter task finished, schedule another + futures[get_waiter()] = None continue if exc := _exception(fut): if isinstance(exc, GraphInterrupt): @@ -168,7 +168,7 @@ async def atick( break # give control back to the caller yield - # cancel extra task + # cancel waiter task for fut in futures: fut.cancel() # panic on failure or timeout diff --git a/libs/langgraph/langgraph/utils/aio.py b/libs/langgraph/langgraph/utils/aio.py index bcddd11e3..afe02c050 100644 --- a/libs/langgraph/langgraph/utils/aio.py +++ b/libs/langgraph/langgraph/utils/aio.py @@ -6,7 +6,11 @@ class Queue(asyncio.Queue): async def wait(self): - """If queue is empty, wait until an item is available.""" + """If queue is empty, wait until an item is available. + + Copied from Queue.get(), removing the call to .get_nowait(), + ie. this doesn't consume the item, just waits for it. + """ while self.empty(): if PY_310: getter = self._get_loop().create_future()