diff --git a/libs/langgraph/langgraph/func/__init__.py b/libs/langgraph/langgraph/func/__init__.py index 26d583ed2..62bf138e6 100644 --- a/libs/langgraph/langgraph/func/__init__.py +++ b/libs/langgraph/langgraph/func/__init__.py @@ -1,9 +1,9 @@ import asyncio import concurrent import concurrent.futures +import functools import inspect import types -from functools import partial, update_wrapper from typing import ( Any, Awaitable, @@ -33,17 +33,17 @@ def call( - func: Callable[[P1], T], - input: P1, - *, + func: Callable[P, T], + *args: Any, retry: Optional[RetryPolicy] = None, + **kwargs: Any, ) -> concurrent.futures.Future[T]: from langgraph.constants import CONFIG_KEY_CALL from langgraph.utils.config import get_configurable conf = get_configurable() impl = conf[CONFIG_KEY_CALL] - fut = impl(func, input, retry=retry) + fut = impl(func, (args, kwargs), retry=retry) return fut @@ -59,16 +59,51 @@ def task( # type: ignore[overload-cannot-match] ) -> Callable[[Callable[P, T]], Callable[P, concurrent.futures.Future[T]]]: ... +@overload def task( - *, retry: Optional[RetryPolicy] = None + __func_or_none__: Callable[P, T], +) -> Callable[P, concurrent.futures.Future[T]]: ... + + +@overload +def task( + __func_or_none__: Callable[P, Awaitable[T]], +) -> Callable[P, asyncio.Future[T]]: ... + + +def task( + __func_or_none__: Optional[Union[Callable[P, T], Callable[P, Awaitable[T]]]] = None, + *, + retry: Optional[RetryPolicy] = None, ) -> Union[ Callable[[Callable[P, Awaitable[T]]], Callable[P, asyncio.Future[T]]], Callable[[Callable[P, T]], Callable[P, concurrent.futures.Future[T]]], + Callable[P, asyncio.Future[T]], + Callable[P, concurrent.futures.Future[T]], ]: - def _task(func: Callable[P, T]) -> Callable[P, concurrent.futures.Future[T]]: - return update_wrapper(partial(call, func, retry=retry), func) + def decorator( + func: Union[Callable[P, Awaitable[T]], Callable[P, T]], + ) -> Callable[P, concurrent.futures.Future[T]]: + if asyncio.iscoroutinefunction(func): + + @functools.wraps(func) + async def _tick(__allargs__: tuple) -> T: + return await func(*__allargs__[0], **__allargs__[1]) + + else: + + @functools.wraps(func) + def _tick(__allargs__: tuple) -> T: + return func(*__allargs__[0], **__allargs__[1]) + + return functools.update_wrapper( + functools.partial(call, _tick, retry=retry), func + ) + + if __func_or_none__ is not None: + return decorator(__func_or_none__) - return _task + return decorator def entrypoint( diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index c069293a8..33d8bea18 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -1515,27 +1515,32 @@ def test_imp_stream_order( checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") @task() - def foo(state: dict) -> dict: - return {"a": state["a"] + "foo", "b": "bar"} + def foo(state: dict) -> tuple: + return state["a"] + "foo", "bar" - @task() - def bar(state: dict) -> dict: - return {"a": state["a"] + state["b"], "c": "bark"} + @task + def bar(a: str, b: str, c: Optional[str] = None) -> dict: + return {"a": a + b, "c": (c or "") + "bark"} - @task() + @task def baz(state: dict) -> dict: return {"a": state["a"] + "baz", "c": "something else"} @entrypoint(checkpointer=checkpointer) def graph(state: dict) -> dict: fut_foo = foo(state) - fut_bar = bar(fut_foo.result()) + fut_bar = bar(*fut_foo.result()) fut_baz = baz(fut_bar.result()) return fut_baz.result() thread1 = {"configurable": {"thread_id": "1"}} assert [c for c in graph.stream({"a": "0"}, thread1)] == [ - {"foo": {"a": "0foo", "b": "bar"}}, + { + "foo": ( + "0foo", + "bar", + ) + }, {"bar": {"a": "0foobar", "c": "bark"}}, {"baz": {"a": "0foobarbaz", "c": "something else"}}, {"graph": {"a": "0foobarbaz", "c": "something else"}}, @@ -4168,10 +4173,12 @@ def __init__(self, i: Optional[int] = None): def __call__(self, inputs: State, config: RunnableConfig, store: BaseStore): assert isinstance(store, BaseStore) store.put( - namespace - if self.i is not None - and config["configurable"]["thread_id"] in (thread_1, thread_2) - else (f"foo_{self.i}", "bar"), + ( + namespace + if self.i is not None + and config["configurable"]["thread_id"] in (thread_1, thread_2) + else (f"foo_{self.i}", "bar") + ), doc_id, { **doc, diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 38cbef2a8..7cd0091e5 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -2571,9 +2571,9 @@ async def test_imp_sync_from_async(checkpointer_name: str) -> None: def foo(state: dict) -> dict: return {"a": state["a"] + "foo", "b": "bar"} - @task() - def bar(state: dict) -> dict: - return {"a": state["a"] + state["b"], "c": "bark"} + @task + def bar(a: str, b: str, c: Optional[str] = None) -> dict: + return {"a": a + b, "c": (c or "") + "bark"} @task() def baz(state: dict) -> dict: @@ -2581,8 +2581,8 @@ def baz(state: dict) -> dict: @entrypoint(checkpointer=checkpointer) def graph(state: dict) -> dict: - fut_foo = foo(state) - fut_bar = bar(fut_foo.result()) + foo_result = foo(state).result() + fut_bar = bar(foo_result["a"], foo_result["b"]) fut_baz = baz(fut_bar.result()) return fut_baz.result() @@ -2607,9 +2607,9 @@ async def test_imp_stream_order(checkpointer_name: str) -> None: async def foo(state: dict) -> dict: return {"a": state["a"] + "foo", "b": "bar"} - @task() - async def bar(state: dict) -> dict: - return {"a": state["a"] + state["b"], "c": "bark"} + @task + async def bar(a: str, b: str, c: Optional[str] = None) -> dict: + return {"a": a + b, "c": (c or "") + "bark"} @task() async def baz(state: dict) -> dict: @@ -2617,8 +2617,9 @@ async def baz(state: dict) -> dict: @entrypoint(checkpointer=checkpointer) async def graph(state: dict) -> dict: - fut_foo = foo(state) - fut_bar = bar(await fut_foo) + foo_res = await foo(state) + + fut_bar = bar(foo_res["a"], foo_res["b"]) fut_baz = baz(await fut_bar) return await fut_baz