diff --git a/libs/langgraph/langgraph/graph/state.py b/libs/langgraph/langgraph/graph/state.py index d5e4a54809..7b950556be 100644 --- a/libs/langgraph/langgraph/graph/state.py +++ b/libs/langgraph/langgraph/graph/state.py @@ -361,7 +361,11 @@ def add_node( ends = EMPTY_SEQ try: - if (isfunction(action) or ismethod(getattr(action, "__call__", None))) and ( + if ( + isfunction(action) + or ismethod(action) + or ismethod(getattr(action, "__call__", None)) + ) and ( hints := get_type_hints(getattr(action, "__call__")) or get_type_hints(action) ): diff --git a/libs/langgraph/tests/test_state.py b/libs/langgraph/tests/test_state.py index 0a4a8725ac..5af4066699 100644 --- a/libs/langgraph/tests/test_state.py +++ b/libs/langgraph/tests/test_state.py @@ -79,11 +79,19 @@ def miss_all_hint(state, config): def pre_foo(_) -> FooState: return {"foo": "bar"} + def pre_bar(_) -> FooState: + return {"foo": "bar"} + class Foo: def __call__(self, state: FooState) -> OutputState: assert state.pop("foo") == "bar" return {"input_state": state} + class Bar: + def my_node(self, state: FooState) -> OutputState: + assert state.pop("foo") == "bar" + return {"input_state": state} + graph = StateGraph(InputState, output=OutputState) actions = [ complete_hint, @@ -92,6 +100,8 @@ def __call__(self, state: FooState) -> OutputState: miss_all_hint, pre_foo, Foo(), + pre_bar, + Bar().my_node, ] for action in actions: @@ -112,7 +122,7 @@ def get_name(action) -> str: foo_state = FooState(foo="bar") for i, c in enumerate(graph.stream(input_state, stream_mode="updates")): node_name = get_name(actions[i]) - if node_name == get_name(pre_foo): + if node_name in {"pre_foo", "pre_bar"}: assert c[node_name] == foo_state else: assert c[node_name] == output_state