Skip to content

Commit

Permalink
Add output node if it does not exist in the split module (#1476)
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Nov 26, 2024
1 parent cd6977d commit 7951a51
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
13 changes: 13 additions & 0 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,19 @@ def callback(node) -> int:
original_split_gm: torch.fx.GraphModule = split_module(
gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True
)

def add_output(m):
has_output = False
for node in m.graph.nodes:
if node.op == "call_module":
add_output(getattr(m, node.target))
elif node.op == "output":
has_output = True
if not has_output:
m.graph.output(())

# Workaround for the Torch bug https://github.com/pytorch/pytorch/pull/139275
add_output(original_split_gm)
split_gm = copy.deepcopy(original_split_gm)

def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
Expand Down
4 changes: 0 additions & 4 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,6 @@ def func(x):
IS_WINDOWS,
reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094",
),
pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("2.6.0"),
reason="Skip until the Torch bug is fixed - https://github.com/pytorch/pytorch/pull/139275",
),
pytest.mark.skipif(
version_between(torch.__version__, min_ver="2.6.0dev0", max_ver="2.6.0a99"),
reason="https://github.com/Lightning-AI/lightning-thunder/issues/1471",
Expand Down

0 comments on commit 7951a51

Please sign in to comment.