From 39a4fd1aa9908f50e51ae768f2ca035f4622cac2 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Wed, 27 Nov 2024 09:48:59 +0100 Subject: [PATCH] fix --- thunder/dynamo/splitter.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index f52dff0f7b..435098028b 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -136,18 +136,11 @@ def callback(node) -> int: gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True ) - def add_output(m): - has_output = False - for node in m.graph.nodes: - if node.op == "call_module": - add_output(getattr(m, node.target)) - elif node.op == "output": - has_output = True - if not has_output: - m.graph.output(()) - # Workaround for the Torch bug https://github.com/pytorch/pytorch/pull/139275 - add_output(original_split_gm) + for submodule in original_split_gm.children(): + last_node = next(iter(reversed(submodule.graph.nodes))) + if last_node.op != "output": + submodule.graph.output(()) split_gm = copy.deepcopy(original_split_gm) def is_thunder_supported_partition(node: torch.fx.Node) -> bool: