From 76ff7c3f1a415d63db0fe2614562df445c040788 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 17 Sep 2024 10:19:19 -0700 Subject: [PATCH] perf: Implement fast path in Runner for single-task steps - if the current step has a single task then we can execute it inline, instead of moving it to asyncio task / bg thread --- libs/langgraph/langgraph/pregel/__init__.py | 11 ++- libs/langgraph/langgraph/pregel/runner.py | 84 ++++++++++++--------- 2 files changed, 59 insertions(+), 36 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index dc803c735..7b6d7d3ee 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -1325,6 +1325,7 @@ async def astream( """ stream = Queue() + aioloop = asyncio.get_running_loop() def output() -> Iterator: while True: @@ -1341,7 +1342,13 @@ def output() -> Iterator: else: yield payload - aioloop = asyncio.get_event_loop() + if subgraphs: + + def get_waiter() -> asyncio.Task[None]: + return aioloop.create_task(stream.wait()) + else: + get_waiter = None + config = ensure_config(self.config, config) callback_manager = get_async_callback_manager_for_config(config) run_manager = await callback_manager.on_chain_start( @@ -1417,7 +1424,7 @@ def output() -> Iterator: loop.tasks.values(), timeout=self.step_timeout, retry_policy=self.retry_policy, - get_waiter=lambda: aioloop.create_task(stream.wait()), + get_waiter=get_waiter, ): # emit output for o in output(): diff --git a/libs/langgraph/langgraph/pregel/runner.py b/libs/langgraph/langgraph/pregel/runner.py index b3bea32c4..fc2d7e464 100644 --- a/libs/langgraph/langgraph/pregel/runner.py +++ b/libs/langgraph/langgraph/pregel/runner.py @@ -33,14 +33,26 @@ def __init__( def tick( self, - tasks: list[PregelExecutableTask], + tasks: Sequence[PregelExecutableTask], *, reraise: bool = True, timeout: Optional[float] = None, retry_policy: Optional[RetryPolicy] = None, ) -> Iterator[None]: + tasks = tuple(tasks) # give control back to the caller yield + # fast path if single task with no timeout + if len(tasks) == 1 and timeout is None: + task = tasks[0] + try: + run_with_retry(task, retry_policy) + self.commit(task, None) + except Exception as exc: + self.commit(task, exc) + if reraise: + raise + return # execute tasks, and wait for one to fail or all to finish. # each task is independent from all other concurrent tasks # yield updates/debug output as each task finishes @@ -66,22 +78,8 @@ def tick( break # timed out for fut in done: task = futures.pop(fut) - if exc := _exception(fut): - if isinstance(exc, GraphInterrupt): - # save interrupt to checkpointer - if interrupts := [(INTERRUPT, i) for i in exc.args[0]]: - self.put_writes(task.id, interrupts) - elif isinstance(exc, GraphDelegate): - raise exc - else: - # save error to checkpointer - self.put_writes(task.id, [(ERROR, exc)]) - else: - if not task.writes: - # add no writes marker - task.writes.append((NO_WRITES, None)) - # save task writes to checkpointer - self.put_writes(task.id, task.writes) + # task finished, commit writes + self.commit(task, _exception(fut)) else: # remove references to loop vars del fut, task @@ -95,7 +93,7 @@ def tick( async def atick( self, - tasks: list[PregelExecutableTask], + tasks: Sequence[PregelExecutableTask], *, reraise: bool = True, timeout: Optional[float] = None, @@ -103,8 +101,20 @@ async def atick( get_waiter: Optional[Callable[[], asyncio.Future[None]]] = None, ) -> AsyncIterator[None]: loop = asyncio.get_event_loop() + tasks = tuple(tasks) # give control back to the caller yield + # fast path if single task with no waiter and no timeout + if len(tasks) == 1 and get_waiter is None and timeout is None: + task = tasks[0] + try: + await arun_with_retry(task, retry_policy, stream=self.use_astream) + self.commit(task, None) + except Exception as exc: + self.commit(task, exc) + if reraise: + raise + return # add waiter task if requested if get_waiter is not None: futures: dict[asyncio.Future, Optional[PregelExecutableTask]] = { @@ -143,23 +153,9 @@ async def atick( if task is None: # waiter task finished, schedule another futures[get_waiter()] = None - continue - if exc := _exception(fut): - if isinstance(exc, GraphInterrupt): - # save interrupt to checkpointer - if interrupts := [(INTERRUPT, i) for i in exc.args[0]]: - self.put_writes(task.id, interrupts) - elif isinstance(exc, GraphDelegate): - raise exc - else: - # save error to checkpointer - self.put_writes(task.id, [(ERROR, exc)]) else: - if not task.writes: - # add no writes marker - task.writes.append((NO_WRITES, None)) - # save task writes to checkpointer - self.put_writes(task.id, task.writes) + # task finished, commit writes + self.commit(task, _exception(fut)) else: # remove references to loop vars del fut, task @@ -176,6 +172,26 @@ async def atick( all_futures, timeout_exc_cls=asyncio.TimeoutError, panic=reraise ) + def commit( + self, task: PregelExecutableTask, exception: Optional[BaseException] + ) -> None: + if exception: + if isinstance(exception, GraphInterrupt): + # save interrupt to checkpointer + if interrupts := [(INTERRUPT, i) for i in exception.args[0]]: + self.put_writes(task.id, interrupts) + elif isinstance(exception, GraphDelegate): + raise exception + else: + # save error to checkpointer + self.put_writes(task.id, [(ERROR, exception)]) + else: + if not task.writes: + # add no writes marker + task.writes.append((NO_WRITES, None)) + # save task writes to checkpointer + self.put_writes(task.id, task.writes) + def _should_stop_others( done: Union[set[concurrent.futures.Future[Any]], set[asyncio.Task[Any]]],