Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

functional api: Add ability to request previous output #3025

Merged
merged 15 commits into from
Jan 16, 2025
2 changes: 2 additions & 0 deletions libs/langgraph/langgraph/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys

Check notice on line 1 in libs/langgraph/langgraph/constants.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 63.7 ms +- 1.5 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 55.0 ms +- 1.1 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 78.3 ms +- 1.3 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 99.7 ms +- 2.0 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 669 ms +- 35 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 533 ms +- 18 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 853 ms +- 29 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 1.00 sec +- 0.02 sec ......................................... react_agent_10x: Mean +- std dev: 30.7 ms +- 0.7 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.9 ms +- 0.2 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 38.5 ms +- 0.8 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 37.0 ms +- 0.6 ms ......................................... react_agent_100x: Mean +- std dev: 343 ms +- 6 ms ......................................... react_agent_100x_sync: Mean +- std dev: 274 ms +- 3 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 673 ms +- 9 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 657 ms +- 9 ms ......................................... wide_state_25x300: Mean +- std dev: 23.7 ms +- 0.5 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 15.6 ms +- 0.3 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 256 ms +- 17 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 254 ms +- 18 ms ......................................... wide_state_15x600: Mean +- std dev: 28.0 ms +- 0.6 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 18.2 ms +- 0.2 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 438 ms +- 18 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 432 ms +- 18 ms ......................................... wide_state_9x1200: Mean +- std dev: 27.4 ms +- 0.7 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 17.8 ms +- 0.2 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 285 ms +- 16 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 284 ms +- 19 ms

Check notice on line 1 in libs/langgraph/langgraph/constants.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+----------+------------------------+ | Benchmark | main | changes | +=========================================+==========+========================+ | fanout_to_subgraph_100x_checkpoint_sync | 1.01 sec | 1.00 sec: 1.01x faster | +-----------------------------------------+----------+------------------------+ | wide_state_9x1200_sync | 17.9 ms | 17.8 ms: 1.01x faster | +-----------------------------------------+----------+------------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 99.0 ms | 99.7 ms: 1.01x slower | +-----------------------------------------+----------+------------------------+ | react_agent_100x | 340 ms | 343 ms: 1.01x slower | +-----------------------------------------+----------+------------------------+ | wide_state_25x300_sync | 15.5 ms | 15.6 ms: 1.01x slower | +-----------------------------------------+----------+------------------------+ | wide_state_25x300 | 23.3 ms | 23.7 ms: 1.01x slower | +-----------------------------------------+----------+------------------------+ | fanout_to_subgraph_100x_sync | 526 ms | 533 ms: 1.01x slower | +-----------------------------------------+----------+------------------------+ | react_agent_10x_checkpoint_sync | 36.5 ms | 37.0 ms: 1.01x slower | +-----------------------------------------+----------+------------------------+ | wide_state_15x600_sync | 17.9 ms | 18.2 ms: 1.01x slower | +-----------------------------------------+----------+------------------------+ | react_agent_100x_sync | 270 ms | 274 ms: 1.02x slower | +-----------------------------------------+----------+------------------------+ | wide_state_15x600_checkpoint_sync | 424 ms | 432 ms: 1.02x slower | +-----------------------------------------+----------+------------------------+ | wide_state_15x600_checkpoint | 429 ms | 438 ms: 1.02x slower | +-----------------------------------------+----------+------------------------+ | wide_state_15x600 | 27.2 ms | 28.0 ms: 1.03x slower | +-----------------------------------------+----------+------------------------+ | wide_state_9x1200_checkpoint_sync | 276 ms | 284 ms: 1.03x slower | +-----------------------------------------+----------+------------------------+ | wide_state_25x300_checkpoint | 248 ms | 256 ms: 1.03x slower | +-----------------------------------------+----------+------------------------+ | wide_state_25x300_checkpoint_sync | 245 ms | 254 ms: 1.04x slower | +-----------------------------------------+----------+------------------------+ | fanout_to_subgraph_100x | 641 ms | 669 ms: 1.04x slower | +-----------------------------------------+----------+------------------------+ | react_agent_100x_checkpoint | 639 ms | 673 ms: 1.05x slower | +-----------------------------------------+----------+------------------------+ | react_agent_100x_checkpoint_sync | 619 ms | 657 ms: 1.06x slower | +-----------------------------------------+----------+------------------------+ | fanout_to_subgraph_100x_checkpoint | 803 ms | 853 ms: 1.06x slower | +-----------------------------------------+----------+------------------------+ | Geometric mean | (ref) | 1.02x slower | +-----------------------------------------+----------+------------------------+ Benchmark hidden because not significant (8): fanout_to_subgraph_10x_sync, react_agent_10x, fanout_to_subgraph_10x, fanout_to_subgraph_10x_checkpoint, react_agent_10x_sync, wide_state_9x1200_checkpoint, wide_state_9x1200, react_agent_10x_checkpoint
from types import MappingProxyType
from typing import Any, Literal, Mapping, cast

