Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
eyurtsev committed Jan 14, 2025
1 parent a61ea10 commit 67a800d
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 34 deletions.
2 changes: 2 additions & 0 deletions libs/langgraph/langgraph/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@
# read-only list of existing task writes
CONFIG_KEY_SCRATCHPAD = sys.intern("__pregel_scratchpad")
# holds a mutable dict for temporary storage scoped to the current task
CONFIG_KEY_END = sys.intern("__pregel_previous")
# holds the previous return value from a stateful Pregel graph.

# --- Other constants ---
PUSH = sys.intern("__pregel_push")
Expand Down
23 changes: 2 additions & 21 deletions libs/langgraph/langgraph/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,8 @@ def entrypoint(
store: Optional[BaseStore] = None,
) -> Callable[[types.FunctionType], Pregel]:
def _imp(func: types.FunctionType) -> Pregel:
if inspect.isgeneratorfunction(func):

def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any:
for chunk in func(*args, **kwargs):
writer(chunk)

bound = get_runnable_for_func(gen_wrapper)
stream_mode: StreamMode = "custom"
elif inspect.isasyncgenfunction(func):

async def agen_wrapper(
*args: Any, writer: StreamWriter, **kwargs: Any
) -> Any:
async for chunk in func(*args, **kwargs):
writer(chunk)

bound = get_runnable_for_func(agen_wrapper)
stream_mode = "custom"
else:
bound = get_runnable_for_func(func)
stream_mode = "updates"
bound = get_runnable_for_func(func)
stream_mode: StreamMode = "updates"

return Pregel(
nodes={
Expand Down
10 changes: 10 additions & 0 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
CONFIG_KEY_STORE,
CONFIG_KEY_TASK_ID,
CONFIG_KEY_WRITES,
CONFIG_KEY_END,
EMPTY_SEQ,
ERROR,
INTERRUPT,
Expand Down Expand Up @@ -560,6 +561,9 @@ def prepare_single_task(
if w[0] in (NULL_TASK_ID, task_id)
],
CONFIG_KEY_SCRATCHPAD: {},
CONFIG_KEY_END: checkpoint["channel_values"].get(
"__end__", None
),
},
),
triggers,
Expand Down Expand Up @@ -709,6 +713,9 @@ def prepare_single_task(
if w[0] in (NULL_TASK_ID, task_id)
],
CONFIG_KEY_SCRATCHPAD: {},
CONFIG_KEY_END: checkpoint["channel_values"].get(
"__end__", None
),
},
),
triggers,
Expand Down Expand Up @@ -833,6 +840,9 @@ def prepare_single_task(
if w[0] in (NULL_TASK_ID, task_id)
],
CONFIG_KEY_SCRATCHPAD: {},
CONFIG_KEY_END: checkpoint["channel_values"].get(
"__end__", None
),
},
),
triggers,
Expand Down
59 changes: 46 additions & 13 deletions libs/langgraph/langgraph/utils/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from typing_extensions import TypeGuard

from langgraph.constants import CONF, CONFIG_KEY_STORE, CONFIG_KEY_STREAM_WRITER
from langgraph.constants import (
CONF,
CONFIG_KEY_STORE,
CONFIG_KEY_END,
CONFIG_KEY_STREAM_WRITER,
)
from langgraph.store.base import BaseStore
from langgraph.types import StreamWriter
from langgraph.utils.config import (
Expand All @@ -58,6 +63,10 @@ class StrEnum(str, enum.Enum):
"""A string enum."""


# Special type to denote any type is accepted
ANY_TYPE = object()


ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11)

KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = (
Expand All @@ -73,9 +82,22 @@ class StrEnum(str, enum.Enum):
CONFIG_KEY_STORE,
inspect.Parameter.empty,
),
(
sys.intern("previous"),
(ANY_TYPE, ),
CONFIG_KEY_END,
inspect.Parameter.empty,
),
)
"""List of kwargs that can be passed to functions, and their corresponding
config keys, default values and type annotations."""
config keys, default values and type annotations.
Each tuple contains:
- the name of the kwarg in the function signature
- the type annotation(s) for the kwarg
- the config key to look for the value in
- the default value for the kwarg
"""

VALID_KINDS = (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)

Expand Down Expand Up @@ -122,9 +144,12 @@ def __init__(
self.func_accepts: dict[str, bool] = {}
for kw, typ, _, _ in KWARGS_CONFIG_KEYS:
p = params.get(kw)
self.func_accepts[kw] = (
p is not None and p.annotation in typ and p.kind in VALID_KINDS
)
if typ == (ANY_TYPE, ):
self.func_accepts[kw] = p is not None and p.kind in VALID_KINDS
else:
self.func_accepts[kw] = (
p is not None and p.annotation in typ and p.kind in VALID_KINDS
)

def __repr__(self) -> str:
repr_args = {
Expand All @@ -149,16 +174,20 @@ def invoke(
if self.func_accepts_config:
kwargs["config"] = config
_conf = config[CONF]
for kw, _, ck, defv in KWARGS_CONFIG_KEYS:
for kw, _, config_key, default_value in KWARGS_CONFIG_KEYS:
if not self.func_accepts[kw]:
continue

if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf:
if (
default_value is inspect.Parameter.empty
and kw not in kwargs
and config_key not in _conf
):
raise ValueError(
f"Missing required config key '{ck}' for '{self.name}'."
f"Missing required config key '{config_key}' for '{self.name}'."
)
elif kwargs.get(kw) is None:
kwargs[kw] = _conf.get(ck, defv)
kwargs[kw] = _conf.get(config_key, default_value)

context = copy_context()
if self.trace:
Expand Down Expand Up @@ -197,16 +226,20 @@ async def ainvoke(
if self.func_accepts_config:
kwargs["config"] = config
_conf = config[CONF]
for kw, _, ck, defv in KWARGS_CONFIG_KEYS:
for kw, _, config_key, default_value in KWARGS_CONFIG_KEYS:
if not self.func_accepts[kw]:
continue

if defv is inspect.Parameter.empty and kw not in kwargs and ck not in _conf:
if (
default_value is inspect.Parameter.empty
and kw not in kwargs
and config_key not in _conf
):
raise ValueError(
f"Missing required config key '{ck}' for '{self.name}'."
f"Missing required config key '{config_key}' for '{self.name}'."
)
elif kwargs.get(kw) is None:
kwargs[kw] = _conf.get(ck, defv)
kwargs[kw] = _conf.get(config_key, default_value)
context = copy_context()
if self.trace:
callback_manager = get_async_callback_manager_for_config(config, self.tags)
Expand Down
24 changes: 24 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5300,3 +5300,27 @@ def node_b(state):
{"node_a": [{"foo": "a1"}, {"foo": "a2"}]},
{"node_b": {"foo": "b"}},
]


def test_version_1_of_entrypoint() -> None:
from langgraph.func import entrypoint
from typing import TypedDict, Annotated, NotRequired, Any

Check failure on line 5307 in libs/langgraph/tests/test_pregel.py

View workflow job for this annotation

GitHub Actions / cd libs/langgraph / lint #3.12

Ruff (TID251)

tests/test_pregel.py:5307:24: TID251 `typing.TypedDict` is banned: Use typing_extensions.TypedDict instead.

Check failure on line 5307 in libs/langgraph/tests/test_pregel.py

View workflow job for this annotation

GitHub Actions / cd libs/langgraph / lint #3.12

Ruff (F401)

tests/test_pregel.py:5307:24: F401 `typing.TypedDict` imported but unused

Check failure on line 5307 in libs/langgraph/tests/test_pregel.py

View workflow job for this annotation

GitHub Actions / cd libs/langgraph / lint #3.12

Ruff (F401)

tests/test_pregel.py:5307:35: F401 `typing.Annotated` imported but unused

Check failure on line 5307 in libs/langgraph/tests/test_pregel.py

View workflow job for this annotation

GitHub Actions / cd libs/langgraph / lint #3.12

Ruff (F401)

tests/test_pregel.py:5307:46: F401 `typing.NotRequired` imported but unused

Check failure on line 5308 in libs/langgraph/tests/test_pregel.py

View workflow job for this annotation

GitHub Actions / cd libs/langgraph / lint #3.12

Ruff (I001)

tests/test_pregel.py:5306:1: I001 Import block is un-sorted or un-formatted
states = []

# In this version reducers do not work
@entrypoint(checkpointer=MemorySaver())
def foo(inputs, *, previous: Any) -> Any:
states.append(previous)
return {"previous": previous, "current": inputs}

config = {"configurable": {"thread_id": "1"}}

foo.invoke({"a": "1"}, config)
foo.invoke({"a": "2"}, config)
foo.invoke({"a": "3"}, config)
assert states == [
None,
{"current": {"a": "1"}, "previous": None},
{"current": {"a": "2"}, "previous": {"current": {"a": "1"}, "previous": None}},
]

0 comments on commit 67a800d

Please sign in to comment.