Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple args in @task's #2923

Merged
merged 1 commit into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions libs/langgraph/langgraph/func/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio

Check notice on line 1 in libs/langgraph/langgraph/func/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 61.8 ms +- 1.3 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 53.3 ms +- 1.2 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 76.0 ms +- 1.8 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 96.7 ms +- 2.6 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 662 ms +- 42 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 526 ms +- 13 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 820 ms +- 45 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 964 ms +- 27 ms ......................................... react_agent_10x: Mean +- std dev: 31.1 ms +- 0.9 ms ......................................... react_agent_10x_sync: Mean +- std dev: 23.1 ms +- 0.3 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 38.4 ms +- 0.9 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 37.3 ms +- 0.7 ms ......................................... react_agent_100x: Mean +- std dev: 346 ms +- 7 ms ......................................... react_agent_100x_sync: Mean +- std dev: 281 ms +- 5 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 896 ms +- 17 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 872 ms +- 23 ms ......................................... wide_state_25x300: Mean +- std dev: 23.5 ms +- 0.5 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 15.3 ms +- 0.2 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 279 ms +- 15 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 275 ms +- 14 ms ......................................... wide_state_15x600: Mean +- std dev: 27.3 ms +- 0.6 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 17.9 ms +- 0.3 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 478 ms +- 18 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 475 ms +- 17 ms ......................................... wide_state_9x1200: Mean +- std dev: 27.4 ms +- 0.6 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 17.8 ms +- 0.3 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 315 ms +- 16 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 309 ms +- 15 ms

Check notice on line 1 in libs/langgraph/langgraph/func/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +=========================================+=========+=======================+ | react_agent_10x_sync | 23.0 ms | 23.1 ms: 1.00x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 15.2 ms | 15.3 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 96.0 ms | 96.7 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600 | 27.1 ms | 27.3 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 17.7 ms | 17.9 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 17.7 ms | 17.8 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_sync | 52.8 ms | 53.3 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300 | 23.2 ms | 23.5 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 473 ms | 478 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint_sync | 953 ms | 964 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_10x | 30.7 ms | 31.1 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 27.0 ms | 27.4 ms: 1.02x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint | 37.7 ms | 38.4 ms: 1.02x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_100x | 340 ms | 346 ms: 1.02x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 276 ms | 281 ms: 1.02x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 516 ms | 526 ms: 1.02x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint | 74.2 ms | 76.0 ms: 1.02x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 832 ms | 872 ms: 1.05x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 836 ms | 896 ms: 1.07x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 607 ms | 662 ms: 1.09x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 751 ms | 820 ms: 1.09x slower | +-----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.02x slower | +-----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (7): react_agent_10x_checkpoint_sync, fanout_to_subgraph_10x, wide_state_15x600_checkpoint_sync, wide_state_25x300_checkpoint_sync, wide_state_9x1200_checkpoint, wide_state_9x1200_checkpoint_sync, wide_state_25x300_checkpoint
import concurrent
import concurrent.futures
import functools
import inspect
import types
from functools import partial, update_wrapper
from typing import (
Any,
Awaitable,
Expand Down Expand Up @@ -33,17 +33,17 @@


def call(
func: Callable[[P1], T],
input: P1,
*,
func: Callable[P, T],
*args: Any,
retry: Optional[RetryPolicy] = None,
**kwargs: Any,
) -> concurrent.futures.Future[T]:
from langgraph.constants import CONFIG_KEY_CALL
from langgraph.utils.config import get_configurable

conf = get_configurable()
impl = conf[CONFIG_KEY_CALL]
fut = impl(func, input, retry=retry)
fut = impl(func, (args, kwargs), retry=retry)
return fut


Expand All @@ -59,16 +59,51 @@
) -> Callable[[Callable[P, T]], Callable[P, concurrent.futures.Future[T]]]: ...


@overload
def task(
*, retry: Optional[RetryPolicy] = None
__func_or_none__: Callable[P, T],
) -> Callable[P, concurrent.futures.Future[T]]: ...


@overload
def task(
__func_or_none__: Callable[P, Awaitable[T]],
) -> Callable[P, asyncio.Future[T]]: ...


def task(
__func_or_none__: Optional[Union[Callable[P, T], Callable[P, Awaitable[T]]]] = None,
*,
retry: Optional[RetryPolicy] = None,
) -> Union[
Callable[[Callable[P, Awaitable[T]]], Callable[P, asyncio.Future[T]]],
Callable[[Callable[P, T]], Callable[P, concurrent.futures.Future[T]]],
Callable[P, asyncio.Future[T]],
Callable[P, concurrent.futures.Future[T]],
]:
def _task(func: Callable[P, T]) -> Callable[P, concurrent.futures.Future[T]]:
return update_wrapper(partial(call, func, retry=retry), func)
def decorator(
func: Union[Callable[P, Awaitable[T]], Callable[P, T]],
) -> Callable[P, concurrent.futures.Future[T]]:
if asyncio.iscoroutinefunction(func):