Expand Down Expand Up @@ -78,6 +78,8 @@
# holds a callback to be called when a node is finished
CONFIG_KEY_SCRATCHPAD = sys.intern("__pregel_scratchpad")
# holds a mutable dict for temporary storage scoped to the current task
CONFIG_KEY_END = sys.intern("__pregel_previous")
# holds the previous return value from a stateful Pregel graph.

# --- Other constants ---
PUSH = sys.intern("__pregel_push")
Expand Down
96 changes: 85 additions & 11 deletions libs/langgraph/langgraph/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,22 +113,96 @@ def entrypoint(
config_schema: Optional[type[Any]] = None,
) -> Callable[[types.FunctionType], Pregel]:
def _imp(func: types.FunctionType) -> Pregel:
# wrap generators in a function that writes to StreamWriter
if inspect.isgeneratorfunction(func):
eyurtsev marked this conversation as resolved.
Show resolved Hide resolved
"""Convert a function into a Pregel graph.

def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any:
for chunk in func(*args, **kwargs):
writer(chunk)
Args:
func: The function to convert. Support both sync and async functions, as well
as generator and async generator functions.

Returns:
A Pregel graph.
"""
# wrap generators in a function that writes to StreamWriter
if inspect.isgeneratorfunction(func):
original_sig = inspect.signature(func)
# Check if original signature has a writer argument with a matching type.
# If not, we'll inject it into the decorator, but not pass it
# to the wrapped function.
if "writer" in original_sig.parameters:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this changing?


@functools.wraps(func)
def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any:
chunks = []
for chunk in func(*args, writer=writer, **kwargs):
writer(chunk)
chunks.append(chunk)
return chunks
else:

@functools.wraps(func)
def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any:
chunks = []
# Do not pass the writer argument to the wrapped function
# as it does not have a matching parameter
for chunk in func(*args, **kwargs):
writer(chunk)
chunks.append(chunk)
return chunks

# Create a new parameter for the writer argument
extra_param = inspect.Parameter(
"writer",
inspect.Parameter.KEYWORD_ONLY,
# The extra argument is a keyword-only argument
default=lambda _: None,
)
# Update the function's signature to include the extra argument
new_params = list(original_sig.parameters.values()) + [extra_param]
new_sig = original_sig.replace(parameters=new_params)
# Update the signature of the wrapper function
gen_wrapper.__signature__ = new_sig # type: ignore
bound = get_runnable_for_func(gen_wrapper)
stream_mode: StreamMode = "custom"
elif inspect.isasyncgenfunction(func):

async def agen_wrapper(
*args: Any, writer: StreamWriter, **kwargs: Any
) -> Any:
async for chunk in func(*args, **kwargs):
writer(chunk)
original_sig = inspect.signature(func)
# Check if original signature has a writer argument with a matching type.
# If not, we'll inject it into the decorator, but not pass it
# to the wrapped function.
if "writer" in original_sig.parameters:

@functools.wraps(func)
async def agen_wrapper(
*args: Any, writer: StreamWriter, **kwargs: Any
) -> Any:
chunks = []
async for chunk in func(*args, writer=writer, **kwargs):
writer(chunk)
chunks.append(chunk)
return chunks
else:

@functools.wraps(func)
async def agen_wrapper(
*args: Any, writer: StreamWriter, **kwargs: Any
) -> Any:
chunks = []
async for chunk in func(*args, **kwargs):
writer(chunk)
chunks.append(chunk)
return chunks

# Create a new parameter for the writer argument
extra_param = inspect.Parameter(
"writer",
inspect.Parameter.KEYWORD_ONLY,
# The extra argument is a keyword-only argument
default=lambda _: None,
)
# Update the function's signature to include the extra argument
new_params = list(original_sig.parameters.values()) + [extra_param]
new_sig = original_sig.replace(parameters=new_params)
# Update the signature of the wrapper function
agen_wrapper.__signature__ = new_sig # type: ignore

