Skip to content

Commit

Permalink
Support multiple args in @task's
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Jan 3, 2025
1 parent d74ec2c commit 9946211
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 30 deletions.
53 changes: 44 additions & 9 deletions libs/langgraph/langgraph/func/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio

Check notice on line 1 in libs/langgraph/langgraph/func/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 61.1 ms +- 1.3 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 52.5 ms +- 1.1 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 74.7 ms +- 1.3 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 95.1 ms +- 1.0 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 636 ms +- 36 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 527 ms +- 14 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 790 ms +- 25 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 955 ms +- 20 ms ......................................... react_agent_10x: Mean +- std dev: 30.6 ms +- 0.6 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.9 ms +- 0.3 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 38.0 ms +- 0.8 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 37.1 ms +- 0.5 ms ......................................... react_agent_100x: Mean +- std dev: 342 ms +- 6 ms ......................................... react_agent_100x_sync: Mean +- std dev: 275 ms +- 2 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 840 ms +- 11 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 834 ms +- 9 ms ......................................... wide_state_25x300: Mean +- std dev: 23.3 ms +- 0.5 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 15.2 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 276 ms +- 15 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 273 ms +- 13 ms ......................................... wide_state_15x600: Mean +- std dev: 27.2 ms +- 0.5 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 17.6 ms +- 0.1 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 474 ms +- 22 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 471 ms +- 14 ms ......................................... wide_state_9x1200: Mean +- std dev: 27.3 ms +- 0.5 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 17.8 ms +- 0.2 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 310 ms +- 15 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 307 ms +- 15 ms

