From 67a800dd988905dafe9af112114c4183ec1424e4 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 14 Jan 2025 15:37:10 -0500 Subject: [PATCH 01/13] update --- libs/langgraph/langgraph/constants.py | 2 + libs/langgraph/langgraph/func/__init__.py | 23 +-------- libs/langgraph/langgraph/pregel/algo.py | 10 ++++ libs/langgraph/langgraph/utils/runnable.py | 59 +++++++++++++++++----- libs/langgraph/tests/test_pregel.py | 24 +++++++++ 5 files changed, 84 insertions(+), 34 deletions(-) diff --git a/libs/langgraph/langgraph/constants.py b/libs/langgraph/langgraph/constants.py index cd847f9be..caba214cb 100644 --- a/libs/langgraph/langgraph/constants.py +++ b/libs/langgraph/langgraph/constants.py @@ -81,6 +81,8 @@ # read-only list of existing task writes CONFIG_KEY_SCRATCHPAD = sys.intern("__pregel_scratchpad") # holds a mutable dict for temporary storage scoped to the current task +CONFIG_KEY_END = sys.intern("__pregel_previous") +# holds the previous return value from a stateful Pregel graph. # --- Other constants --- PUSH = sys.intern("__pregel_push") diff --git a/libs/langgraph/langgraph/func/__init__.py b/libs/langgraph/langgraph/func/__init__.py index 62bf138e6..76e97d7e5 100644 --- a/libs/langgraph/langgraph/func/__init__.py +++ b/libs/langgraph/langgraph/func/__init__.py @@ -112,27 +112,8 @@ def entrypoint( store: Optional[BaseStore] = None, ) -> Callable[[types.FunctionType], Pregel]: def _imp(func: types.FunctionType) -> Pregel: - if inspect.isgeneratorfunction(func): - - def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any: - for chunk in func(*args, **kwargs): - writer(chunk) - - bound = get_runnable_for_func(gen_wrapper) - stream_mode: StreamMode = "custom" - elif inspect.isasyncgenfunction(func): - - async def agen_wrapper( - *args: Any, writer: StreamWriter, **kwargs: Any - ) -> Any: - async for chunk in func(*args, **kwargs): - writer(chunk) - - bound = get_runnable_for_func(agen_wrapper) - stream_mode = "custom" - else: - bound = get_runnable_for_func(func) - stream_mode = "updates" + bound = get_runnable_for_func(func) + stream_mode: StreamMode = "updates" return Pregel( nodes={ diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index 3adea073a..f532b7bc8 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -42,6 +42,7 @@ CONFIG_KEY_STORE, CONFIG_KEY_TASK_ID, CONFIG_KEY_WRITES, + CONFIG_KEY_END, EMPTY_SEQ, ERROR, INTERRUPT, @@ -560,6 +561,9 @@ def prepare_single_task( if w[0] in (NULL_TASK_ID, task_id) ], CONFIG_KEY_SCRATCHPAD: {}, + CONFIG_KEY_END: checkpoint["channel_values"].get( + "__end__", None + ), }, ), triggers, @@ -709,6 +713,9 @@ def prepare_single_task( if w[0] in (NULL_TASK_ID, task_id) ], CONFIG_KEY_SCRATCHPAD: {}, + CONFIG_KEY_END: checkpoint["channel_values"].get( + "__end__", None + ), }, ), triggers, @@ -833,6 +840,9 @@ def prepare_single_task( if w[0] in (NULL_TASK_ID, task_id) ], CONFIG_KEY_SCRATCHPAD: {}, + CONFIG_KEY_END: checkpoint["channel_values"].get( + "__end__", None + ), }, ), triggers, diff --git a/libs/langgraph/langgraph/utils/runnable.py b/libs/langgraph/langgraph/utils/runnable.py index 7cd6a85b9..e5109443f 100644 --- a/libs/langgraph/langgraph/utils/runnable.py +++ b/libs/langgraph/langgraph/utils/runnable.py @@ -34,7 +34,12 @@ from langchain_core.tracers._streaming import _StreamingCallbackHandler from typing_extensions import TypeGuard -from langgraph.constants import CONF, CONFIG_KEY_STORE, CONFIG_KEY_STREAM_WRITER +from langgraph.constants import ( + CONF, + CONFIG_KEY_STORE, + CONFIG_KEY_END, + CONFIG_KEY_STREAM_WRITER, +) from langgraph.store.base import BaseStore from langgraph.types import StreamWriter from langgraph.utils.config import ( @@ -58,6 +63,10 @@ class StrEnum(str, enum.Enum): """A string enum.""" +# Special type to denote any type is accepted +ANY_TYPE = object() + + ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11) KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = ( @@ -73,9 +82,22 @@ class StrEnum(str, enum.Enum): CONFIG_KEY_STORE, inspect.Parameter.empty, ), + ( + sys.intern("previous"), + (ANY_TYPE, ), + CONFIG_KEY_END, + inspect.Parameter.empty, + ), ) """List of kwargs that can be passed to functions, and their corresponding -config keys, default values and type annotations.""" +config keys, default values and type annotations. + +Each tuple contains: +- the name of the kwarg in the function signature +- the type annotation(s) for the kwarg +- the config key to look for the value in +- the default value for the kwarg +""" VALID_KINDS = (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) @@ -122,9 +144,12 @@ def __init__( self.func_accepts: dict[str, bool] = {} for kw, typ, _, _ in KWARGS_CONFIG_KEYS: p = params.get(kw) - self.func_accepts[kw] = ( - p is not None and p.annotation in typ and p.kind in VALID_KINDS - ) + if typ == (ANY_TYPE, ): + self.func_accepts[kw] = p is not None and p.kind in VALID_KINDS + else: + self.func_accepts[kw] = ( + p is not None and p.annotation in typ and p.kind in VALID_KINDS + ) def __repr__(self) -> str: repr_args = { @@ -149,16 +174,20 @@ def invoke( if self.func_accepts_config: kwargs["config"] = config _conf = config[CONF] - for kw, _, ck, defv in KWARGS_CONFIG_KEYS: + for kw, _, config_key, default_value in KWARGS_CONFIG_KEYS: if not self.func_accepts[kw]: continue - if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf: + if ( + default_value is inspect.Parameter.empty + and kw not in kwargs + and config_key not in _conf + ): raise ValueError( - f"Missing required config key '{ck}' for '{self.name}'." + f"Missing required config key '{config_key}' for '{self.name}'." ) elif kwargs.get(kw) is None: - kwargs[kw] = _conf.get(ck, defv) + kwargs[kw] = _conf.get(config_key, default_value) context = copy_context() if self.trace: @@ -197,16 +226,20 @@ async def ainvoke( if self.func_accepts_config: kwargs["config"] = config _conf = config[CONF] - for kw, _, ck, defv in KWARGS_CONFIG_KEYS: + for kw, _, config_key, default_value in KWARGS_CONFIG_KEYS: if not self.func_accepts[kw]: continue - if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf: + if ( + default_value is inspect.Parameter.empty + and kw not in kwargs + and config_key not in _conf + ): raise ValueError( - f"Missing required config key '{ck}' for '{self.name}'." + f"Missing required config key '{config_key}' for '{self.name}'." ) elif kwargs.get(kw) is None: - kwargs[kw] = _conf.get(ck, defv) + kwargs[kw] = _conf.get(config_key, default_value) context = copy_context() if self.trace: callback_manager = get_async_callback_manager_for_config(config, self.tags) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index d1f1c66eb..00568da18 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -5300,3 +5300,27 @@ def node_b(state): {"node_a": [{"foo": "a1"}, {"foo": "a2"}]}, {"node_b": {"foo": "b"}}, ] + + +def test_version_1_of_entrypoint() -> None: + from langgraph.func import entrypoint + from typing import TypedDict, Annotated, NotRequired, Any + + states = [] + + # In this version reducers do not work + @entrypoint(checkpointer=MemorySaver()) + def foo(inputs, *, previous: Any) -> Any: + states.append(previous) + return {"previous": previous, "current": inputs} + + config = {"configurable": {"thread_id": "1"}} + + foo.invoke({"a": "1"}, config) + foo.invoke({"a": "2"}, config) + foo.invoke({"a": "3"}, config) + assert states == [ + None, + {"current": {"a": "1"}, "previous": None}, + {"current": {"a": "2"}, "previous": {"current": {"a": "1"}, "previous": None}}, + ] From 311fe3970ca7c124f076df376c39860b9c5075a6 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 14 Jan 2025 15:39:57 -0500 Subject: [PATCH 02/13] x --- libs/langgraph/langgraph/func/__init__.py | 3 +-- libs/langgraph/langgraph/pregel/algo.py | 2 +- libs/langgraph/langgraph/utils/runnable.py | 6 +++--- libs/langgraph/tests/test_pregel.py | 1 - 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/libs/langgraph/langgraph/func/__init__.py b/libs/langgraph/langgraph/func/__init__.py index 76e97d7e5..86c5403f9 100644 --- a/libs/langgraph/langgraph/func/__init__.py +++ b/libs/langgraph/langgraph/func/__init__.py @@ -2,7 +2,6 @@ import concurrent import concurrent.futures import functools -import inspect import types from typing import ( Any, @@ -25,7 +24,7 @@ from langgraph.pregel.read import PregelNode from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry from langgraph.store.base import BaseStore -from langgraph.types import RetryPolicy, StreamMode, StreamWriter +from langgraph.types import RetryPolicy, StreamMode P = ParamSpec("P") P1 = TypeVar("P1") diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index f532b7bc8..0d2553240 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -36,13 +36,13 @@ CONFIG_KEY_CHECKPOINT_MAP, CONFIG_KEY_CHECKPOINT_NS, CONFIG_KEY_CHECKPOINTER, + CONFIG_KEY_END, CONFIG_KEY_READ, CONFIG_KEY_SCRATCHPAD, CONFIG_KEY_SEND, CONFIG_KEY_STORE, CONFIG_KEY_TASK_ID, CONFIG_KEY_WRITES, - CONFIG_KEY_END, EMPTY_SEQ, ERROR, INTERRUPT, diff --git a/libs/langgraph/langgraph/utils/runnable.py b/libs/langgraph/langgraph/utils/runnable.py index e5109443f..d5e6cf1d2 100644 --- a/libs/langgraph/langgraph/utils/runnable.py +++ b/libs/langgraph/langgraph/utils/runnable.py @@ -36,8 +36,8 @@ from langgraph.constants import ( CONF, - CONFIG_KEY_STORE, CONFIG_KEY_END, + CONFIG_KEY_STORE, CONFIG_KEY_STREAM_WRITER, ) from langgraph.store.base import BaseStore @@ -84,7 +84,7 @@ class StrEnum(str, enum.Enum): ), ( sys.intern("previous"), - (ANY_TYPE, ), + (ANY_TYPE,), CONFIG_KEY_END, inspect.Parameter.empty, ), @@ -144,7 +144,7 @@ def __init__( self.func_accepts: dict[str, bool] = {} for kw, typ, _, _ in KWARGS_CONFIG_KEYS: p = params.get(kw) - if typ == (ANY_TYPE, ): + if typ == (ANY_TYPE,): self.func_accepts[kw] = p is not None and p.kind in VALID_KINDS else: self.func_accepts[kw] = ( diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 00568da18..8cf2e6b4f 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -5304,7 +5304,6 @@ def node_b(state): def test_version_1_of_entrypoint() -> None: from langgraph.func import entrypoint - from typing import TypedDict, Annotated, NotRequired, Any states = [] From 7603809a9fcdd2782fb93e314e1d336f6e97dd96 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 14 Jan 2025 18:06:09 -0500 Subject: [PATCH 03/13] x --- libs/langgraph/langgraph/func/__init__.py | 10 +++- libs/langgraph/tests/test_pregel.py | 72 +++++++++++++++++++++-- 2 files changed, 75 insertions(+), 7 deletions(-) diff --git a/libs/langgraph/langgraph/func/__init__.py b/libs/langgraph/langgraph/func/__init__.py index 86c5403f9..c7b4d1bb4 100644 --- a/libs/langgraph/langgraph/func/__init__.py +++ b/libs/langgraph/langgraph/func/__init__.py @@ -2,6 +2,7 @@ import concurrent import concurrent.futures import functools +import inspect import types from typing import ( Any, @@ -111,8 +112,13 @@ def entrypoint( store: Optional[BaseStore] = None, ) -> Callable[[types.FunctionType], Pregel]: def _imp(func: types.FunctionType) -> Pregel: - bound = get_runnable_for_func(func) - stream_mode: StreamMode = "updates" + if inspect.isgeneratorfunction(func): + raise TypeError("@entrypoint does not support generator functions.") + elif inspect.isasyncgenfunction(func): + raise TypeError("@entrypoint does not support async generator functions.") + else: + bound = get_runnable_for_func(func) + stream_mode: StreamMode = "updates" return Pregel( nodes={ diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 8cf2e6b4f..432eea069 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -15,6 +15,7 @@ Any, Dict, Generator, + Iterable, Iterator, List, Literal, @@ -5302,9 +5303,32 @@ def node_b(state): ] -def test_version_1_of_entrypoint() -> None: - from langgraph.func import entrypoint +def test_entrypoint_without_checkpointer() -> None: + """Test no checkpointer.""" + states = [] + config = {"configurable": {"thread_id": "1"}} + + # Test without previous + @entrypoint() + def foo(inputs: Any) -> Any: + states.append(inputs) + return inputs + + assert foo.invoke({"a": "1"}, config) == {"a": "1"} + + @entrypoint() + def foo(inputs: Any, *, previous: Any) -> Any: + states.append(previous) + return {"previous": previous, "current": inputs} + + assert foo.invoke({"a": "1"}, config) == {"current": {"a": "1"}, "previous": None} + assert foo.invoke({"a": "1"}, config) == {"current": {"a": "1"}, "previous": None} + +def test_entrypoint_stateful() -> None: + """Test stateful entrypoint invoke.""" + + # Test invoke states = [] # In this version reducers do not work @@ -5315,11 +5339,49 @@ def foo(inputs, *, previous: Any) -> Any: config = {"configurable": {"thread_id": "1"}} - foo.invoke({"a": "1"}, config) - foo.invoke({"a": "2"}, config) - foo.invoke({"a": "3"}, config) + assert foo.invoke({"a": "1"}, config) == {"current": {"a": "1"}, "previous": None} + assert foo.invoke({"a": "2"}, config) == { + "current": {"a": "2"}, + "previous": {"current": {"a": "1"}, "previous": None}, + } + assert foo.invoke({"a": "3"}, config) == { + "current": {"a": "3"}, + "previous": { + "current": {"a": "2"}, + "previous": {"current": {"a": "1"}, "previous": None}, + }, + } assert states == [ None, {"current": {"a": "1"}, "previous": None}, {"current": {"a": "2"}, "previous": {"current": {"a": "1"}, "previous": None}}, ] + + # Test stream + @entrypoint(checkpointer=MemorySaver()) + def foo(inputs, *, previous: Any) -> Any: + return {"previous": previous, "current": inputs} + + config = {"configurable": {"thread_id": "1"}} + items = [item for item in foo.stream({"a": "1"}, config)] + assert items == [{"foo": {"current": {"a": "1"}, "previous": None}}] + + +async def test_entrypoint_from_generator() -> None: + """@entrypoint does not support sync generators.""" + + with pytest.raises(TypeError): + + @entrypoint(checkpointer=MemorySaver()) + def foo(inputs: Any) -> Iterable[dict]: + yield "a" + + +async def test_entrypoint_from_async_generator() -> None: + """@entrypoint does not support async generators.""" + + with pytest.raises(TypeError): + + @entrypoint(checkpointer=MemorySaver()) + def foo(inputs: Any) -> Iterable[dict]: + yield "a" From 6b707cbfc5366130da17a562fc8d952ad15cb671 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 15 Jan 2025 17:15:56 -0500 Subject: [PATCH 04/13] x --- libs/langgraph/langgraph/func/__init__.py | 98 +++++++++++++++++++++-- libs/langgraph/tests/test_pregel.py | 71 +++++++++++++--- 2 files changed, 153 insertions(+), 16 deletions(-) diff --git a/libs/langgraph/langgraph/func/__init__.py b/libs/langgraph/langgraph/func/__init__.py index c7b4d1bb4..14c69ac89 100644 --- a/libs/langgraph/langgraph/func/__init__.py +++ b/libs/langgraph/langgraph/func/__init__.py @@ -19,13 +19,13 @@ from langgraph.channels.ephemeral_value import EphemeralValue from langgraph.channels.last_value import LastValue from langgraph.checkpoint.base import BaseCheckpointSaver -from langgraph.constants import END, START, TAG_HIDDEN +from langgraph.constants import CONF, END, START, TAG_HIDDEN from langgraph.pregel import Pregel from langgraph.pregel.call import get_runnable_for_func from langgraph.pregel.read import PregelNode from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry from langgraph.store.base import BaseStore -from langgraph.types import RetryPolicy, StreamMode +from langgraph.types import RetryPolicy, StreamMode, StreamWriter P = ParamSpec("P") P1 = TypeVar("P1") @@ -112,13 +112,101 @@ def entrypoint( store: Optional[BaseStore] = None, ) -> Callable[[types.FunctionType], Pregel]: def _imp(func: types.FunctionType) -> Pregel: + """Convert a function into a Pregel graph. + + Args: + func: The function to convert. Support both sync and async functions, as well + as generator and async generator functions. + + Returns: + A Pregel graph. + """ if inspect.isgeneratorfunction(func): - raise TypeError("@entrypoint does not support generator functions.") + original_sig = inspect.signature(func) + # Check if original signature has a writer argument with a matching type. + # If not, we'll inject it into the decorator, but not pass it + # to the wrapped function. + if "writer" in original_sig.parameters: + + @functools.wraps(func) + def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any: + chunks = [] + for chunk in func(*args, writer=writer, **kwargs): + writer(chunk) + chunks.append(chunk) + return chunks + else: + + @functools.wraps(func) + def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any: + chunks = [] + # Do not pass the writer argument to the wrapped function + # as it does not have a matching parameter + for chunk in func(*args, **kwargs): + writer(chunk) + chunks.append(chunk) + return chunks + + # Create a new parameter for the writer argument + extra_param = inspect.Parameter( + "writer", + inspect.Parameter.KEYWORD_ONLY, + # The extra argument is a keyword-only argument + default=lambda _: None, + ) + # Update the function's signature to include the extra argument + new_params = list(original_sig.parameters.values()) + [extra_param] + new_sig = original_sig.replace(parameters=new_params) + # Update the signature of the wrapper function + gen_wrapper.__signature__ = new_sig + bound = get_runnable_for_func(gen_wrapper) + stream_mode: StreamMode = "custom" elif inspect.isasyncgenfunction(func): - raise TypeError("@entrypoint does not support async generator functions.") + original_sig = inspect.signature(func) + # Check if original signature has a writer argument with a matching type. + # If not, we'll inject it into the decorator, but not pass it + # to the wrapped function. + if "writer" in original_sig.parameters: + + @functools.wraps(func) + async def agen_wrapper( + *args: Any, writer: StreamWriter, **kwargs: Any + ) -> Any: + chunks = [] + async for chunk in func(*args, writer=writer, **kwargs): + writer(chunk) + chunks.append(chunk) + return chunks + else: + + @functools.wraps(func) + async def agen_wrapper( + *args: Any, writer: StreamWriter, **kwargs: Any + ) -> Any: + chunks = [] + async for chunk in func(*args, **kwargs): + writer(chunk) + chunks.append(chunk) + return chunks + + # Create a new parameter for the writer argument + extra_param = inspect.Parameter( + "writer", + inspect.Parameter.KEYWORD_ONLY, + # The extra argument is a keyword-only argument + default=lambda _: None, + ) + # Update the function's signature to include the extra argument + new_params = list(original_sig.parameters.values()) + [extra_param] + new_sig = original_sig.replace(parameters=new_params) + # Update the signature of the wrapper function + agen_wrapper.__signature__ = new_sig + + bound = get_runnable_for_func(agen_wrapper) + stream_mode = "custom" else: bound = get_runnable_for_func(func) - stream_mode: StreamMode = "updates" + stream_mode = "updates" return Pregel( nodes={ diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 432eea069..f776a2580 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -5331,7 +5331,6 @@ def test_entrypoint_stateful() -> None: # Test invoke states = [] - # In this version reducers do not work @entrypoint(checkpointer=MemorySaver()) def foo(inputs, *, previous: Any) -> Any: states.append(previous) @@ -5367,21 +5366,71 @@ def foo(inputs, *, previous: Any) -> Any: assert items == [{"foo": {"current": {"a": "1"}, "previous": None}}] -async def test_entrypoint_from_generator() -> None: +def test_entrypoint_from_sync_generator() -> None: """@entrypoint does not support sync generators.""" + previous_return_values = [] - with pytest.raises(TypeError): + @entrypoint(checkpointer=MemorySaver()) + def foo(inputs, previous=None) -> Any: + previous_return_values.append(previous) + yield "a" + yield "b" + + config = {"configurable": {"thread_id": "1"}} + + assert foo.invoke({"a": "1"}, config) == ["a", "b"] + assert previous_return_values == [None] + assert foo.invoke({"a": "2"}, config) == ["a", "b"] + assert previous_return_values == [None, ["a", "b"]] + + +def test_entrypoint_request_stream_writer() -> None: + """Test using a stream writer with an entrypoint.""" + + @entrypoint(checkpointer=MemorySaver()) + def foo(inputs, writer: StreamWriter) -> Any: + writer("a") + yield "b" + + config = {"configurable": {"thread_id": "1"}} + + # Different invocations + # Are any of these confusing or unexpected? + assert list(foo.invoke({}, config)) == ["b"] + assert list(foo.stream({}, config)) == ["a", "b"] - @entrypoint(checkpointer=MemorySaver()) - def foo(inputs: Any) -> Iterable[dict]: - yield "a" + # Stream modes + assert list(foo.stream({}, config, stream_mode=["updates"])) == [ + ("updates", {"foo": ["b"]}) + ] + assert list(foo.stream({}, config, stream_mode=["values"])) == [("values", ["b"])] + assert list(foo.stream({}, config, stream_mode=["custom"])) == [ + ( + "custom", + "a", + ), + ( + "custom", + "b", + ), + ] async def test_entrypoint_from_async_generator() -> None: - """@entrypoint does not support async generators.""" + """@entrypoint does not support sync generators.""" + # Test invoke + previous_return_values = [] - with pytest.raises(TypeError): + # In this version reducers do not work + @entrypoint(checkpointer=MemorySaver()) + async def foo(inputs, previous=None) -> Any: + previous_return_values.append(previous) + yield "a" + yield "b" + + config = {"configurable": {"thread_id": "1"}} - @entrypoint(checkpointer=MemorySaver()) - def foo(inputs: Any) -> Iterable[dict]: - yield "a" + assert list(await foo.ainvoke({"a": "1"}, config)) == ["a", "b"] + assert previous_return_values == [None] + assert list(foo.invoke({"a": "2"}, config)) == ["a", "b"] + assert previous_return_values == [None, ["a", "b"]] From 421f7c0238b3a8740ee48c5cdd99e7185ea1e2e5 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 16 Jan 2025 13:55:41 -0500 Subject: [PATCH 05/13] x --- libs/langgraph/langgraph/pregel/loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index 1f2fe91a3..e92f0c267 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -345,11 +345,11 @@ def accept_push( (PUSH, task.path, write_idx, task.id, call), None, checkpoint=self.checkpoint, - pending_writes=[(task.id, *w) for w in task.writes], + pending_writes=self.checkpoint_pending_writes, processes=self.nodes, channels=self.channels, managed=self.managed, - config=self.config, + config=task.config, step=self.step, for_execution=True, store=self.store, From e476897177daf306b8a0e280971c27f00bd63b6c Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 16 Jan 2025 13:58:55 -0500 Subject: [PATCH 06/13] fix merge error --- libs/langgraph/langgraph/pregel/algo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index 08dcf77e6..d1247dbc4 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -507,6 +507,7 @@ def prepare_single_task( CONFIG_KEY_SCRATCHPAD: _scratchpad( pending_writes, task_id, + ), CONFIG_KEY_END: checkpoint["channel_values"].get( "__end__", None ), @@ -745,7 +746,7 @@ def prepare_single_task( ), CONFIG_KEY_END: checkpoint["channel_values"].get( "__end__", None - + ), }, ), triggers, From 582856b30c3c6ec0b9494c291bc9cfb4bac34218 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 16 Jan 2025 14:00:42 -0500 Subject: [PATCH 07/13] x --- libs/langgraph/tests/test_pregel.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 75958221f..56d18e1b2 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -5598,6 +5598,34 @@ def foo(inputs: Any, *, previous: Any) -> Any: assert foo.invoke({"a": "1"}, config) == {"current": {"a": "1"}, "previous": None} +async def test_async_entrypoint_without_checkpointer() -> None: + """Test no checkpointer.""" + states = [] + config = {"configurable": {"thread_id": "1"}} + + # Test without previous + @entrypoint() + async def foo(inputs: Any) -> Any: + states.append(inputs) + return inputs + + assert (await foo.ainvoke({"a": "1"}, config)) == {"a": "1"} + + @entrypoint() + async def foo(inputs: Any, *, previous: Any) -> Any: + states.append(previous) + return {"previous": previous, "current": inputs} + + assert (await foo.ainvoke({"a": "1"}, config)) == { + "current": {"a": "1"}, + "previous": None, + } + assert (await foo.ainvoke({"a": "1"}, config)) == { + "current": {"a": "1"}, + "previous": None, + } + + def test_entrypoint_stateful() -> None: """Test stateful entrypoint invoke.""" From 46907b6cf9f619d0685201e947606ec9a79e26a9 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 16 Jan 2025 14:04:00 -0500 Subject: [PATCH 08/13] x --- libs/langgraph/tests/test_pregel.py | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 56d18e1b2..c1575dd7f 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -15,7 +15,6 @@ Any, Dict, Generator, - Iterable, Iterator, List, Literal, From d021f476db8cb0227971d84b215df36a696207c0 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 16 Jan 2025 14:23:32 -0500 Subject: [PATCH 09/13] x --- libs/langgraph/langgraph/func/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/func/__init__.py b/libs/langgraph/langgraph/func/__init__.py index d873fdaed..65b4eb76e 100644 --- a/libs/langgraph/langgraph/func/__init__.py +++ b/libs/langgraph/langgraph/func/__init__.py @@ -122,7 +122,7 @@ def _imp(func: types.FunctionType) -> Pregel: Returns: A Pregel graph. """ - # wrap generators in a function that writes to StreamWriter + # wrap generators in a function that writes to StreamWriter if inspect.isgeneratorfunction(func): original_sig = inspect.signature(func) # Check if original signature has a writer argument with a matching type. From b218cc76a7366e13e0020f7aa20ddb30c6b52689 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 16 Jan 2025 14:50:40 -0500 Subject: [PATCH 10/13] type ignore for now --- libs/langgraph/langgraph/func/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langgraph/langgraph/func/__init__.py b/libs/langgraph/langgraph/func/__init__.py index 65b4eb76e..a051e465e 100644 --- a/libs/langgraph/langgraph/func/__init__.py +++ b/libs/langgraph/langgraph/func/__init__.py @@ -160,7 +160,7 @@ def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any: new_params = list(original_sig.parameters.values()) + [extra_param] new_sig = original_sig.replace(parameters=new_params) # Update the signature of the wrapper function - gen_wrapper.__signature__ = new_sig + gen_wrapper.__signature__ = new_sig # type: ignore bound = get_runnable_for_func(gen_wrapper) stream_mode: StreamMode = "custom" elif inspect.isasyncgenfunction(func): @@ -202,7 +202,7 @@ async def agen_wrapper( new_params = list(original_sig.parameters.values()) + [extra_param] new_sig = original_sig.replace(parameters=new_params) # Update the signature of the wrapper function - agen_wrapper.__signature__ = new_sig + agen_wrapper.__signature__ = new_sig # type: ignore bound = get_runnable_for_func(agen_wrapper) stream_mode = "custom" From e5f0db0af3177d227db7738f3eafe7fee0c14caf Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 16 Jan 2025 16:09:11 -0500 Subject: [PATCH 11/13] x --- libs/scheduler-kafka/tests/test_subgraph.py | 4 ++++ libs/scheduler-kafka/tests/test_subgraph_sync.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/libs/scheduler-kafka/tests/test_subgraph.py b/libs/scheduler-kafka/tests/test_subgraph.py index 053eaedd0..d0588a4fa 100644 --- a/libs/scheduler-kafka/tests/test_subgraph.py +++ b/libs/scheduler-kafka/tests/test_subgraph.py @@ -195,6 +195,7 @@ async def test_subgraph_w_interrupt( "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, + "__pregel_previous": None, "__pregel_store": None, "__pregel_task_id": history[0].tasks[0].id, "__pregel_scratchpad": { @@ -267,6 +268,7 @@ async def test_subgraph_w_interrupt( "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, + "__pregel_previous": None, "__pregel_store": None, "__pregel_task_id": history[0].tasks[0].id, "__pregel_scratchpad": { @@ -481,6 +483,7 @@ async def test_subgraph_w_interrupt( "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": True, + "__pregel_previous": None, "__pregel_store": None, "__pregel_task_id": history[1].tasks[0].id, "__pregel_scratchpad": { @@ -548,6 +551,7 @@ async def test_subgraph_w_interrupt( "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": True, + "__pregel_previous": None, "__pregel_store": None, "__pregel_task_id": history[1].tasks[0].id, "__pregel_scratchpad": { diff --git a/libs/scheduler-kafka/tests/test_subgraph_sync.py b/libs/scheduler-kafka/tests/test_subgraph_sync.py index 7d5de920c..c2c9a8fc1 100644 --- a/libs/scheduler-kafka/tests/test_subgraph_sync.py +++ b/libs/scheduler-kafka/tests/test_subgraph_sync.py @@ -194,6 +194,7 @@ def test_subgraph_w_interrupt( "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, + "__pregel_previous": None, "__pregel_store": None, "__pregel_task_id": history[0].tasks[0].id, "__pregel_scratchpad": { @@ -267,6 +268,7 @@ def test_subgraph_w_interrupt( "__pregel_store": None, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, + "__pregel_previous": None, "__pregel_task_id": history[0].tasks[0].id, "__pregel_scratchpad": { "subgraph_counter": AnyInt(), @@ -370,6 +372,7 @@ def test_subgraph_w_interrupt( "__pregel_store": None, "__pregel_resuming": False, "__pregel_task_id": history[0].tasks[0].id, + "__pregel_previous": None, "__pregel_scratchpad": { "subgraph_counter": AnyInt(), "call_counter": 0, @@ -480,6 +483,7 @@ def test_subgraph_w_interrupt( "__pregel_dedupe_tasks": True, "__pregel_store": None, "__pregel_resuming": True, + "__pregel_previous": None, "__pregel_task_id": history[1].tasks[0].id, "__pregel_scratchpad": { "subgraph_counter": AnyInt(), @@ -547,6 +551,7 @@ def test_subgraph_w_interrupt( "__pregel_dedupe_tasks": True, "__pregel_store": None, "__pregel_resuming": True, + "__pregel_previous": None, "__pregel_task_id": history[1].tasks[0].id, "__pregel_scratchpad": { "subgraph_counter": AnyInt(), @@ -669,6 +674,7 @@ def test_subgraph_w_interrupt( "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": True, + "__pregel_previous": None, "__pregel_store": None, "__pregel_task_id": history[1].tasks[0].id, "__pregel_scratchpad": { From 3a48049194d84fd8e79a291fab7acc8ba279ad1b Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 16 Jan 2025 16:19:23 -0500 Subject: [PATCH 12/13] x --- libs/scheduler-kafka/tests/test_subgraph.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/scheduler-kafka/tests/test_subgraph.py b/libs/scheduler-kafka/tests/test_subgraph.py index d0588a4fa..c85f335ae 100644 --- a/libs/scheduler-kafka/tests/test_subgraph.py +++ b/libs/scheduler-kafka/tests/test_subgraph.py @@ -371,6 +371,7 @@ async def test_subgraph_w_interrupt( "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": False, + "__pregel_previous": None, "__pregel_store": None, "__pregel_task_id": history[0].tasks[0].id, "__pregel_scratchpad": { From a58f5dacca4cead9f0cec5ab5c1805f355e473c3 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 16 Jan 2025 16:26:54 -0500 Subject: [PATCH 13/13] x --- libs/scheduler-kafka/tests/test_subgraph.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/scheduler-kafka/tests/test_subgraph.py b/libs/scheduler-kafka/tests/test_subgraph.py index c85f335ae..2a6c9992a 100644 --- a/libs/scheduler-kafka/tests/test_subgraph.py +++ b/libs/scheduler-kafka/tests/test_subgraph.py @@ -676,6 +676,7 @@ async def test_subgraph_w_interrupt( "__pregel_ensure_latest": True, "__pregel_dedupe_tasks": True, "__pregel_resuming": True, + "__pregel_previous": None, "__pregel_store": None, "__pregel_task_id": history[1].tasks[0].id, "__pregel_scratchpad": {