bound = get_runnable_for_func(agen_wrapper)
stream_mode = "custom"
Expand Down
10 changes: 10 additions & 0 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_CHECKPOINTER,
CONFIG_KEY_END,
CONFIG_KEY_READ,
CONFIG_KEY_SCRATCHPAD,
CONFIG_KEY_SEND,
Expand Down Expand Up @@ -507,6 +508,9 @@ def prepare_single_task(
pending_writes,
task_id,
),
CONFIG_KEY_END: checkpoint["channel_values"].get(
"__end__", None
),
},
),
triggers,
Expand Down Expand Up @@ -616,6 +620,9 @@ def prepare_single_task(
pending_writes,
task_id,
),
CONFIG_KEY_END: checkpoint["channel_values"].get(
"__end__", None
),
},
),
triggers,
Expand Down Expand Up @@ -737,6 +744,9 @@ def prepare_single_task(
pending_writes,
task_id,
),
CONFIG_KEY_END: checkpoint["channel_values"].get(
"__end__", None
),
},
),
triggers,
Expand Down
50 changes: 38 additions & 12 deletions libs/langgraph/langgraph/utils/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from typing_extensions import TypeGuard

from langgraph.constants import CONF, CONFIG_KEY_STORE, CONFIG_KEY_STREAM_WRITER
from langgraph.constants import (
CONF,
CONFIG_KEY_END,
CONFIG_KEY_STORE,
CONFIG_KEY_STREAM_WRITER,
)
from langgraph.store.base import BaseStore
from langgraph.types import StreamWriter
from langgraph.utils.config import (
Expand All @@ -58,6 +63,10 @@ class StrEnum(str, enum.Enum):
"""A string enum."""


# Special type to denote any type is accepted
ANY_TYPE = object()


ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11)

KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = (
Expand All @@ -73,6 +82,12 @@ class StrEnum(str, enum.Enum):
CONFIG_KEY_STORE,
inspect.Parameter.empty,
),
(
sys.intern("previous"),
(ANY_TYPE,),
CONFIG_KEY_END,
inspect.Parameter.empty,
),
)
"""List of kwargs that can be passed to functions, and their corresponding
config keys, default values and type annotations.
Expand Down Expand Up @@ -135,9 +150,12 @@ def __init__(
self.func_accepts: dict[str, bool] = {}
for kw, typ, _, _ in KWARGS_CONFIG_KEYS:
p = params.get(kw)
self.func_accepts[kw] = (
p is not None and p.annotation in typ and p.kind in VALID_KINDS
)
if typ == (ANY_TYPE,):
self.func_accepts[kw] = p is not None and p.kind in VALID_KINDS
else:
self.func_accepts[kw] = (
p is not None and p.annotation in typ and p.kind in VALID_KINDS
)

def __repr__(self) -> str:
repr_args = {
Expand All @@ -162,16 +180,20 @@ def invoke(
if self.func_accepts_config:
kwargs["config"] = config
_conf = config[CONF]
for kw, _, ck, defv in KWARGS_CONFIG_KEYS:
for kw, _, config_key, default_value in KWARGS_CONFIG_KEYS:
if not self.func_accepts[kw]:
continue

if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf:
if (
default_value is inspect.Parameter.empty
and kw not in kwargs
and config_key not in _conf
):
raise ValueError(
f"Missing required config key '{ck}' for '{self.name}'."
f"Missing required config key '{config_key}' for '{self.name}'."
)
elif kwargs.get(kw) is None:
kwargs[kw] = _conf.get(ck, defv)
kwargs[kw] = _conf.get(config_key, default_value)

context = copy_context()
if self.trace:
Expand Down Expand Up @@ -210,16 +232,20 @@ async def ainvoke(
if self.func_accepts_config:
kwargs["config"] = config
_conf = config[CONF]
for kw, _, ck, defv in KWARGS_CONFIG_KEYS:
for kw, _, config_key, default_value in KWARGS_CONFIG_KEYS:
if not self.func_accepts[kw]:
continue

if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf:
if (
default_value is inspect.Parameter.empty
and kw not in kwargs
and config_key not in _conf
):
raise ValueError(
f"Missing required config key '{ck}' for '{self.name}'."
f"Missing required config key '{config_key}' for '{self.name}'."
)
elif kwargs.get(kw) is None:
kwargs[kw] = _conf.get(ck, defv)
kwargs[kw] = _conf.get(config_key, default_value)
context = copy_context()
if self.trace:
callback_manager = get_async_callback_manager_for_config(config, self.tags)
Expand Down
Loading
Loading