Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

functional api: Add ability to request previous output #3025

Merged
merged 15 commits into from
Jan 16, 2025
2 changes: 2 additions & 0 deletions libs/langgraph/langgraph/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys

Check notice on line 1 in libs/langgraph/langgraph/constants.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 62.0 ms +- 1.7 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 53.9 ms +- 0.7 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 76.1 ms +- 1.4 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 97.6 ms +- 2.2 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 646 ms +- 29 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 530 ms +- 14 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 808 ms +- 15 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 984 ms +- 21 ms ......................................... react_agent_10x: Mean +- std dev: 31.3 ms +- 0.7 ms ......................................... react_agent_10x_sync: Mean +- std dev: 23.2 ms +- 0.3 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 38.5 ms +- 0.8 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 37.3 ms +- 0.6 ms ......................................... react_agent_100x: Mean +- std dev: 347 ms +- 6 ms ......................................... react_agent_100x_sync: Mean +- std dev: 277 ms +- 6 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 662 ms +- 11 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 657 ms +- 11 ms ......................................... wide_state_25x300: Mean +- std dev: 23.6 ms +- 0.5 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 15.5 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 253 ms +- 15 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 248 ms +- 15 ms ......................................... wide_state_15x600: Mean +- std dev: 27.5 ms +- 0.6 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 18.1 ms +- 0.2 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 436 ms +- 16 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 427 ms +- 15 ms ......................................... wide_state_9x1200: Mean +- std dev: 27.6 ms +- 0.5 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 18.0 ms +- 0.2 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 285 ms +- 16 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 278 ms +- 14 ms

Check notice on line 1 in libs/langgraph/langgraph/constants.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +========================================+=========+=======================+ | wide_state_15x600_checkpoint_sync | 433 ms | 427 ms: 1.01x faster | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_sync | 54.1 ms | 53.9 ms: 1.01x faster | +----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 17.9 ms | 18.0 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 15.4 ms | 15.5 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | react_agent_100x | 344 ms | 347 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint_sync | 37.0 ms | 37.3 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 27.3 ms | 27.6 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | wide_state_15x600 | 27.3 ms | 27.5 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | react_agent_10x | 30.9 ms | 31.3 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 274 ms | 277 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 523 ms | 530 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint | 75.0 ms | 76.1 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 96.2 ms | 97.6 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 17.8 ms | 18.1 ms: 1.02x slower | +----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint | 249 ms | 253 ms: 1.02x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 781 ms | 808 ms: 1.03x slower | +----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 639 ms | 662 ms: 1.04x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 623 ms | 646 ms: 1.04x slower | +----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 624 ms | 657 ms: 1.05x slower | +----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.01x slower | +----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (9): fanout_to_subgraph_10x, fanout_to_subgraph_100x_checkpoint_sync, react_agent_10x_sync, wide_state_9x1200_checkpoint_sync, react_agent_10x_checkpoint, wide_state_25x300, wide_state_15x600_checkpoint, wide_state_9x1200_checkpoint, wide_state_25x300_checkpoint_sync
from os import getenv
from types import MappingProxyType
from typing import Any, Literal, Mapping, cast
Expand Down Expand Up @@ -81,6 +81,8 @@
# read-only list of existing task writes
CONFIG_KEY_SCRATCHPAD = sys.intern("__pregel_scratchpad")
# holds a mutable dict for temporary storage scoped to the current task
CONFIG_KEY_END = sys.intern("__pregel_previous")
# holds the previous return value from a stateful Pregel graph.

# --- Other constants ---
PUSH = sys.intern("__pregel_push")
Expand Down
26 changes: 3 additions & 23 deletions libs/langgraph/langgraph/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import concurrent
import concurrent.futures
import functools
import inspect
import types
from typing import (
Any,
Expand All @@ -25,7 +24,7 @@
from langgraph.pregel.read import PregelNode
from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.types import RetryPolicy, StreamMode, StreamWriter
from langgraph.types import RetryPolicy, StreamMode

P = ParamSpec("P")
P1 = TypeVar("P1")
Expand Down Expand Up @@ -112,27 +111,8 @@ def entrypoint(
store: Optional[BaseStore] = None,
) -> Callable[[types.FunctionType], Pregel]:
def _imp(func: types.FunctionType) -> Pregel:
if inspect.isgeneratorfunction(func):
eyurtsev marked this conversation as resolved.
Show resolved Hide resolved

def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any:
for chunk in func(*args, **kwargs):
writer(chunk)

bound = get_runnable_for_func(gen_wrapper)
stream_mode: StreamMode = "custom"
elif inspect.isasyncgenfunction(func):

async def agen_wrapper(
*args: Any, writer: StreamWriter, **kwargs: Any
) -> Any:
async for chunk in func(*args, **kwargs):
writer(chunk)

bound = get_runnable_for_func(agen_wrapper)
stream_mode = "custom"
else:
bound = get_runnable_for_func(func)
stream_mode = "updates"
bound = get_runnable_for_func(func)
stream_mode: StreamMode = "updates"

return Pregel(
nodes={
Expand Down
10 changes: 10 additions & 0 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_CHECKPOINTER,
CONFIG_KEY_END,
CONFIG_KEY_READ,
CONFIG_KEY_SCRATCHPAD,
CONFIG_KEY_SEND,
Expand Down Expand Up @@ -560,6 +561,9 @@ def prepare_single_task(
if w[0] in (NULL_TASK_ID, task_id)
],
CONFIG_KEY_SCRATCHPAD: {},
CONFIG_KEY_END: checkpoint["channel_values"].get(
"__end__", None
),
},
),
triggers,
Expand Down Expand Up @@ -709,6 +713,9 @@ def prepare_single_task(
if w[0] in (NULL_TASK_ID, task_id)
],
CONFIG_KEY_SCRATCHPAD: {},
CONFIG_KEY_END: checkpoint["channel_values"].get(
"__end__", None
),
},
),
triggers,
Expand Down Expand Up @@ -833,6 +840,9 @@ def prepare_single_task(
if w[0] in (NULL_TASK_ID, task_id)
],
CONFIG_KEY_SCRATCHPAD: {},
CONFIG_KEY_END: checkpoint["channel_values"].get(
"__end__", None
),
},
),
triggers,
Expand Down
59 changes: 46 additions & 13 deletions libs/langgraph/langgraph/utils/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from typing_extensions import TypeGuard