@functools.wraps(func)
async def _tick(__allargs__: tuple) -> T:
return await func(*__allargs__[0], **__allargs__[1])

else:

@functools.wraps(func)
def _tick(__allargs__: tuple) -> T:
return func(*__allargs__[0], **__allargs__[1])

return functools.update_wrapper(
functools.partial(call, _tick, retry=retry), func
)

if __func_or_none__ is not None:
return decorator(__func_or_none__)

return _task
return decorator


def entrypoint(
Expand Down
31 changes: 19 additions & 12 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,27 +1515,32 @@ def test_imp_stream_order(
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

@task()
def foo(state: dict) -> dict:
return {"a": state["a"] + "foo", "b": "bar"}
def foo(state: dict) -> tuple:
return state["a"] + "foo", "bar"

@task()
def bar(state: dict) -> dict:
return {"a": state["a"] + state["b"], "c": "bark"}
@task
def bar(a: str, b: str, c: Optional[str] = None) -> dict:
return {"a": a + b, "c": (c or "") + "bark"}

@task()
@task
def baz(state: dict) -> dict:
return {"a": state["a"] + "baz", "c": "something else"}

@entrypoint(checkpointer=checkpointer)
def graph(state: dict) -> dict:
fut_foo = foo(state)
fut_bar = bar(fut_foo.result())
fut_bar = bar(*fut_foo.result())
fut_baz = baz(fut_bar.result())
return fut_baz.result()

thread1 = {"configurable": {"thread_id": "1"}}
assert [c for c in graph.stream({"a": "0"}, thread1)] == [
{"foo": {"a": "0foo", "b": "bar"}},
{
"foo": (
"0foo",
"bar",
)
},
{"bar": {"a": "0foobar", "c": "bark"}},
{"baz": {"a": "0foobarbaz", "c": "something else"}},
{"graph": {"a": "0foobarbaz", "c": "something else"}},
Expand Down Expand Up @@ -4168,10 +4173,12 @@ def __init__(self, i: Optional[int] = None):
def __call__(self, inputs: State, config: RunnableConfig, store: BaseStore):
assert isinstance(store, BaseStore)
store.put(
namespace
if self.i is not None
and config["configurable"]["thread_id"] in (thread_1, thread_2)
else (f"foo_{self.i}", "bar"),
(
namespace
if self.i is not None
and config["configurable"]["thread_id"] in (thread_1, thread_2)
else (f"foo_{self.i}", "bar")
),
doc_id,
{
**doc,
Expand Down
21 changes: 11 additions & 10 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2571,18 +2571,18 @@ async def test_imp_sync_from_async(checkpointer_name: str) -> None:
def foo(state: dict) -> dict:
return {"a": state["a"] + "foo", "b": "bar"}

@task()
def bar(state: dict) -> dict:
return {"a": state["a"] + state["b"], "c": "bark"}
@task
def bar(a: str, b: str, c: Optional[str] = None) -> dict:
return {"a": a + b, "c": (c or "") + "bark"}

@task()
def baz(state: dict) -> dict:
return {"a": state["a"] + "baz", "c": "something else"}

@entrypoint(checkpointer=checkpointer)
def graph(state: dict) -> dict:
fut_foo = foo(state)
fut_bar = bar(fut_foo.result())
foo_result = foo(state).result()
fut_bar = bar(foo_result["a"], foo_result["b"])
fut_baz = baz(fut_bar.result())
return fut_baz.result()

Expand All @@ -2607,18 +2607,19 @@ async def test_imp_stream_order(checkpointer_name: str) -> None:
async def foo(state: dict) -> dict:
return {"a": state["a"] + "foo", "b": "bar"}

@task()
async def bar(state: dict) -> dict:
return {"a": state["a"] + state["b"], "c": "bark"}
@task
async def bar(a: str, b: str, c: Optional[str] = None) -> dict:
return {"a": a + b, "c": (c or "") + "bark"}

@task()
async def baz(state: dict) -> dict:
return {"a": state["a"] + "baz", "c": "something else"}

@entrypoint(checkpointer=checkpointer)
async def graph(state: dict) -> dict:
fut_foo = foo(state)
fut_bar = bar(await fut_foo)
foo_res = await foo(state)

fut_bar = bar(foo_res["a"], foo_res["b"])
fut_baz = baz(await fut_bar)
return await fut_baz

Expand Down
Loading