Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Nov 27, 2024
1 parent 7951a51 commit 39a4fd1
Showing 1 changed file with 4 additions and 11 deletions.
15 changes: 4 additions & 11 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,11 @@ def callback(node) -> int:
gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True
)

def add_output(m):
has_output = False
for node in m.graph.nodes:
if node.op == "call_module":
add_output(getattr(m, node.target))
elif node.op == "output":
has_output = True
if not has_output:
m.graph.output(())

# Workaround for the Torch bug https://github.com/pytorch/pytorch/pull/139275
add_output(original_split_gm)
for submodule in original_split_gm.children():
last_node = next(iter(reversed(submodule.graph.nodes)))
if last_node.op != "output":
submodule.graph.output(())
split_gm = copy.deepcopy(original_split_gm)

def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
Expand Down

0 comments on commit 39a4fd1

Please sign in to comment.