Skip to content

Commit

Permalink
functional api: Add ability to request previous output (#3025)
Browse files Browse the repository at this point in the history
1. The inputs into foo do not affect any state behavior
2. `previous` always reflects the previous return value from the
function
3. Anything can be returned and that will be the new state for the
function on the next iteration
4. This API is not meant to support reducers in the inputs/state

```python
  from langgraph.func import entrypoint

  states = []

  # In this version reducers do not work
  @entrypoint(checkpointer=MemorySaver())
  def foo(inputs, *, previous: Any) -> Any:
      states.append(previous)
      return {"previous": previous, "current": inputs}

  config = {"configurable": {"thread_id": "1"}}

  foo.invoke({"a": "1"}, config)
  foo.invoke({"a": "2"}, config)
  foo.invoke({"a": "3"}, config)
  assert states == [
      None,
      {"current": {"a": "1"}, "previous": None},
      {"current": {"a": "2"}, "previous": {"current": {"a": "1"}, "previous": None}},
  ]
```
  • Loading branch information
eyurtsev authored Jan 16, 2025
2 parents 47122ce + a58f5da commit ce30965
Show file tree
Hide file tree
Showing 7 changed files with 308 additions and 23 deletions.
2 changes: 2 additions & 0 deletions libs/langgraph/langgraph/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
# holds a callback to be called when a node is finished
CONFIG_KEY_SCRATCHPAD = sys.intern("__pregel_scratchpad")
# holds a mutable dict for temporary storage scoped to the current task
CONFIG_KEY_END = sys.intern("__pregel_previous")
# holds the previous return value from a stateful Pregel graph.

# --- Other constants ---
PUSH = sys.intern("__pregel_push")
Expand Down
96 changes: 85 additions & 11 deletions libs/langgraph/langgraph/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,22 +113,96 @@ def entrypoint(
config_schema: Optional[type[Any]] = None,
) -> Callable[[types.FunctionType], Pregel]:
def _imp(func: types.FunctionType) -> Pregel:
# wrap generators in a function that writes to StreamWriter
if inspect.isgeneratorfunction(func):
"""Convert a function into a Pregel graph.
def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any:
for chunk in func(*args, **kwargs):
writer(chunk)
Args:
func: The function to convert. Support both sync and async functions, as well
as generator and async generator functions.
Returns:
A Pregel graph.
"""
# wrap generators in a function that writes to StreamWriter
if inspect.isgeneratorfunction(func):
original_sig = inspect.signature(func)
# Check if original signature has a writer argument with a matching type.
# If not, we'll inject it into the decorator, but not pass it
# to the wrapped function.
if "writer" in original_sig.parameters:

@functools.wraps(func)
def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any:
chunks = []
for chunk in func(*args, writer=writer, **kwargs):
writer(chunk)
chunks.append(chunk)
return chunks
else:

@functools.wraps(func)
def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any:
chunks = []
# Do not pass the writer argument to the wrapped function
# as it does not have a matching parameter
for chunk in func(*args, **kwargs):
writer(chunk)
chunks.append(chunk)
return chunks

# Create a new parameter for the writer argument
extra_param = inspect.Parameter(
"writer",
inspect.Parameter.KEYWORD_ONLY,
# The extra argument is a keyword-only argument
default=lambda _: None,
)
# Update the function's signature to include the extra argument
new_params = list(original_sig.parameters.values()) + [extra_param]
new_sig = original_sig.replace(parameters=new_params)
# Update the signature of the wrapper function
gen_wrapper.__signature__ = new_sig # type: ignore
bound = get_runnable_for_func(gen_wrapper)
stream_mode: StreamMode = "custom"
elif inspect.isasyncgenfunction(func):

async def agen_wrapper(
*args: Any, writer: StreamWriter, **kwargs: Any
) -> Any:
async for chunk in func(*args, **kwargs):
writer(chunk)
original_sig = inspect.signature(func)
# Check if original signature has a writer argument with a matching type.
# If not, we'll inject it into the decorator, but not pass it
# to the wrapped function.
if "writer" in original_sig.parameters:

@functools.wraps(func)
async def agen_wrapper(
*args: Any, writer: StreamWriter, **kwargs: Any
) -> Any:
chunks = []
async for chunk in func(*args, writer=writer, **kwargs):
writer(chunk)
chunks.append(chunk)
return chunks
else:

@functools.wraps(func)
async def agen_wrapper(
*args: Any, writer: StreamWriter, **kwargs: Any
) -> Any:
chunks = []
async for chunk in func(*args, **kwargs):
writer(chunk)
chunks.append(chunk)
return chunks

# Create a new parameter for the writer argument
extra_param = inspect.Parameter(
"writer",
inspect.Parameter.KEYWORD_ONLY,
# The extra argument is a keyword-only argument
default=lambda _: None,
)
# Update the function's signature to include the extra argument
new_params = list(original_sig.parameters.values()) + [extra_param]
new_sig = original_sig.replace(parameters=new_params)
# Update the signature of the wrapper function
agen_wrapper.__signature__ = new_sig # type: ignore

bound = get_runnable_for_func(agen_wrapper)
stream_mode = "custom"
Expand Down
10 changes: 10 additions & 0 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_CHECKPOINTER,
CONFIG_KEY_END,
CONFIG_KEY_READ,
CONFIG_KEY_SCRATCHPAD,
CONFIG_KEY_SEND,
Expand Down Expand Up @@ -507,6 +508,9 @@ def prepare_single_task(
pending_writes,
task_id,
),
CONFIG_KEY_END: checkpoint["channel_values"].get(
"__end__", None
),
},
),
triggers,
Expand Down Expand Up @@ -616,6 +620,9 @@ def prepare_single_task(
pending_writes,
task_id,
),
CONFIG_KEY_END: checkpoint["channel_values"].get(
"__end__", None
),
},
),
triggers,
Expand Down Expand Up @@ -737,6 +744,9 @@ def prepare_single_task(
pending_writes,
task_id,
),
CONFIG_KEY_END: checkpoint["channel_values"].get(
"__end__", None
),
},
),
triggers,
Expand Down
50 changes: 38 additions & 12 deletions libs/langgraph/langgraph/utils/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from typing_extensions import TypeGuard