Check notice on line 1 in libs/langgraph/langgraph/func/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +========================================+=========+=======================+ | fanout_to_subgraph_10x_checkpoint_sync | 96.0 ms | 95.1 ms: 1.01x faster | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x | 61.6 ms | 61.1 ms: 1.01x faster | +----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint_sync | 37.3 ms | 37.1 ms: 1.01x faster | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_sync | 52.8 ms | 52.5 ms: 1.01x faster | +----------------------------------------+---------+-----------------------+ | react_agent_10x_sync | 23.0 ms | 22.9 ms: 1.00x faster | +----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 17.7 ms | 17.6 ms: 1.00x faster | +----------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 15.2 ms | 15.2 ms: 1.00x faster | +----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 17.7 ms | 17.8 ms: 1.00x slower | +----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 836 ms | 840 ms: 1.00x slower | +----------------------------------------+---------+-----------------------+ | react_agent_100x | 340 ms | 342 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint | 74.2 ms | 74.7 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint | 37.7 ms | 38.0 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 27.0 ms | 27.3 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 516 ms | 527 ms: 1.02x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 607 ms | 636 ms: 1.05x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 751 ms | 790 ms: 1.05x slower | +----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.00x slower | +----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (12): wide_state_9x1200_checkpoint, wide_state_15x600_checkpoint_sync, react_agent_10x, react_agent_100x_sync, wide_state_25x300_checkpoint_sync, wide_state_25x300_checkpoint, react_agent_100x_checkpoint_sync, wide_state_15x600, fanout_to_subgraph_100x_checkpoint_sync, wide_state_15x600_checkpoint, wide_state_9x1200_checkpoint_sync, wide_state_25x300
import concurrent
import concurrent.futures
import functools
import inspect
import types
from functools import partial, update_wrapper
from typing import (
Any,
Awaitable,
Expand Down Expand Up @@ -33,17 +33,17 @@


def call(
func: Callable[[P1], T],
input: P1,
*,
func: Callable[P, T],
*args: Any,
retry: Optional[RetryPolicy] = None,
**kwargs: Any,
) -> concurrent.futures.Future[T]:
from langgraph.constants import CONFIG_KEY_CALL
from langgraph.utils.config import get_configurable

conf = get_configurable()
impl = conf[CONFIG_KEY_CALL]
fut = impl(func, input, retry=retry)
fut = impl(func, (args, kwargs), retry=retry)
return fut


Expand All @@ -59,16 +59,51 @@ def task( # type: ignore[overload-cannot-match]
) -> Callable[[Callable[P, T]], Callable[P, concurrent.futures.Future[T]]]: ...


@overload
def task(
*, retry: Optional[RetryPolicy] = None
__func_or_none__: Callable[P, T],
) -> Callable[P, concurrent.futures.Future[T]]: ...


@overload
def task(
__func_or_none__: Callable[P, Awaitable[T]],
) -> Callable[P, asyncio.Future[T]]: ...


def task(
__func_or_none__: Optional[Union[Callable[P, T], Callable[P, Awaitable[T]]]] = None,
*,
retry: Optional[RetryPolicy] = None,
) -> Union[
Callable[[Callable[P, Awaitable[T]]], Callable[P, asyncio.Future[T]]],
Callable[[Callable[P, T]], Callable[P, concurrent.futures.Future[T]]],
Callable[P, asyncio.Future[T]],
Callable[P, concurrent.futures.Future[T]],
]:
def _task(func: Callable[P, T]) -> Callable[P, concurrent.futures.Future[T]]:
return update_wrapper(partial(call, func, retry=retry), func)
def decorator(
func: Union[Callable[P, Awaitable[T]], Callable[P, T]],
) -> Callable[P, concurrent.futures.Future[T]]:
if asyncio.iscoroutinefunction(func):

@functools.wraps(func)
async def _tick(__allargs__: tuple) -> T:
return await func(*__allargs__[0], **__allargs__[1])

else:

@functools.wraps(func)
def _tick(__allargs__: tuple) -> T:
return func(*__allargs__[0], **__allargs__[1])

return functools.update_wrapper(
functools.partial(call, _tick, retry=retry), func
)

if __func_or_none__ is not None:
return decorator(__func_or_none__)

return _task
return decorator


def entrypoint(
Expand Down
24 changes: 13 additions & 11 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,21 +1515,21 @@ def test_imp_stream_order(
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

@task()
def foo(state: dict) -> dict:
return {"a": state["a"] + "foo", "b": "bar"}
def foo(state: dict) -> tuple:
return state["a"] + "foo", "bar"

@task()
def bar(state: dict) -> dict:
return {"a": state["a"] + state["b"], "c": "bark"}
@task
def bar(a: str, b: str, c: Optional[str] = None) -> dict:
return {"a": a + b, "c": (c or "") + "bark"}

@task()
@task
def baz(state: dict) -> dict:
return {"a": state["a"] + "baz", "c": "something else"}

@entrypoint(checkpointer=checkpointer)
def graph(state: dict) -> dict:
fut_foo = foo(state)
fut_bar = bar(fut_foo.result())
fut_bar = bar(*fut_foo.result())
fut_baz = baz(fut_bar.result())
return fut_baz.result()

Expand Down Expand Up @@ -4168,10 +4168,12 @@ def __init__(self, i: Optional[int] = None):
def __call__(self, inputs: State, config: RunnableConfig, store: BaseStore):
assert isinstance(store, BaseStore)
store.put(
namespace
if self.i is not None
and config["configurable"]["thread_id"] in (thread_1, thread_2)
else (f"foo_{self.i}", "bar"),
(
namespace
if self.i is not None
and config["configurable"]["thread_id"] in (thread_1, thread_2)
else (f"foo_{self.i}", "bar")
),
doc_id,
{
**doc,
Expand Down
21 changes: 11 additions & 10 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2571,18 +2571,18 @@ async def test_imp_sync_from_async(checkpointer_name: str) -> None:
def foo(state: dict) -> dict:
return {"a": state["a"] + "foo", "b": "bar"}

@task()
def bar(state: dict) -> dict:
return {"a": state["a"] + state["b"], "c": "bark"}
@task
def bar(a: str, b: str, c: Optional[str] = None) -> dict:
return {"a": a + b, "c": (c or "") + "bark"}

@task()
def baz(state: dict) -> dict:
return {"a": state["a"] + "baz", "c": "something else"}

@entrypoint(checkpointer=checkpointer)
def graph(state: dict) -> dict:
fut_foo = foo(state)
fut_bar = bar(fut_foo.result())
foo_result = foo(state).result()
fut_bar = bar(foo_result["a"], foo_result["b"])
fut_baz = baz(fut_bar.result())
return fut_baz.result()

Expand All @@ -2607,18 +2607,19 @@ async def test_imp_stream_order(checkpointer_name: str) -> None:
async def foo(state: dict) -> dict:
return {"a": state["a"] + "foo", "b": "bar"}

@task()
async def bar(state: dict) -> dict:
return {"a": state["a"] + state["b"], "c": "bark"}
@task
async def bar(a: str, b: str, c: Optional[str] = None) -> dict:
return {"a": a + b, "c": (c or "") + "bark"}

@task()
async def baz(state: dict) -> dict:
return {"a": state["a"] + "baz", "c": "something else"}

@entrypoint(checkpointer=checkpointer)
async def graph(state: dict) -> dict:
fut_foo = foo(state)
fut_bar = bar(await fut_foo)
foo_res = await foo(state)

fut_bar = bar(foo_res["a"], foo_res["b"])
fut_baz = baz(await fut_bar)
return await fut_baz

Expand Down

0 comments on commit 9946211

Please sign in to comment.