From 103cbd38f53e26886bdefe3e233115a31ecb8a81 Mon Sep 17 00:00:00 2001 From: zhanghb97 Date: Sun, 11 Feb 2024 18:22:57 +0000 Subject: [PATCH] [frontend] Add verbose mode for torch dynamo compiler. --- frontend/Python/frontend.py | 13 +++++++++++ requirements.txt | 1 + tests/Python/test_verbose_mode.py | 39 +++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+) create mode 100644 tests/Python/test_verbose_mode.py diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index e89597800c..8768308172 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -62,6 +62,7 @@ def __init__( func_name: str = "forward", primary_registry: Optional[dict] = None, aot_autograd_decomposition: Optional[dict] = None, + verbose=False, ) -> None: """ Initializes the Dynamo Compiler. @@ -71,10 +72,14 @@ def __init__( primary_registry (dict, optional): The primary operations registry. aot_autograd_decomposition (Optional[dict], optional): The ahead-of-time autograd decomposition dictionary. + verbose (bool): Controls whether to print additional information for + debugging purposes. The default value is False, indicating that + no extra debug information will be printed. Attributes: _func_name: The function name to be used. _aot_autograd_decomposition (Optional[dict], optional): The ahead-of-time autograd decomposition dictionary. + _verbose: The option for the verbosity option of output. _imported_graphs: The buddy graphs from dynamo importer. _ops_registry (dict, optional): The buddy operations' lower func registry. @@ -82,10 +87,14 @@ def __init__( _ops_map: The torch aten ops map with buddy ops. """ + # Make custom dynamo compiler take effect. + dynamo.reset() + # Initialize the attributes. if primary_registry is None: primary_registry = {} self._func_name = func_name self._aot_autograd_decomposition = aot_autograd_decomposition + self._verbose = verbose self._imported_graphs = [] self._ops_registry = {} self._imported_params = {} @@ -243,6 +252,10 @@ def _compile_fx( } params_flat, _ = pytree.tree_flatten(params) + if self._verbose: + print("Graph in tabular form:") + gm.graph.print_tabular() + def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): """Compile a FX graph in Aten/Prims IR to MLIR.""" nonlocal params_flat diff --git a/requirements.txt b/requirements.txt index 606179eb74..125cc8d4cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ accelerate protobuf pybind11 == 2.11.1 torchvision +tabulate diff --git a/tests/Python/test_verbose_mode.py b/tests/Python/test_verbose_mode.py new file mode 100644 index 0000000000..82279ac2ed --- /dev/null +++ b/tests/Python/test_verbose_mode.py @@ -0,0 +1,39 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s +import torch + +from buddy.compiler.frontend import DynamoCompiler + + +# Define the target function or model. +def foo(x, y): + return x * y + x + + +# Define the input data. +float32_in1 = torch.randn(10).to(torch.float32) +float32_in2 = torch.randn(10).to(torch.float32) + +# Test the default dynamo compiler importer mode. +dynamo_compiler_default = DynamoCompiler() +graphs = dynamo_compiler_default.importer(foo, *(float32_in1, float32_in2)) + +# Ensure no output is printed in the default mode. +# CHECK-NOT: . + +# Test the dynamo compiler verbose mode. +dynamo_compiler_verbose_on = DynamoCompiler(verbose=True) +graphs = dynamo_compiler_verbose_on.importer(foo, *(float32_in1, float32_in2)) + +# Test output in the verbose mode. +# CHECK: placeholder +# CHECK: placeholder +# CHECK: call_function +# CHECK: call_function +# CHECK: output + +# Test the dynamo compiler verbose mode off. +dynamo_compiler_verbose_off = DynamoCompiler(verbose=False) +graphs = dynamo_compiler_verbose_off.importer(foo, *(float32_in1, float32_in2)) + +# Ensure no output is printed when the verbose mode is off. +# CHECK-NOT: .