Skip to content

Commit

Permalink
follow comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Nov 20, 2024
1 parent d457c19 commit 83c231e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
9 changes: 0 additions & 9 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ def __init__(
use_torchao_fp8_allgather: bool = False,
use_torchao_fp8_precompute_scale_for_fsdp: bool = False,
fp8_shard_intermediate_activation: bool = False,
save_dynamo_repro: str | None = None,
):
seed = 1337
torch.manual_seed(seed)
Expand Down Expand Up @@ -276,11 +275,6 @@ def __init__(
self.dump_thunder_traces = dump_thunder_traces
self.dump_memory_snapshot = dump_memory_snapshot
self.fp8_shard_intermediate_activation = fp8_shard_intermediate_activation
if save_dynamo_repro is not None:
assert (
"dynamo" in self.compile and "thunder" in self.compile
), "save_dynamo_repro can only be used if --compile=thunder+dynamo"
self.save_dynamo_repro = save_dynamo_repro

if use_torchao_fp8_linear:

Expand Down Expand Up @@ -898,9 +892,6 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
print(f"##########\n#{i}-th ThunderModule\n##########")
print(b_traces[-1])

if benchmark.save_dynamo_repro:
benchmark.backend.save_reproducer_to_folder(benchmark.save_dynamo_repro)

if global_rank in [0, None]:
if return_metrics_as_json:
benchmark.add_model_info_to_metrics()
Expand Down
18 changes: 11 additions & 7 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class SplitReason:
exception: Exception | None = None


@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class ExampleInputMetaData:
"""
Describes the metadata of a tensor, used to generate a random tensor with matching properties
Expand Down Expand Up @@ -464,6 +464,12 @@ def _get_storage_shape(t: torch.Tensor):


def _get_example_input_tensor_metadata(t: torch.Tensor) -> ExampleInputMetaData:
min_val = None
max_val = None
if not isinstance(t, FakeTensor) and t.numel() != 0:
minmax: tuple[torch.Tensor, torch.Tensor] = torch.aminmax(t)
min_val = minmax[0].cpu().item()
max_val = minmax[1].cpu().item()
meta_ev = ExampleInputMetaData(
t.requires_grad,
t.layout,
Expand All @@ -472,11 +478,9 @@ def _get_example_input_tensor_metadata(t: torch.Tensor) -> ExampleInputMetaData:
_concrete_value(t.shape),
_get_storage_shape(t),
_concrete_value(t.stride()),
min_val,
max_val,
)
if not isinstance(t, FakeTensor) and t.numel() != 0:
minmax: tuple[torch.Tensor, torch.Tensor] = torch.aminmax(t)
meta_ev.min_val = minmax[0].cpu().item()
meta_ev.max_val = minmax[1].cpu().item()
return meta_ev


Expand Down Expand Up @@ -768,8 +772,8 @@ def reproducer(
code_str += f"{_addindent(input_str, 4)}\n]\n"

if not use_pytest_benchmark:
code_str += f"fqn = thunder.jit(DynamoModule(), {thunder_options_str})\n"
code_str += "fqn(*inputs)"
code_str += f"compiled = thunder.jit(DynamoModule(), {thunder_options_str})\n"
code_str += "compiled(*inputs)"
else:
code_str += "from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking\n"
code_str = f"""{code_str}
Expand Down

0 comments on commit 83c231e

Please sign in to comment.