Skip to content

Commit

Permalink
Wait until next tick to start send task
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Dec 4, 2024
1 parent e1f6501 commit 9733db0
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
23 changes: 22 additions & 1 deletion libs/langgraph/langgraph/pregel/executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import concurrent.futures
import sys
import time
from contextlib import ExitStack
from contextvars import copy_context
from types import TracebackType
Expand Down Expand Up @@ -34,6 +35,7 @@ def __call__(
__name__: Optional[str] = None,
__cancel_on_exit__: bool = False,
__reraise_on_exit__: bool = True,
__next_tick__: bool = False,
**kwargs: P.kwargs,
) -> concurrent.futures.Future[T]: ...

Expand All @@ -58,9 +60,13 @@ def submit( # type: ignore[valid-type]
__name__: Optional[str] = None, # currently not used in sync version
__cancel_on_exit__: bool = False, # for sync, can cancel only if not started
__reraise_on_exit__: bool = True,
__next_tick__: bool = False,
**kwargs: P.kwargs,
) -> concurrent.futures.Future[T]:
task = self.executor.submit(fn, *args, **kwargs)
if __next_tick__:
task = self.executor.submit(next_tick, fn, *args, **kwargs)
else:
task = self.executor.submit(fn, *args, **kwargs)
self.tasks[task] = (__cancel_on_exit__, __reraise_on_exit__)
task.add_done_callback(self.done)
return task
Expand Down Expand Up @@ -137,11 +143,14 @@ def submit( # type: ignore[valid-type]
__name__: Optional[str] = None,
__cancel_on_exit__: bool = False,
__reraise_on_exit__: bool = True,
__next_tick__: bool = False,
**kwargs: P.kwargs,
) -> asyncio.Task[T]:
coro = cast(Coroutine[None, None, T], fn(*args, **kwargs))
if self.semaphore:
coro = gated(self.semaphore, coro)
if __next_tick__:
coro = anext_tick(coro)
if self.context_not_supported:
task = self.loop.create_task(coro, name=__name__)
else:
Expand Down Expand Up @@ -197,3 +206,15 @@ async def gated(semaphore: asyncio.Semaphore, coro: Coroutine[None, None, T]) ->
"""A coroutine that waits for a semaphore before running another coroutine."""
async with semaphore:
return await coro


def next_tick(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
"""A function that yields control to other threads before running another function."""
time.sleep(0)
return fn(*args, **kwargs)


async def anext_tick(coro: Coroutine[None, None, T]) -> T:
"""A coroutine that yields control to event loop before running another coroutine."""
await asyncio.sleep(0)
return await coro
5 changes: 5 additions & 0 deletions libs/langgraph/langgraph/pregel/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def writer(
CONFIG_KEY_CALL: partial(call, next_task),
},
__reraise_on_exit__=reraise,
__next_tick__=True,
)
fut.add_done_callback(partial(self.commit, next_task))
futures[fut] = next_task
Expand Down Expand Up @@ -228,6 +229,10 @@ def call(
break
# give control back to the caller
yield
# wait for pending done callbacks
# if a 2nd future finishes while `wait` is returning, it's possible
# that done callbacks for the 2nd future aren't called until next tick
time.sleep(0)
# panic on failure or timeout
_panic_or_proceed(
done_futures.union(f for f, t in futures.items() if t is not None),
Expand Down
1 change: 1 addition & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5837,6 +5837,7 @@ class AgentState(TypedDict):
@tool()
def search_api(query: str) -> str:
"""Searches the API for the query."""
time.sleep(0.1)
return f"result for {query}"

tools = [search_api]
Expand Down

0 comments on commit 9733db0

Please sign in to comment.