diff --git a/libs/langgraph/langgraph/constants.py b/libs/langgraph/langgraph/constants.py index cd847f9be..caba214cb 100644 --- a/libs/langgraph/langgraph/constants.py +++ b/libs/langgraph/langgraph/constants.py @@ -81,6 +81,8 @@ # read-only list of existing task writes CONFIG_KEY_SCRATCHPAD = sys.intern("__pregel_scratchpad") # holds a mutable dict for temporary storage scoped to the current task +CONFIG_KEY_END = sys.intern("__pregel_previous") +# holds the previous return value from a stateful Pregel graph. # --- Other constants --- PUSH = sys.intern("__pregel_push") diff --git a/libs/langgraph/langgraph/func/__init__.py b/libs/langgraph/langgraph/func/__init__.py index 62bf138e6..76e97d7e5 100644 --- a/libs/langgraph/langgraph/func/__init__.py +++ b/libs/langgraph/langgraph/func/__init__.py @@ -112,27 +112,8 @@ def entrypoint( store: Optional[BaseStore] = None, ) -> Callable[[types.FunctionType], Pregel]: def _imp(func: types.FunctionType) -> Pregel: - if inspect.isgeneratorfunction(func): - - def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any: - for chunk in func(*args, **kwargs): - writer(chunk) - - bound = get_runnable_for_func(gen_wrapper) - stream_mode: StreamMode = "custom" - elif inspect.isasyncgenfunction(func): - - async def agen_wrapper( - *args: Any, writer: StreamWriter, **kwargs: Any - ) -> Any: - async for chunk in func(*args, **kwargs): - writer(chunk) - - bound = get_runnable_for_func(agen_wrapper) - stream_mode = "custom" - else: - bound = get_runnable_for_func(func) - stream_mode = "updates" + bound = get_runnable_for_func(func) + stream_mode: StreamMode = "updates" return Pregel( nodes={ diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index 3adea073a..f532b7bc8 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -42,6 +42,7 @@ CONFIG_KEY_STORE, CONFIG_KEY_TASK_ID, CONFIG_KEY_WRITES, + CONFIG_KEY_END, EMPTY_SEQ, ERROR, INTERRUPT, @@ -560,6 +561,9 @@ def prepare_single_task( if w[0] in (NULL_TASK_ID, task_id) ], CONFIG_KEY_SCRATCHPAD: {}, + CONFIG_KEY_END: checkpoint["channel_values"].get( + "__end__", None + ), }, ), triggers, @@ -709,6 +713,9 @@ def prepare_single_task( if w[0] in (NULL_TASK_ID, task_id) ], CONFIG_KEY_SCRATCHPAD: {}, + CONFIG_KEY_END: checkpoint["channel_values"].get( + "__end__", None + ), }, ), triggers, @@ -833,6 +840,9 @@ def prepare_single_task( if w[0] in (NULL_TASK_ID, task_id) ], CONFIG_KEY_SCRATCHPAD: {}, + CONFIG_KEY_END: checkpoint["channel_values"].get( + "__end__", None + ), }, ), triggers, diff --git a/libs/langgraph/langgraph/utils/runnable.py b/libs/langgraph/langgraph/utils/runnable.py index 7cd6a85b9..e5109443f 100644 --- a/libs/langgraph/langgraph/utils/runnable.py +++ b/libs/langgraph/langgraph/utils/runnable.py @@ -34,7 +34,12 @@ from langchain_core.tracers._streaming import _StreamingCallbackHandler from typing_extensions import TypeGuard -from langgraph.constants import CONF, CONFIG_KEY_STORE, CONFIG_KEY_STREAM_WRITER +from langgraph.constants import ( + CONF, + CONFIG_KEY_STORE, + CONFIG_KEY_END, + CONFIG_KEY_STREAM_WRITER, +) from langgraph.store.base import BaseStore from langgraph.types import StreamWriter from langgraph.utils.config import ( @@ -58,6 +63,10 @@ class StrEnum(str, enum.Enum): """A string enum.""" +# Special type to denote any type is accepted +ANY_TYPE = object() + + ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11) KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = ( @@ -73,9 +82,22 @@ class StrEnum(str, enum.Enum): CONFIG_KEY_STORE, inspect.Parameter.empty, ), + ( + sys.intern("previous"), + (ANY_TYPE, ), + CONFIG_KEY_END, + inspect.Parameter.empty, + ), ) """List of kwargs that can be passed to functions, and their corresponding -config keys, default values and type annotations.""" +config keys, default values and type annotations. + +Each tuple contains: +- the name of the kwarg in the function signature +- the type annotation(s) for the kwarg +- the config key to look for the value in +- the default value for the kwarg +""" VALID_KINDS = (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) @@ -122,9 +144,12 @@ def __init__( self.func_accepts: dict[str, bool] = {} for kw, typ, _, _ in KWARGS_CONFIG_KEYS: p = params.get(kw) - self.func_accepts[kw] = ( - p is not None and p.annotation in typ and p.kind in VALID_KINDS - ) + if typ == (ANY_TYPE, ): + self.func_accepts[kw] = p is not None and p.kind in VALID_KINDS + else: + self.func_accepts[kw] = ( + p is not None and p.annotation in typ and p.kind in VALID_KINDS + ) def __repr__(self) -> str: repr_args = { @@ -149,16 +174,20 @@ def invoke( if self.func_accepts_config: kwargs["config"] = config _conf = config[CONF] - for kw, _, ck, defv in KWARGS_CONFIG_KEYS: + for kw, _, config_key, default_value in KWARGS_CONFIG_KEYS: if not self.func_accepts[kw]: continue - if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf: + if ( + default_value is inspect.Parameter.empty + and kw not in kwargs + and config_key not in _conf + ): raise ValueError( - f"Missing required config key '{ck}' for '{self.name}'." + f"Missing required config key '{config_key}' for '{self.name}'." ) elif kwargs.get(kw) is None: - kwargs[kw] = _conf.get(ck, defv) + kwargs[kw] = _conf.get(config_key, default_value) context = copy_context() if self.trace: @@ -197,16 +226,20 @@ async def ainvoke( if self.func_accepts_config: kwargs["config"] = config _conf = config[CONF] - for kw, _, ck, defv in KWARGS_CONFIG_KEYS: + for kw, _, config_key, default_value in KWARGS_CONFIG_KEYS: if not self.func_accepts[kw]: continue - if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf: + if ( + default_value is inspect.Parameter.empty + and kw not in kwargs + and config_key not in _conf + ): raise ValueError( - f"Missing required config key '{ck}' for '{self.name}'." + f"Missing required config key '{config_key}' for '{self.name}'." ) elif kwargs.get(kw) is None: - kwargs[kw] = _conf.get(ck, defv) + kwargs[kw] = _conf.get(config_key, default_value) context = copy_context() if self.trace: callback_manager = get_async_callback_manager_for_config(config, self.tags) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index d1f1c66eb..00568da18 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -5300,3 +5300,27 @@ def node_b(state): {"node_a": [{"foo": "a1"}, {"foo": "a2"}]}, {"node_b": {"foo": "b"}}, ] + + +def test_version_1_of_entrypoint() -> None: + from langgraph.func import entrypoint + from typing import TypedDict, Annotated, NotRequired, Any + + states = [] + + # In this version reducers do not work + @entrypoint(checkpointer=MemorySaver()) + def foo(inputs, *, previous: Any) -> Any: + states.append(previous) + return {"previous": previous, "current": inputs} + + config = {"configurable": {"thread_id": "1"}} + + foo.invoke({"a": "1"}, config) + foo.invoke({"a": "2"}, config) + foo.invoke({"a": "3"}, config) + assert states == [ + None, + {"current": {"a": "1"}, "previous": None}, + {"current": {"a": "2"}, "previous": {"current": {"a": "1"}, "previous": None}}, + ]