Skip to content

Commit

Permalink
qx
Browse files Browse the repository at this point in the history
  • Loading branch information
eyurtsev committed Jan 13, 2025
1 parent d7199e5 commit 2d5bf1e
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions libs/langgraph/langgraph/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
Awaitable,
Callable,
Optional,
Type,
TypeVar,
Union,
get_type_hints,
overload,
)

Expand All @@ -20,6 +22,7 @@
from langgraph.channels.last_value import LastValue
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.constants import END, START, TAG_HIDDEN
from langgraph.graph.state import StateGraph
from langgraph.pregel import Pregel
from langgraph.pregel.call import get_runnable_for_func
from langgraph.pregel.read import PregelNode
Expand Down Expand Up @@ -47,6 +50,14 @@ def call(
return fut


def get_store() -> BaseStore:
"""Get the current store."""
from langgraph.constants import CONFIG_KEY_STORE
from langgraph.utils.config import get_configurable

return get_configurable()[CONFIG_KEY_STORE]


@overload
def task(
*, retry: Optional[RetryPolicy] = None
Expand Down Expand Up @@ -153,3 +164,42 @@ async def agen_wrapper(
)

return _imp


def g(
*,
checkpointer: Optional[BaseCheckpointSaver] = None,
store: Optional[BaseStore] = None,
input: Optional[Type[Any]] = None,
output: Optional[Type[Any]] = None,
) -> Callable[[types.FunctionType], Pregel]:
"""Generate a single node StateGraph from a callable.
Args:
checkpointer: Checkpointer to use for the graph.
store: Store to use for the graph.
input: Schema projection applied to run input to the graph.
output: Schema projection applied to run output from the graph.
Returns:
Callable: A decorator that can be applied to a function to generate a Pregel graph.
"""

def decorator(func: Callable) -> Pregel:
"""Decorator to generate a Pregel graph from a callable."""
node_name = func.__name__
type_hints = get_type_hints(func)
if not type_hints:
state_schema = dict
else:
state_schema = list(type_hints.values())[0]
builder = (
StateGraph(state_schema=state_schema, input=input, output=output)
.add_node(func)
.set_entry_point(node_name)
)
graph = builder.compile(checkpointer=checkpointer, store=store)
return graph

return decorator

0 comments on commit 2d5bf1e

Please sign in to comment.