Skip to content

Commit

Permalink
Merge pull request #1743 from langchain-ai/nc/17sep/runner-fast-path
Browse files Browse the repository at this point in the history
perf: Implement fast path in Runner for single-task steps
  • Loading branch information
nfcampos authored Sep 17, 2024
2 parents 47a68c1 + 76ff7c3 commit f03dc2f
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 36 deletions.
11 changes: 9 additions & 2 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,7 @@ async def astream(
"""

stream = Queue()
aioloop = asyncio.get_running_loop()

def output() -> Iterator:
while True:
Expand All @@ -1341,7 +1342,13 @@ def output() -> Iterator:
else:
yield payload

aioloop = asyncio.get_event_loop()
if subgraphs:

def get_waiter() -> asyncio.Task[None]:
return aioloop.create_task(stream.wait())
else:
get_waiter = None

config = ensure_config(self.config, config)
callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start(
Expand Down Expand Up @@ -1417,7 +1424,7 @@ def output() -> Iterator:
loop.tasks.values(),
timeout=self.step_timeout,
retry_policy=self.retry_policy,
get_waiter=lambda: aioloop.create_task(stream.wait()),
get_waiter=get_waiter,
):
# emit output
for o in output():
Expand Down
84 changes: 50 additions & 34 deletions libs/langgraph/langgraph/pregel/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,26 @@ def __init__(

def tick(
self,
tasks: list[PregelExecutableTask],
tasks: Sequence[PregelExecutableTask],
*,
reraise: bool = True,
timeout: Optional[float] = None,
retry_policy: Optional[RetryPolicy] = None,
) -> Iterator[None]:
tasks = tuple(tasks)
# give control back to the caller
yield
# fast path if single task with no timeout
if len(tasks) == 1 and timeout is None:
task = tasks[0]
try:
run_with_retry(task, retry_policy)
self.commit(task, None)
except Exception as exc:
self.commit(task, exc)
if reraise:
raise
return
# execute tasks, and wait for one to fail or all to finish.
# each task is independent from all other concurrent tasks
# yield updates/debug output as each task finishes
Expand All @@ -66,22 +78,8 @@ def tick(
break # timed out
for fut in done:
task = futures.pop(fut)
if exc := _exception(fut):
if isinstance(exc, GraphInterrupt):
# save interrupt to checkpointer
if interrupts := [(INTERRUPT, i) for i in exc.args[0]]:
self.put_writes(task.id, interrupts)
elif isinstance(exc, GraphDelegate):
raise exc
else:
# save error to checkpointer
self.put_writes(task.id, [(ERROR, exc)])
else:
if not task.writes:
# add no writes marker
task.writes.append((NO_WRITES, None))
# save task writes to checkpointer
self.put_writes(task.id, task.writes)
# task finished, commit writes
self.commit(task, _exception(fut))
else:
# remove references to loop vars
del fut, task
Expand All @@ -95,16 +93,28 @@ def tick(

async def atick(
self,
tasks: list[PregelExecutableTask],
tasks: Sequence[PregelExecutableTask],
*,
reraise: bool = True,
timeout: Optional[float] = None,
retry_policy: Optional[RetryPolicy] = None,
get_waiter: Optional[Callable[[], asyncio.Future[None]]] = None,
) -> AsyncIterator[None]:
loop = asyncio.get_event_loop()
tasks = tuple(tasks)
# give control back to the caller
yield
# fast path if single task with no waiter and no timeout
if len(tasks) == 1 and get_waiter is None and timeout is None:
task = tasks[0]
try:
await arun_with_retry(task, retry_policy, stream=self.use_astream)
self.commit(task, None)
except Exception as exc:
self.commit(task, exc)
if reraise:
raise
return
# add waiter task if requested
if get_waiter is not None:
futures: dict[asyncio.Future, Optional[PregelExecutableTask]] = {
Expand Down Expand Up @@ -143,23 +153,9 @@ async def atick(
if task is None:
# waiter task finished, schedule another
futures[get_waiter()] = None
continue
if exc := _exception(fut):
if isinstance(exc, GraphInterrupt):
# save interrupt to checkpointer
if interrupts := [(INTERRUPT, i) for i in exc.args[0]]:
self.put_writes(task.id, interrupts)
elif isinstance(exc, GraphDelegate):
raise exc
else:
# save error to checkpointer
self.put_writes(task.id, [(ERROR, exc)])
else:
if not task.writes:
# add no writes marker
task.writes.append((NO_WRITES, None))
# save task writes to checkpointer
self.put_writes(task.id, task.writes)
# task finished, commit writes
self.commit(task, _exception(fut))
else:
# remove references to loop vars
del fut, task
Expand All @@ -176,6 +172,26 @@ async def atick(
all_futures, timeout_exc_cls=asyncio.TimeoutError, panic=reraise
)

def commit(
self, task: PregelExecutableTask, exception: Optional[BaseException]
) -> None:
if exception:
if isinstance(exception, GraphInterrupt):
# save interrupt to checkpointer
if interrupts := [(INTERRUPT, i) for i in exception.args[0]]:
self.put_writes(task.id, interrupts)
elif isinstance(exception, GraphDelegate):
raise exception
else:
# save error to checkpointer
self.put_writes(task.id, [(ERROR, exception)])
else:
if not task.writes:
# add no writes marker
task.writes.append((NO_WRITES, None))
# save task writes to checkpointer
self.put_writes(task.id, task.writes)


def _should_stop_others(
done: Union[set[concurrent.futures.Future[Any]], set[asyncio.Task[Any]]],
Expand Down

0 comments on commit f03dc2f

Please sign in to comment.