from langgraph.constants import CONF, CONFIG_KEY_STORE, CONFIG_KEY_STREAM_WRITER
from langgraph.constants import (
CONF,
CONFIG_KEY_END,
CONFIG_KEY_STORE,
CONFIG_KEY_STREAM_WRITER,
)
from langgraph.store.base import BaseStore
from langgraph.types import StreamWriter
from langgraph.utils.config import (
Expand All @@ -58,6 +63,10 @@ class StrEnum(str, enum.Enum):
"""A string enum."""


# Special type to denote any type is accepted
ANY_TYPE = object()


ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11)

KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = (
Expand All @@ -73,6 +82,12 @@ class StrEnum(str, enum.Enum):
CONFIG_KEY_STORE,
inspect.Parameter.empty,
),
(
sys.intern("previous"),
(ANY_TYPE,),
CONFIG_KEY_END,
inspect.Parameter.empty,
),
)
"""List of kwargs that can be passed to functions, and their corresponding
config keys, default values and type annotations.
Expand Down Expand Up @@ -135,9 +150,12 @@ def __init__(
self.func_accepts: dict[str, bool] = {}
for kw, typ, _, _ in KWARGS_CONFIG_KEYS:
p = params.get(kw)
self.func_accepts[kw] = (
p is not None and p.annotation in typ and p.kind in VALID_KINDS
)
if typ == (ANY_TYPE,):
self.func_accepts[kw] = p is not None and p.kind in VALID_KINDS
else:
self.func_accepts[kw] = (
p is not None and p.annotation in typ and p.kind in VALID_KINDS
)

def __repr__(self) -> str:
repr_args = {
Expand All @@ -162,16 +180,20 @@ def invoke(
if self.func_accepts_config:
kwargs["config"] = config
_conf = config[CONF]
for kw, _, ck, defv in KWARGS_CONFIG_KEYS:
for kw, _, config_key, default_value in KWARGS_CONFIG_KEYS:
if not self.func_accepts[kw]:
continue

if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf:
if (
default_value is inspect.Parameter.empty
and kw not in kwargs
and config_key not in _conf
):
raise ValueError(
f"Missing required config key '{ck}' for '{self.name}'."
f"Missing required config key '{config_key}' for '{self.name}'."
)
elif kwargs.get(kw) is None:
kwargs[kw] = _conf.get(ck, defv)
kwargs[kw] = _conf.get(config_key, default_value)

context = copy_context()
if self.trace:
Expand Down Expand Up @@ -210,16 +232,20 @@ async def ainvoke(
if self.func_accepts_config:
kwargs["config"] = config
_conf = config[CONF]
for kw, _, ck, defv in KWARGS_CONFIG_KEYS:
for kw, _, config_key, default_value in KWARGS_CONFIG_KEYS:
if not self.func_accepts[kw]:
continue

if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf:
if (
default_value is inspect.Parameter.empty
and kw not in kwargs
and config_key not in _conf
):
raise ValueError(
f"Missing required config key '{ck}' for '{self.name}'."
f"Missing required config key '{config_key}' for '{self.name}'."
)
elif kwargs.get(kw) is None:
kwargs[kw] = _conf.get(ck, defv)
kwargs[kw] = _conf.get(config_key, default_value)
context = copy_context()
if self.trace:
callback_manager = get_async_callback_manager_for_config(config, self.tags)
Expand Down
Loading

0 comments on commit ce30965

Please sign in to comment.