Skip to content

Commit

Permalink
Support multiple args in @task's
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Jan 2, 2025
1 parent d74ec2c commit dc94d93
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 31 deletions.
53 changes: 44 additions & 9 deletions libs/langgraph/langgraph/func/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio

Check notice on line 1 in libs/langgraph/langgraph/func/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 59.8 ms +- 1.6 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 51.0 ms +- 1.1 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 72.2 ms +- 1.7 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 91.2 ms +- 2.0 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 580 ms +- 22 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 503 ms +- 9 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 715 ms +- 14 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 916 ms +- 18 ms ......................................... react_agent_10x: Mean +- std dev: 30.2 ms +- 0.7 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.2 ms +- 0.4 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 36.6 ms +- 0.9 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 35.9 ms +- 0.8 ms ......................................... react_agent_100x: Mean +- std dev: 332 ms +- 8 ms ......................................... react_agent_100x_sync: Mean +- std dev: 268 ms +- 4 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 810 ms +- 9 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 804 ms +- 8 ms ......................................... wide_state_25x300: Mean +- std dev: 22.6 ms +- 0.6 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 14.9 ms +- 0.4 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 268 ms +- 13 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 265 ms +- 12 ms ......................................... wide_state_15x600: Mean +- std dev: 26.5 ms +- 0.7 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 17.3 ms +- 0.6 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 458 ms +- 14 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 459 ms +- 15 ms ......................................... wide_state_9x1200: Mean +- std dev: 26.3 ms +- 0.6 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 17.3 ms +- 0.3 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 299 ms +- 13 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 296 ms +- 12 ms

