Skip to content

Commit

Permalink
[frontend] Add verbose mode for torch dynamo compiler.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghb97 committed Feb 11, 2024
1 parent 90ed936 commit 103cbd3
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 0 deletions.
13 changes: 13 additions & 0 deletions frontend/Python/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
func_name: str = "forward",
primary_registry: Optional[dict] = None,
aot_autograd_decomposition: Optional[dict] = None,
verbose=False,
) -> None:
"""
Initializes the Dynamo Compiler.
Expand All @@ -71,21 +72,29 @@ def __init__(
primary_registry (dict, optional): The primary operations registry.
aot_autograd_decomposition (Optional[dict], optional):
The ahead-of-time autograd decomposition dictionary.
verbose (bool): Controls whether to print additional information for
debugging purposes. The default value is False, indicating that
no extra debug information will be printed.
Attributes:
_func_name: The function name to be used.
_aot_autograd_decomposition (Optional[dict], optional):
The ahead-of-time autograd decomposition dictionary.
_verbose: The option for the verbosity option of output.
_imported_graphs: The buddy graphs from dynamo importer.
_ops_registry (dict, optional): The buddy operations' lower func
registry.
_imported_params: The model params extract from torch.
_ops_map: The torch aten ops map with buddy ops.
"""
# Make custom dynamo compiler take effect.
dynamo.reset()
# Initialize the attributes.
if primary_registry is None:
primary_registry = {}
self._func_name = func_name
self._aot_autograd_decomposition = aot_autograd_decomposition
self._verbose = verbose
self._imported_graphs = []
self._ops_registry = {}
self._imported_params = {}
Expand Down Expand Up @@ -243,6 +252,10 @@ def _compile_fx(
}
params_flat, _ = pytree.tree_flatten(params)

if self._verbose:
print("Graph in tabular form:")
gm.graph.print_tabular()

def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
"""Compile a FX graph in Aten/Prims IR to MLIR."""
nonlocal params_flat
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ accelerate
protobuf
pybind11 == 2.11.1
torchvision
tabulate
39 changes: 39 additions & 0 deletions tests/Python/test_verbose_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s
import torch

from buddy.compiler.frontend import DynamoCompiler


# Define the target function or model.
def foo(x, y):
return x * y + x


# Define the input data.
float32_in1 = torch.randn(10).to(torch.float32)
float32_in2 = torch.randn(10).to(torch.float32)

# Test the default dynamo compiler importer mode.
dynamo_compiler_default = DynamoCompiler()
graphs = dynamo_compiler_default.importer(foo, *(float32_in1, float32_in2))

# Ensure no output is printed in the default mode.
# CHECK-NOT: .

# Test the dynamo compiler verbose mode.
dynamo_compiler_verbose_on = DynamoCompiler(verbose=True)
graphs = dynamo_compiler_verbose_on.importer(foo, *(float32_in1, float32_in2))

# Test output in the verbose mode.
# CHECK: placeholder
# CHECK: placeholder
# CHECK: call_function
# CHECK: call_function
# CHECK: output

# Test the dynamo compiler verbose mode off.
dynamo_compiler_verbose_off = DynamoCompiler(verbose=False)
graphs = dynamo_compiler_verbose_off.importer(foo, *(float32_in1, float32_in2))

# Ensure no output is printed when the verbose mode is off.
# CHECK-NOT: .

0 comments on commit 103cbd3

Please sign in to comment.