From 994621134d44a118cdec870c17a12af0723db130 Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Thu, 2 Jan 2025 14:50:22 -0800 Subject: [PATCH] Support multiple args in @task's --- libs/langgraph/langgraph/func/__init__.py | 53 +++++++++++++++++++---- libs/langgraph/tests/test_pregel.py | 24 +++++----- libs/langgraph/tests/test_pregel_async.py | 21 ++++----- 3 files changed, 68 insertions(+), 30 deletions(-) diff --git a/libs/langgraph/langgraph/func/__init__.py b/libs/langgraph/langgraph/func/__init__.py index 26d583ed2f..62bf138e64 100644 --- a/libs/langgraph/langgraph/func/__init__.py +++ b/libs/langgraph/langgraph/func/__init__.py @@ -1,9 +1,9 @@ import asyncio import concurrent import concurrent.futures +import functools import inspect import types -from functools import partial, update_wrapper from typing import ( Any, Awaitable, @@ -33,17 +33,17 @@ def call( - func: Callable[[P1], T], - input: P1, - *, + func: Callable[P, T], + *args: Any, retry: Optional[RetryPolicy] = None, + **kwargs: Any, ) -> concurrent.futures.Future[T]: from langgraph.constants import CONFIG_KEY_CALL from langgraph.utils.config import get_configurable conf = get_configurable() impl = conf[CONFIG_KEY_CALL] - fut = impl(func, input, retry=retry) + fut = impl(func, (args, kwargs), retry=retry) return fut @@ -59,16 +59,51 @@ def task( # type: ignore[overload-cannot-match] ) -> Callable[[Callable[P, T]], Callable[P, concurrent.futures.Future[T]]]: ... +@overload def task( - *, retry: Optional[RetryPolicy] = None + __func_or_none__: Callable[P, T], +) -> Callable[P, concurrent.futures.Future[T]]: ... + + +@overload +def task( + __func_or_none__: Callable[P, Awaitable[T]], +) -> Callable[P, asyncio.Future[T]]: ... + + +def task( + __func_or_none__: Optional[Union[Callable[P, T], Callable[P, Awaitable[T]]]] = None, + *, + retry: Optional[RetryPolicy] = None, ) -> Union[ Callable[[Callable[P, Awaitable[T]]], Callable[P, asyncio.Future[T]]], Callable[[Callable[P, T]], Callable[P, concurrent.futures.Future[T]]], + Callable[P, asyncio.Future[T]], + Callable[P, concurrent.futures.Future[T]], ]: - def _task(func: Callable[P, T]) -> Callable[P, concurrent.futures.Future[T]]: - return update_wrapper(partial(call, func, retry=retry), func) + def decorator( + func: Union[Callable[P, Awaitable[T]], Callable[P, T]], + ) -> Callable[P, concurrent.futures.Future[T]]: + if asyncio.iscoroutinefunction(func): + + @functools.wraps(func) + async def _tick(__allargs__: tuple) -> T: + return await func(*__allargs__[0], **__allargs__[1]) + + else: + + @functools.wraps(func) + def _tick(__allargs__: tuple) -> T: + return func(*__allargs__[0], **__allargs__[1]) + + return functools.update_wrapper( + functools.partial(call, _tick, retry=retry), func + ) + + if __func_or_none__ is not None: + return decorator(__func_or_none__) - return _task + return decorator def entrypoint( diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index c069293a8e..aec27067fd 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -1515,21 +1515,21 @@ def test_imp_stream_order( checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") @task() - def foo(state: dict) -> dict: - return {"a": state["a"] + "foo", "b": "bar"} + def foo(state: dict) -> tuple: + return state["a"] + "foo", "bar" - @task() - def bar(state: dict) -> dict: - return {"a": state["a"] + state["b"], "c": "bark"} + @task + def bar(a: str, b: str, c: Optional[str] = None) -> dict: + return {"a": a + b, "c": (c or "") + "bark"} - @task() + @task def baz(state: dict) -> dict: return {"a": state["a"] + "baz", "c": "something else"} @entrypoint(checkpointer=checkpointer) def graph(state: dict) -> dict: fut_foo = foo(state) - fut_bar = bar(fut_foo.result()) + fut_bar = bar(*fut_foo.result()) fut_baz = baz(fut_bar.result()) return fut_baz.result() @@ -4168,10 +4168,12 @@ def __init__(self, i: Optional[int] = None): def __call__(self, inputs: State, config: RunnableConfig, store: BaseStore): assert isinstance(store, BaseStore) store.put( - namespace - if self.i is not None - and config["configurable"]["thread_id"] in (thread_1, thread_2) - else (f"foo_{self.i}", "bar"), + ( + namespace + if self.i is not None + and config["configurable"]["thread_id"] in (thread_1, thread_2) + else (f"foo_{self.i}", "bar") + ), doc_id, { **doc, diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 38cbef2a85..7cd0091e5e 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -2571,9 +2571,9 @@ async def test_imp_sync_from_async(checkpointer_name: str) -> None: def foo(state: dict) -> dict: return {"a": state["a"] + "foo", "b": "bar"} - @task() - def bar(state: dict) -> dict: - return {"a": state["a"] + state["b"], "c": "bark"} + @task + def bar(a: str, b: str, c: Optional[str] = None) -> dict: + return {"a": a + b, "c": (c or "") + "bark"} @task() def baz(state: dict) -> dict: @@ -2581,8 +2581,8 @@ def baz(state: dict) -> dict: @entrypoint(checkpointer=checkpointer) def graph(state: dict) -> dict: - fut_foo = foo(state) - fut_bar = bar(fut_foo.result()) + foo_result = foo(state).result() + fut_bar = bar(foo_result["a"], foo_result["b"]) fut_baz = baz(fut_bar.result()) return fut_baz.result() @@ -2607,9 +2607,9 @@ async def test_imp_stream_order(checkpointer_name: str) -> None: async def foo(state: dict) -> dict: return {"a": state["a"] + "foo", "b": "bar"} - @task() - async def bar(state: dict) -> dict: - return {"a": state["a"] + state["b"], "c": "bark"} + @task + async def bar(a: str, b: str, c: Optional[str] = None) -> dict: + return {"a": a + b, "c": (c or "") + "bark"} @task() async def baz(state: dict) -> dict: @@ -2617,8 +2617,9 @@ async def baz(state: dict) -> dict: @entrypoint(checkpointer=checkpointer) async def graph(state: dict) -> dict: - fut_foo = foo(state) - fut_bar = bar(await fut_foo) + foo_res = await foo(state) + + fut_bar = bar(foo_res["a"], foo_res["b"]) fut_baz = baz(await fut_bar) return await fut_baz