From 83c231e98c2e9d8260fcf7a2e990d69212b49535 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Wed, 20 Nov 2024 13:42:00 +0100 Subject: [PATCH] follow comments --- thunder/benchmarks/benchmark_litgpt.py | 9 --------- thunder/dynamo/utils.py | 18 +++++++++++------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index e9a3c202c..8bcaf575e 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -241,7 +241,6 @@ def __init__( use_torchao_fp8_allgather: bool = False, use_torchao_fp8_precompute_scale_for_fsdp: bool = False, fp8_shard_intermediate_activation: bool = False, - save_dynamo_repro: str | None = None, ): seed = 1337 torch.manual_seed(seed) @@ -276,11 +275,6 @@ def __init__( self.dump_thunder_traces = dump_thunder_traces self.dump_memory_snapshot = dump_memory_snapshot self.fp8_shard_intermediate_activation = fp8_shard_intermediate_activation - if save_dynamo_repro is not None: - assert ( - "dynamo" in self.compile and "thunder" in self.compile - ), "save_dynamo_repro can only be used if --compile=thunder+dynamo" - self.save_dynamo_repro = save_dynamo_repro if use_torchao_fp8_linear: @@ -898,9 +892,6 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None print(f"##########\n#{i}-th ThunderModule\n##########") print(b_traces[-1]) - if benchmark.save_dynamo_repro: - benchmark.backend.save_reproducer_to_folder(benchmark.save_dynamo_repro) - if global_rank in [0, None]: if return_metrics_as_json: benchmark.add_model_info_to_metrics() diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index c1aadadf5..4bab617cd 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -80,7 +80,7 @@ class SplitReason: exception: Exception | None = None -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class ExampleInputMetaData: """ Describes the metadata of a tensor, used to generate a random tensor with matching properties @@ -464,6 +464,12 @@ def _get_storage_shape(t: torch.Tensor): def _get_example_input_tensor_metadata(t: torch.Tensor) -> ExampleInputMetaData: + min_val = None + max_val = None + if not isinstance(t, FakeTensor) and t.numel() != 0: + minmax: tuple[torch.Tensor, torch.Tensor] = torch.aminmax(t) + min_val = minmax[0].cpu().item() + max_val = minmax[1].cpu().item() meta_ev = ExampleInputMetaData( t.requires_grad, t.layout, @@ -472,11 +478,9 @@ def _get_example_input_tensor_metadata(t: torch.Tensor) -> ExampleInputMetaData: _concrete_value(t.shape), _get_storage_shape(t), _concrete_value(t.stride()), + min_val, + max_val, ) - if not isinstance(t, FakeTensor) and t.numel() != 0: - minmax: tuple[torch.Tensor, torch.Tensor] = torch.aminmax(t) - meta_ev.min_val = minmax[0].cpu().item() - meta_ev.max_val = minmax[1].cpu().item() return meta_ev @@ -768,8 +772,8 @@ def reproducer( code_str += f"{_addindent(input_str, 4)}\n]\n" if not use_pytest_benchmark: - code_str += f"fqn = thunder.jit(DynamoModule(), {thunder_options_str})\n" - code_str += "fqn(*inputs)" + code_str += f"compiled = thunder.jit(DynamoModule(), {thunder_options_str})\n" + code_str += "compiled(*inputs)" else: code_str += "from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking\n" code_str = f"""{code_str}