diff --git a/libs/langgraph/langgraph/func/__init__.py b/libs/langgraph/langgraph/func/__init__.py index 62bf138e6..7b27c78ee 100644 --- a/libs/langgraph/langgraph/func/__init__.py +++ b/libs/langgraph/langgraph/func/__init__.py @@ -9,8 +9,10 @@ Awaitable, Callable, Optional, + Type, TypeVar, Union, + get_type_hints, overload, ) @@ -20,6 +22,7 @@ from langgraph.channels.last_value import LastValue from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.constants import END, START, TAG_HIDDEN +from langgraph.graph.state import StateGraph from langgraph.pregel import Pregel from langgraph.pregel.call import get_runnable_for_func from langgraph.pregel.read import PregelNode @@ -47,6 +50,14 @@ def call( return fut +def get_store() -> BaseStore: + """Get the current store.""" + from langgraph.constants import CONFIG_KEY_STORE + from langgraph.utils.config import get_configurable + + return get_configurable()[CONFIG_KEY_STORE] + + @overload def task( *, retry: Optional[RetryPolicy] = None @@ -153,3 +164,42 @@ async def agen_wrapper( ) return _imp + + +def g( + *, + checkpointer: Optional[BaseCheckpointSaver] = None, + store: Optional[BaseStore] = None, + input: Optional[Type[Any]] = None, + output: Optional[Type[Any]] = None, +) -> Callable[[types.FunctionType], Pregel]: + """Generate a single node StateGraph from a callable. + + Args: + checkpointer: Checkpointer to use for the graph. + store: Store to use for the graph. + input: Schema projection applied to run input to the graph. + output: Schema projection applied to run output from the graph. + + + Returns: + Callable: A decorator that can be applied to a function to generate a Pregel graph. + """ + + def decorator(func: Callable) -> Pregel: + """Decorator to generate a Pregel graph from a callable.""" + node_name = func.__name__ + type_hints = get_type_hints(func) + if not type_hints: + state_schema = dict + else: + state_schema = list(type_hints.values())[0] + builder = ( + StateGraph(state_schema=state_schema, input=input, output=output) + .add_node(func) + .set_entry_point(node_name) + ) + graph = builder.compile(checkpointer=checkpointer, store=store) + return graph + + return decorator