From dcfc7999947b1357c3cea1979828d8fe9eaafab8 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Fri, 18 Oct 2024 19:37:07 +0200 Subject: [PATCH] Improve graph-by-graph bench (#1317) --- thunder/dynamo/compiler_graph_benchmark.py | 69 +++++++++++----------- thunder/tests/test_dynamo.py | 29 ++++++++- 2 files changed, 62 insertions(+), 36 deletions(-) diff --git a/thunder/dynamo/compiler_graph_benchmark.py b/thunder/dynamo/compiler_graph_benchmark.py index 7f7fbccf6f..eafd30ce0e 100644 --- a/thunder/dynamo/compiler_graph_benchmark.py +++ b/thunder/dynamo/compiler_graph_benchmark.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS = ("GraphID", "SplitModuleName", "executor") @@ -25,8 +25,8 @@ class ThunderCompilerGraphBenchmarking(ThunderCompiler): def __init__( self, bench: BenchmarkFixture, - executors: Sequence[str], - **thunder_options, + executors: dict[str, Callable], + **debug_options, ): """ This class acts as a backend for the :func:`torch.compile` function, facilitating the benchmarking of each :class:`torch.fx.GraphModule` produced by Thunder dynamo splitter. @@ -34,16 +34,17 @@ def __init__( Args: bench: the BenchmarkFixture created by ``pytest_benchmark`` - executors: list of executors to compare. Supported executors include: 'eager', 'inductor', and 'thunder'. If None, defaults to all available executors. - **thunder_options: a dictionary of options to pass to :func:`thunder.jit`. Besides all the arguments to :func:`thunder.jit`, - it accepts `torch_inductor_options` which are passed to :func:`torch.compile` if part of the graph - is not supported by thunder. + executors: A dictionary of functors to compare. + - Key: The name of the executor to be displayed in the test name. + - Value: A callable representing the compile function to be applied to the GraphModule. + If the value is None, no compilation is performed, and the GraphModule runs in Torch eager mode. Example: .. code-block:: python # script.py import torch + import thunder from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking def func(x): @@ -54,7 +55,7 @@ def func(x): return x - 1 def test_func(benchmark): - backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["eager", "thunder"]) + backend = ThunderCompilerGraphBenchmarking(benchmark, executors={"eager": None, "thunder": thunder.jit}) compiled = torch.compile(backend=backend)(func) x = torch.ones(2, requires_grad=True).cuda() compiled(x) @@ -72,41 +73,41 @@ def test_func(benchmark): With `--benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName'`, the test cases are grouped based on GraphID and SplitModuleName, allowing for performance comparison between different executors (e.g., 'eager' vs. 'thunder'). """ - super().__init__(**thunder_options) + super().__init__() self.bench = bench - if not executors: - self.executors = ThunderCompilerGraphBenchmarking._executors - else: - check( - all(ex in ThunderCompilerGraphBenchmarking._executors for ex in executors), - lambda: f"ThunderCompilerGraphBenchmarking only supports the following executor names: {ThunderCompilerGraphBenchmarking._executors} ", - ) - self.executors = executors + check(isinstance(executors, dict) and executors, lambda: f"'executors' must be a non-empty dictionary.") + check( + not any("-" in k for k in executors.keys()), + lambda: f"Executor names cannot contain '-' as it conflicts with the 'benchmark-group-by' function. Please rename it using a different character.", + ) + self.executors = executors + self._get_debug_options(**debug_options) + self.graph_idx = 0 + def _get_debug_options(self, **debug_options): + self.post_graph = debug_options.get("post_graph", False) + def run_bench(self, gm: torch.fx.GraphModule, name: str, *sample_args): from thunder.benchmarks.targets import record_peak_allocated_memory, MAX_ALLOCATED_MEMORY_KEYWORD - for ex in self.executors: - # Uses the already compiled module if it is compiled with the expected executor - if name.startswith(ex): - fn = self.subgraph_infos[self.graph_idx].submodule_to_compiled_functions[gm].compiled_fn + for ex_name, ex in self.executors.items(): + if ex is None: + compiled_fn = gm else: - if ex == "thunder": - # The subgraph whose name starts with "inductor" is not supported by the Thunder backend. - if name.startswith("inductor"): - continue - fn = self._thunder_jit(gm) - elif ex == "inductor": - fn = self._torch_compile(gm) - else: - fn = gm + try: + compiled_fn = ex(gm) + except Exception as e: + raise RuntimeError(f"The input executor {ex_name} failed to compile {gm}") from e + if self.post_graph: + compiled_fn = self.post_graph(compiled_fn, sample_args) + with record_peak_allocated_memory(self.bench): - self.bench(fn, *sample_args) + self.bench(compiled_fn, *sample_args) # BenchmarkFixture.stats is created each time bench is called (ref: https://github.com/pybenchmark/pytest-benchmark/blob/8c9a5faa1dd178b53ab7b2a66f5364a77e903d74/src/pytest_benchmark/fixture.py#L150) # Adds the graph number, split module name and executor suffix to the name string gid_key, module_name_key, ex_key = GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS - self.bench.stats.name += f"-{gid_key}[{self.graph_idx+1}]-{module_name_key}[{name}]-{ex_key}[{ex}]" + self.bench.stats.name += f"-{gid_key}[{self.graph_idx+1}]-{module_name_key}[{name}]-{ex_key}[{ex_name}]" assert MAX_ALLOCATED_MEMORY_KEYWORD in self.bench.extra_info assert f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}" not in self.bench.extra_info # NOTE: A benchmark can include multiple stats, but only one extra_info field is allowed per benchmark. @@ -128,8 +129,8 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor } for node in split_module.graph.nodes: target = node.target - # Benchmarks the modules produced by the splitter. - if isinstance(target, str) and target.startswith(("thunder_", "inductor_")): + # Benchmarks the modules produced by the splitter and are supported by Thunder. + if isinstance(target, str) and target.startswith("thunder_"): check( hasattr(split_module, target), lambda: f"the submodule {target} does not exist in {split_module}", diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index cae17adf65..869422359a 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -475,7 +475,11 @@ def func(x): # It must be located in the same folder as the test file to ensure the configuration. @requiresCUDA def test_ThunderCompilerGraphBenchmarking_LlamaMLPBenchmark(benchmark): - backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["thunder", "inductor", "eager"]) + import thunder + + backend = ThunderCompilerGraphBenchmarking( + benchmark, executors={"thunder": thunder.jit, "inductor": torch.compile, "eager": None} + ) from thunder.benchmarks import LlamaMLPBenchmark, Benchmark bench: Benchmark = LlamaMLPBenchmark( @@ -505,8 +509,29 @@ def f(x, y): x = torch.sinc(y) + torch.cos(x) return x - 1 - backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["thunder", "inductor", "eager"]) + import thunder + + backend = ThunderCompilerGraphBenchmarking( + benchmark, executors={"thunder": thunder.jit, "inductor": torch.compile, "eager": None} + ) compiled = torch.compile(backend=backend)(f) x = torch.ones(2, requires_grad=True).cuda() y = torch.ones(2, requires_grad=True).cuda() compiled(x, y) + + +@requiresCUDA +def test_ThunderCompilerGraphBenchmarking_post_graph(benchmark): + def f(x): + return torch.sin(x) + + import thunder + from functools import partial + + x = torch.randn((2, 2), device="cuda").requires_grad_() + post_gp = partial(torch.cuda.make_graphed_callables, num_warmup_iters=1, allow_unused_input=True) + backend = ThunderCompilerGraphBenchmarking( + benchmark, executors={"inductor": torch.compile, "thunder": thunder.jit}, post_graph=post_gp + ) + compiled = torch.compile(backend=backend)(f) + compiled(x)