Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

thunderFX : pass to remove empty autocast regions #1400

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from thunder.core.baseutils import run_once
from thunder.dynamo.utils import recompile_graph
from thunder.dynamo.utils import recompile_graph, remove_empty_autocast
from thunder.dynamo.splitter import _splitter

if TYPE_CHECKING:
Expand Down Expand Up @@ -72,6 +72,8 @@ def __init__(self, **thunder_options):
self._torch_compile = partial(torch.compile, **torch_inductor_options)

def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]):
gm = remove_empty_autocast(gm)

# Dynamo uses lazy generation of the underlying Python code, so we need to
# force recompilation of the GraphModule before passing it to Thunder.
recompile_graph(gm)
Expand Down
41 changes: 41 additions & 0 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,3 +512,44 @@ def checkpoint_converter(gm: torch.fx.GraphModule, sub_gm: torch.fx.GraphModule)
else:
function_module = getattr(gm, n.args[0].name)
_checkpoint_function_converter(function_module)


def remove_empty_autocast(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Function to remove empty autocast regions from GraphModule.

Dynamo can provide empty autocast regions in which case, it is more performant to remove them
from the graph than to compile them and pay the cost of calling a wrapped optimized function
which does nothing.

Args:
graph_module: Graph module to which this pass is applied.

"""

empty_autocast_removed_graph_module = copy.deepcopy(graph_module)

# Dummy init node.
prev_node = torch.fx.node.Node(graph_module.graph, "start_node", "call_function", lambda: None, None, None)
nodes_to_erase = []
for node in empty_autocast_removed_graph_module.graph.nodes:
# As _enter_autocast and _exit_autocast functions map the regions created by context manager,
# previous `_enter_autocast` will always correspond with current `_exit_autocast`.
if (
prev_node.target == torch.amp.autocast_mode._enter_autocast
and node.target == torch.amp.autocast_mode._exit_autocast
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
):
# NOTE: Order of node being appended matters.
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
# The node to be erased has to have zero users.
# So, we remove `_exit_autocast` first (which consumes output from `_enter_autocast`)
# and then we can remove the corresponding `_enter_autocast`.
nodes_to_erase.append(node)
nodes_to_erase.append(prev_node)

prev_node = node

# Erase the marked nodes.
for node in nodes_to_erase:
empty_autocast_removed_graph_module.graph.erase_node(node)

return empty_autocast_removed_graph_module
55 changes: 55 additions & 0 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import warnings
import itertools
import torch
import torch.fx
import torch.nn as nn
Expand Down Expand Up @@ -515,6 +516,60 @@ def func(x):
torch.testing.assert_close(actual_grad, expected_grad)


def test_empty_autocast():
autocast_ops = (torch.amp.autocast_mode._enter_autocast, torch.amp.autocast_mode._exit_autocast)

def _call_thunder_backend(fn, args):
backend = ThunderCompiler()
jf = torch.compile(backend=backend)(f)
jf(*args)
return backend

# autocast region is removed
def f():
with torch.autocast(dtype=torch.bfloat16, device_type="cpu"):
pass
return

backend = _call_thunder_backend(f, ())
assert all(node.target not in autocast_ops for node in backend.subgraph_infos[0].split_graph_module.graph.nodes)

# Both autocast regions are removed
def f(x):
with torch.autocast(dtype=torch.bfloat16, device_type="cpu"):
pass
y = x @ x
with torch.autocast(dtype=torch.bfloat16, device_type="cpu"):
pass
return y

x = torch.randn(3, 3)
backend = _call_thunder_backend(f, (x,))

all_nodes = itertools.chain(
backend.subgraph_infos[0].split_graph_module.graph.nodes,
backend.subgraph_infos[0].split_graph_module.thunder_1.graph.nodes,
)
assert all(node.target not in autocast_ops for node in all_nodes)

# First autocast region is removed and second isn't
def f(x):
with torch.autocast(dtype=torch.bfloat16, device_type="cpu"):
pass
y = x @ x
with torch.autocast(dtype=torch.bfloat16, device_type="cpu"):
y = y @ y
return y

x = torch.randn(3, 3)
backend = _call_thunder_backend(f, (x,))
all_nodes = itertools.chain(
backend.subgraph_infos[0].split_graph_module.graph.nodes,
backend.subgraph_infos[0].split_graph_module.thunder_1.graph.nodes,
)
assert sum(node.target in autocast_ops for node in all_nodes) == 2


# Sample command to run the benchmark using ThunderCompilerGraphBenchmarking
# pytest thunder/tests/test_dynamo.py -k test_ThunderCompilerGraphBenchmarking_groupby --benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName'
# For more details, see :class:`thunder.dynamo.compiler_graph_benchmark.ThunderCompilerGraphBenchmarking`
Expand Down
Loading