diff --git a/libs/langgraph/langgraph/graph/state.py b/libs/langgraph/langgraph/graph/state.py index 6e37c6150..c416d5f6a 100644 --- a/libs/langgraph/langgraph/graph/state.py +++ b/libs/langgraph/langgraph/graph/state.py @@ -613,21 +613,24 @@ def attach_node(self, key: str, node: Optional[StateNodeSpec]) -> None: ] def _get_root(input: Any) -> Optional[Sequence[tuple[str, Any]]]: - if ( + if isinstance(input, Command): + if input.graph == Command.PARENT: + return () + return input._update_as_tuples() + elif ( isinstance(input, (list, tuple)) and input - and all(isinstance(i, Command) for i in input) + and any(isinstance(i, Command) for i in input) ): updates: list[tuple[str, Any]] = [] for i in input: - if i.graph == Command.PARENT: - continue - updates.extend(i._update_as_tuples()) + if isinstance(i, Command): + if i.graph == Command.PARENT: + continue + updates.extend(i._update_as_tuples()) + else: + updates.append(("__root__", i)) return updates - elif isinstance(input, Command): - if input.graph == Command.PARENT: - return () - return input._update_as_tuples() elif input is not None: return [("__root__", input)] @@ -645,13 +648,16 @@ def _get_updates( elif ( isinstance(input, (list, tuple)) and input - and all(isinstance(i, Command) for i in input) + and any(isinstance(i, Command) for i in input) ): updates: list[tuple[str, Any]] = [] for i in input: - if i.graph == Command.PARENT: - continue - updates.extend(i._update_as_tuples()) + if isinstance(i, Command): + if i.graph == Command.PARENT: + continue + updates.extend(i._update_as_tuples()) + else: + updates.extend(_get_updates(i) or ()) return updates elif get_type_hints(type(input)): return [ diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 2515270a1..a3f5a6d48 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -14807,3 +14807,31 @@ def ask_age(s: State): assert [event for event in graph.stream(Command(resume="19"), thread1)] == [ {"node": {"age": 19}}, ] + + +def test_root_mixed_return() -> None: + def my_node(state: list[str]): + return [Command(update=["a"]), ["b"]] + + graph = StateGraph(Annotated[list[str], operator.add]) + + graph.add_node(my_node) + graph.add_edge(START, "my_node") + graph = graph.compile() + + assert graph.invoke([]) == ["a", "b"] + + +def test_dict_mixed_return() -> None: + class State(TypedDict): + foo: Annotated[str, operator.add] + + def my_node(state: State): + return [Command(update={"foo": "a"}), {"foo": "b"}] + + graph = StateGraph(State) + graph.add_node(my_node) + graph.add_edge(START, "my_node") + graph = graph.compile() + + assert graph.invoke({"foo": ""}) == {"foo": "ab"} diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 514703781..b278d9f77 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -9670,14 +9670,14 @@ async def outer_2(state: State): ), (FloatBetween(0.2, 0.4), ((), {"outer_1": {"my_key": " and parallel"}})), ( - FloatBetween(0.5, 0.7), + FloatBetween(0.5, 0.8), ( (AnyStr("inner:"),), {"inner_2": {"my_key": " and there", "my_other_key": "got here"}}, ), ), - (FloatBetween(0.5, 0.7), ((), {"inner": {"my_key": "got here and there"}})), - (FloatBetween(0.5, 0.7), ((), {"outer_2": {"my_key": " and back again"}})), + (FloatBetween(0.5, 0.8), ((), {"inner": {"my_key": "got here and there"}})), + (FloatBetween(0.5, 0.8), ((), {"outer_2": {"my_key": " and back again"}})), ]