diff --git a/libs/langgraph/langgraph/graph/graph.py b/libs/langgraph/langgraph/graph/graph.py index e957a15b9..e52ed3729 100644 --- a/libs/langgraph/langgraph/graph/graph.py +++ b/libs/langgraph/langgraph/graph/graph.py @@ -561,7 +561,11 @@ def add_edge( subgraph = ( subgraphs[key].get_graph( config=config, - xray=xray - 1 if isinstance(xray, int) and xray > 0 else xray, + xray=xray - 1 + if isinstance(xray, int) + and not isinstance(xray, bool) + and xray > 0 + else xray, ) if key in subgraphs else node.get_graph(config=config) diff --git a/libs/langgraph/tests/__snapshots__/test_pregel.ambr b/libs/langgraph/tests/__snapshots__/test_pregel.ambr index 2e50030e3..e93e58f35 100644 --- a/libs/langgraph/tests/__snapshots__/test_pregel.ambr +++ b/libs/langgraph/tests/__snapshots__/test_pregel.ambr @@ -5148,6 +5148,42 @@ ''' # --- +# name: test_xray_bool + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([
__start__
]):::first + gp_one(gp_one) + gp_two___start__(__start__
) + gp_two_p_one(p_one) + gp_two_p_two___start__(__start__
) + gp_two_p_two_c_one(c_one) + gp_two_p_two_c_two(c_two) + gp_two_p_two___end__(__end__
) + gp_two___end__(__end__
) + __end__([__end__
]):::last + __start__ --> gp_one; + gp_two___end__ --> gp_one; + gp_one -. 0 .-> gp_two___start__; + gp_one -. 1 .-> __end__; + subgraph gp_two + gp_two___start__ --> gp_two_p_one; + gp_two_p_two___end__ --> gp_two_p_one; + gp_two_p_one -. 0 .-> gp_two_p_two___start__; + gp_two_p_one -. 1 .-> gp_two___end__; + subgraph p_two + gp_two_p_two___start__ --> gp_two_p_two_c_one; + gp_two_p_two_c_two --> gp_two_p_two_c_one; + gp_two_p_two_c_one -. 0 .-> gp_two_p_two_c_two; + gp_two_p_two_c_one -. 1 .-> gp_two_p_two___end__; + end + end + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- # name: test_xray_issue ''' %%{init: {'flowchart': {'curve': 'linear'}}}%% diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index d2943d068..917057b29 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -11358,6 +11358,51 @@ def _node(state: State): assert app.get_graph(xray=True).draw_mermaid() == snapshot +def test_xray_bool(snapshot: SnapshotAssertion) -> None: + class State(TypedDict): + messages: Annotated[list, add_messages] + + def node(name): + def _node(state: State): + return {"messages": [("human", f"entered {name} node")]} + + return _node + + grand_parent = StateGraph(State) + + child = StateGraph(State) + + child.add_node("c_one", node("c_one")) + child.add_node("c_two", node("c_two")) + + child.add_edge("__start__", "c_one") + child.add_edge("c_two", "c_one") + + child.add_conditional_edges( + "c_one", lambda x: str(randrange(0, 2)), {"0": "c_two", "1": "__end__"} + ) + + parent = StateGraph(State) + parent.add_node("p_one", node("p_one")) + parent.add_node("p_two", child.compile()) + parent.add_edge("__start__", "p_one") + parent.add_edge("p_two", "p_one") + parent.add_conditional_edges( + "p_one", lambda x: str(randrange(0, 2)), {"0": "p_two", "1": "__end__"} + ) + + grand_parent.add_node("gp_one", node("gp_one")) + grand_parent.add_node("gp_two", parent.compile()) + grand_parent.add_edge("__start__", "gp_one") + grand_parent.add_edge("gp_two", "gp_one") + grand_parent.add_conditional_edges( + "gp_one", lambda x: str(randrange(0, 2)), {"0": "gp_two", "1": "__end__"} + ) + + app = grand_parent.compile() + assert app.get_graph(xray=True).draw_mermaid() == snapshot + + def test_subgraph_retries(): class State(TypedDict): count: int