from langgraph.constants import CONF, CONFIG_KEY_STORE, CONFIG_KEY_STREAM_WRITER
from langgraph.constants import (
CONF,
CONFIG_KEY_END,
CONFIG_KEY_STORE,
CONFIG_KEY_STREAM_WRITER,
)
from langgraph.store.base import BaseStore
from langgraph.types import StreamWriter
from langgraph.utils.config import (
Expand All @@ -58,6 +63,10 @@ class StrEnum(str, enum.Enum):
"""A string enum."""


# Special type to denote any type is accepted
ANY_TYPE = object()


ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11)

KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = (
Expand All @@ -73,9 +82,22 @@ class StrEnum(str, enum.Enum):
CONFIG_KEY_STORE,
inspect.Parameter.empty,
),
(
sys.intern("previous"),
(ANY_TYPE,),
CONFIG_KEY_END,
inspect.Parameter.empty,
),
)
"""List of kwargs that can be passed to functions, and their corresponding
config keys, default values and type annotations."""
config keys, default values and type annotations.

Each tuple contains:
- the name of the kwarg in the function signature
- the type annotation(s) for the kwarg
- the config key to look for the value in
- the default value for the kwarg
"""

VALID_KINDS = (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)

Expand Down Expand Up @@ -122,9 +144,12 @@ def __init__(
self.func_accepts: dict[str, bool] = {}
for kw, typ, _, _ in KWARGS_CONFIG_KEYS:
p = params.get(kw)
self.func_accepts[kw] = (
p is not None and p.annotation in typ and p.kind in VALID_KINDS
)
if typ == (ANY_TYPE,):
self.func_accepts[kw] = p is not None and p.kind in VALID_KINDS
else:
self.func_accepts[kw] = (
p is not None and p.annotation in typ and p.kind in VALID_KINDS
)

def __repr__(self) -> str:
repr_args = {
Expand All @@ -149,16 +174,20 @@ def invoke(
if self.func_accepts_config:
kwargs["config"] = config
_conf = config[CONF]
for kw, _, ck, defv in KWARGS_CONFIG_KEYS:
for kw, _, config_key, default_value in KWARGS_CONFIG_KEYS:
if not self.func_accepts[kw]:
continue

if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf:
if (
default_value is inspect.Parameter.empty
and kw not in kwargs
and config_key not in _conf
):
raise ValueError(
f"Missing required config key '{ck}' for '{self.name}'."
f"Missing required config key '{config_key}' for '{self.name}'."
)
elif kwargs.get(kw) is None:
kwargs[kw] = _conf.get(ck, defv)
kwargs[kw] = _conf.get(config_key, default_value)

context = copy_context()
if self.trace:
Expand Down Expand Up @@ -197,16 +226,20 @@ async def ainvoke(
if self.func_accepts_config:
kwargs["config"] = config
_conf = config[CONF]
for kw, _, ck, defv in KWARGS_CONFIG_KEYS:
for kw, _, config_key, default_value in KWARGS_CONFIG_KEYS:
if not self.func_accepts[kw]:
continue

if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf:
if (
default_value is inspect.Parameter.empty
and kw not in kwargs
and config_key not in _conf
):
raise ValueError(
f"Missing required config key '{ck}' for '{self.name}'."
f"Missing required config key '{config_key}' for '{self.name}'."
)
elif kwargs.get(kw) is None:
kwargs[kw] = _conf.get(ck, defv)
kwargs[kw] = _conf.get(config_key, default_value)
context = copy_context()
if self.trace:
callback_manager = get_async_callback_manager_for_config(config, self.tags)
Expand Down
23 changes: 23 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5300,3 +5300,26 @@ def node_b(state):
{"node_a": [{"foo": "a1"}, {"foo": "a2"}]},
{"node_b": {"foo": "b"}},
]


def test_version_1_of_entrypoint() -> None:
from langgraph.func import entrypoint

states = []

# In this version reducers do not work
@entrypoint(checkpointer=MemorySaver())
def foo(inputs, *, previous: Any) -> Any:
states.append(previous)
return {"previous": previous, "current": inputs}

config = {"configurable": {"thread_id": "1"}}

foo.invoke({"a": "1"}, config)
foo.invoke({"a": "2"}, config)
foo.invoke({"a": "3"}, config)
assert states == [
None,
{"current": {"a": "1"}, "previous": None},
{"current": {"a": "2"}, "previous": {"current": {"a": "1"}, "previous": None}},
]
Loading