Check notice on line 1 in libs/langgraph/langgraph/func/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +=========================================+=========+=======================+ | fanout_to_subgraph_10x_checkpoint_sync | 96.0 ms | 91.2 ms: 1.05x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 751 ms | 715 ms: 1.05x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 607 ms | 580 ms: 1.05x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint | 312 ms | 299 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint_sync | 953 ms | 916 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint_sync | 37.3 ms | 35.9 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_sync | 23.0 ms | 22.2 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 832 ms | 804 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint_sync | 306 ms | 296 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_sync | 52.8 ms | 51.0 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint_sync | 273 ms | 265 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 473 ms | 458 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 836 ms | 810 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x | 61.6 ms | 59.8 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint | 37.7 ms | 36.6 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint_sync | 472 ms | 459 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint | 276 ms | 268 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint | 74.2 ms | 72.2 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 276 ms | 268 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 27.0 ms | 26.3 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300 | 23.2 ms | 22.6 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 516 ms | 503 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 15.2 ms | 14.9 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 17.7 ms | 17.3 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600 | 27.1 ms | 26.5 ms: 1.02x faster | +---------------------------------------
import concurrent
import concurrent.futures
import functools
import inspect
import types
from functools import partial, update_wrapper
from typing import (
Any,
Awaitable,
Expand Down Expand Up @@ -33,17 +33,17 @@


def call(
func: Callable[[P1], T],
input: P1,
*,
func: Callable[P, T],
*args: Any,
retry: Optional[RetryPolicy] = None,
**kwargs: Any,
) -> concurrent.futures.Future[T]:
from langgraph.constants import CONFIG_KEY_CALL
from langgraph.utils.config import get_configurable

conf = get_configurable()
impl = conf[CONFIG_KEY_CALL]
fut = impl(func, input, retry=retry)
fut = impl(func, (args, kwargs), retry=retry)
return fut


Expand All @@ -59,16 +59,51 @@ def task( # type: ignore[overload-cannot-match]
) -> Callable[[Callable[P, T]], Callable[P, concurrent.futures.Future[T]]]: ...


@overload
def task(
*, retry: Optional[RetryPolicy] = None
__func_or_none__: Callable[P, T],
) -> Callable[P, concurrent.futures.Future[T]]: ...


@overload
def task(
__func_or_none__: Callable[P, Awaitable[T]],
) -> Callable[P, asyncio.Future[T]]: ...


def task(
__func_or_none__: Optional[Union[Callable[P, T], Callable[P, Awaitable[T]]]] = None,
*,
retry: Optional[RetryPolicy] = None,
) -> Union[
Callable[[Callable[P, Awaitable[T]]], Callable[P, asyncio.Future[T]]],
Callable[[Callable[P, T]], Callable[P, concurrent.futures.Future[T]]],
Callable[P, asyncio.Future[T]],
Callable[P, concurrent.futures.Future[T]],
]:
def _task(func: Callable[P, T]) -> Callable[P, concurrent.futures.Future[T]]:
return update_wrapper(partial(call, func, retry=retry), func)
def decorator(
func: Union[Callable[P, Awaitable[T]], Callable[P, T]],
) -> Callable[P, concurrent.futures.Future[T]]:
if asyncio.iscoroutinefunction(func):

@functools.wraps(func)
async def _tick(__allargs__: tuple) -> T:
return await func(*__allargs__[0], **__allargs__[1])

else:

@functools.wraps(func)
def _tick(__allargs__: tuple) -> T:
return func(*__allargs__[0], **__allargs__[1])

return functools.update_wrapper(
functools.partial(call, _tick, retry=retry), func
)

if __func_or_none__ is not None:
return decorator(__func_or_none__)

return _task
return decorator


def entrypoint(
Expand Down
26 changes: 14 additions & 12 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,21 +1515,21 @@ def test_imp_stream_order(
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

@task()
def foo(state: dict) -> dict:
return {"a": state["a"] + "foo", "b": "bar"}
def foo(state: dict) -> tuple:
return state["a"] + "foo", "bar"

@task()
def bar(state: dict) -> dict:
return {"a": state["a"] + state["b"], "c": "bark"}
@task
def bar(a: str, b: str, c: Optional[str] = None) -> dict:
return {"a": a + b, "c": (c or "") + "bark"}

@task()
@task
def baz(state: dict) -> dict:
return {"a": state["a"] + "baz", "c": "something else"}

@entrypoint(checkpointer=checkpointer)
def graph(state: dict) -> dict:
fut_foo = foo(state)
fut_bar = bar(fut_foo.result())
fut_foo = foo(state).result()
fut_bar = bar(*fut_foo.result())
fut_baz = baz(fut_bar.result())
return fut_baz.result()

Expand Down Expand Up @@ -4168,10 +4168,12 @@ def __init__(self, i: Optional[int] = None):
def __call__(self, inputs: State, config: RunnableConfig, store: BaseStore):
assert isinstance(store, BaseStore)
store.put(
namespace
if self.i is not None
and config["configurable"]["thread_id"] in (thread_1, thread_2)
else (f"foo_{self.i}", "bar"),
(
namespace
if self.i is not None
and config["configurable"]["thread_id"] in (thread_1, thread_2)
else (f"foo_{self.i}", "bar")
),
doc_id,
{
**doc,
Expand Down
21 changes: 11 additions & 10 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2571,18 +2571,18 @@ async def test_imp_sync_from_async(checkpointer_name: str) -> None:
def foo(state: dict) -> dict:
return {"a": state["a"] + "foo", "b": "bar"}

@task()
def bar(state: dict) -> dict:
return {"a": state["a"] + state["b"], "c": "bark"}
@task
def bar(a: str, b: str, c: Optional[str] = None) -> dict:
return {"a": a + b, "c": (c or "") + "bark"}

@task()
def baz(state: dict) -> dict:
return {"a": state["a"] + "baz", "c": "something else"}

@entrypoint(checkpointer=checkpointer)
def graph(state: dict) -> dict:
fut_foo = foo(state)
fut_bar = bar(fut_foo.result())
foo_result = foo(state).result()
fut_bar = bar(foo_result["a"], foo_result["b"])
fut_baz = baz(fut_bar.result())
return fut_baz.result()

Expand All @@ -2607,18 +2607,19 @@ async def test_imp_stream_order(checkpointer_name: str) -> None:
async def foo(state: dict) -> dict:
return {"a": state["a"] + "foo", "b": "bar"}

@task()
async def bar(state: dict) -> dict:
return {"a": state["a"] + state["b"], "c": "bark"}
@task
async def bar(a: str, b: str, c: Optional[str] = None) -> dict:
return {"a": a + b, "c": (c or "") + "bark"}

@task()
async def baz(state: dict) -> dict:
return {"a": state["a"] + "baz", "c": "something else"}

@entrypoint(checkpointer=checkpointer)
async def graph(state: dict) -> dict:
fut_foo = foo(state)
fut_bar = bar(await fut_foo)
foo_res = await foo(state)

fut_bar = bar(foo_res["a"], foo_res["b"])
fut_baz = baz(await fut_bar)
return await fut_baz

Expand Down

0 comments on commit dc94d93

Please sign in to comment.