diff --git a/libs/langgraph/langgraph/graph/state.py b/libs/langgraph/langgraph/graph/state.py index 508f4e474..1a7208a2a 100644 --- a/libs/langgraph/langgraph/graph/state.py +++ b/libs/langgraph/langgraph/graph/state.py @@ -829,15 +829,16 @@ def _coerce_state(schema: Type[Any], input: dict[str, Any]) -> dict[str, Any]: def _control_branch(value: Any) -> Sequence[Union[str, Send]]: if isinstance(value, Send): return [value] - if not isinstance(value, GraphCommand): + if not isinstance(value, Command): return EMPTY_SEQ if value.graph == Command.PARENT: raise ParentCommand(value) rtn: list[Union[str, Send]] = [] - if isinstance(value.goto, str): - rtn.append(value.goto) - else: - rtn.extend(value.goto) + if isinstance(value, GraphCommand): + if isinstance(value.goto, str): + rtn.append(value.goto) + else: + rtn.extend(value.goto) if isinstance(value.send, Send): rtn.append(value.send) else: @@ -848,15 +849,16 @@ def _control_branch(value: Any) -> Sequence[Union[str, Send]]: async def _acontrol_branch(value: Any) -> Sequence[Union[str, Send]]: if isinstance(value, Send): return [value] - if not isinstance(value, GraphCommand): + if not isinstance(value, Command): return EMPTY_SEQ if value.graph == Command.PARENT: raise ParentCommand(value) rtn: list[Union[str, Send]] = [] - if isinstance(value.goto, str): - rtn.append(value.goto) - else: - rtn.extend(value.goto) + if isinstance(value, GraphCommand): + if isinstance(value.goto, str): + rtn.append(value.goto) + else: + rtn.extend(value.goto) if isinstance(value.send, Send): rtn.append(value.send) else: diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 6d1caa342..4ff56ca3e 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -1925,7 +1925,7 @@ def __call__(self, state): def send_for_fun(state): return [ - Send("2", GraphCommand(send=Send("2", 3))), + Send("2", Command(send=Send("2", 3))), Send("2", GraphCommand(send=Send("2", 4))), "3.1", ] diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 378ca1ac1..cc244c363 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -2573,14 +2573,14 @@ async def __call__(self, state): if isinstance(state, list) # or isinstance(state, Control) else ["|".join((self.name, str(state)))] ) - if isinstance(state, GraphCommand): + if isinstance(state, Command): return replace(state, update=update) else: return update async def send_for_fun(state): return [ - Send("2", GraphCommand(send=Send("2", 3))), + Send("2", Command(send=Send("2", 3))), Send("2", GraphCommand(send=Send("2", 4))), "3.1", ]