From 348597fd045903aa232b6811bf6bffa392edbd65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 20 Mar 2024 23:41:26 -0700 Subject: [PATCH] Use the throughput utility for benchmarking (#21) --- examples/lit-gpt/test_parametrized.py | 49 ++++++------- thunder/benchmarks/benchmark_litgpt.py | 99 +++++++++++--------------- 2 files changed, 64 insertions(+), 84 deletions(-) diff --git a/examples/lit-gpt/test_parametrized.py b/examples/lit-gpt/test_parametrized.py index bca55173fa..5e658b6447 100644 --- a/examples/lit-gpt/test_parametrized.py +++ b/examples/lit-gpt/test_parametrized.py @@ -7,12 +7,13 @@ MID_BENCHMARK_OUT - use this env variable to control whether you want to see the combined results between each test. BENCHMARK_OUT_FORMAT - use this env variable to control the format in which the results are presented. - Uses 'xlsx' by default. More format support to come soon. + Uses 'xlsx' by default. Supported: 'none', 'print', 'xlsx'. ''' import torch from absl.testing import parameterized from absl.testing import absltest +from collections import defaultdict import os import subprocess import json @@ -48,6 +49,9 @@ def add_to_dataframe(self): self.dataframe_data.append(self.perf_metrics_dict) def complete_dataframe(self, is_teardown): + if not self.dataframe_data: + # The benchmark probably failed + return #Called when tearing down the parametrized test #This generates a summarized dataframe for each perf metric and saves as a xlsx file df = pd.DataFrame(self.dataframe_data) @@ -59,7 +63,7 @@ def complete_dataframe(self, is_teardown): self.tokens_per_sec_per_gpu_df = df.pivot_table(index=index_list, columns='compiler', values='tokens_per_sec_per_gpu', aggfunc='first').reset_index() self.memory_used_GB_df = df.pivot_table(index=index_list, columns='compiler', values='memory_used_GB', aggfunc='first').reset_index() - if self.output_format not in ('none', 'print'): + if self.output_format == "xlsx": output_ext = {'xlsx': '.xlsx', }[self.output_format] if not is_teardown: filename = 'examples/lit-gpt/mid_output_parameterized_results' + str(output_ext) @@ -84,7 +88,6 @@ def complete_dataframe(self, is_teardown): print(self.memory_used_GB_df) def run_benchmark(self, kwargs): - # benchmark_file = 'thunder/benchmarks/benchmark_litgpt.py' command_list = [] for key, val in kwargs.items(): command_list.append("--" + str(key) + "=" + str(val)) @@ -98,32 +101,26 @@ def run_benchmark(self, kwargs): print(f'Running {" ".join(subprocess_cmd)!r}') proc_output = subprocess.run(subprocess_cmd, capture_output=True, text=True) + + self.perf_metrics_dict = {} + if os.path.exists(self.json_file_path): + with open(self.json_file_path, 'r') as file: + self.perf_metrics_dict = json.load(file) + # Cleanup after the benchmark finishes. It might have failed before creating this + os.remove(self.json_file_path) + if proc_output.returncode: - print(proc_output.stdout) - print(proc_output.stderr) - proc_output.check_returncode() - - with open(self.json_file_path, 'r') as file: - self.perf_metrics_dict = json.load(file) - os.remove(self.json_file_path) #cleanup after test finishes - - if self.perf_metrics_dict['average_iter_time'] is None: - if 'CUDA out of memory' in proc_output.stdout: - self.perf_metrics_dict['average_iter_time'] = 'OOM' - self.perf_metrics_dict['model_flops'] = 'OOM' - self.perf_metrics_dict['model_flop_per_sec'] = 'OOM' - self.perf_metrics_dict['tokens_per_sec'] = 'OOM' - self.perf_metrics_dict['tokens_per_sec_per_gpu'] = 'OOM' - self.perf_metrics_dict['memory_used_GB'] = 'OOM' + if 'CUDA out of memory' in proc_output.stdout or "CUDA error: out of memory" in proc_output.stderr: + defaultdict_oom = defaultdict(lambda: "OOM") + defaultdict_oom.update(self.perf_metrics_dict) + self.perf_metrics_dict = defaultdict_oom pass_str = "TestCase did not finish reporting metrics due to CUDA out of memory error. Reporting OOM and triggering test success." return True, pass_str - else: - print(proc_output.stdout) - print(proc_output.stderr) - fail_str = "Testcase did not finish reporting metrics due to an unknown error. Triggering test failure." - return False, fail_str - else: - return True, "Test passed successfully." + print(proc_output.stdout) + print(proc_output.stderr) + fail_str = "TestCase did not finish reporting metrics due to an unknown error. Triggering test failure." + return False, fail_str + return True, "Test passed successfully." class Test(parameterized.TestCase): diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index ae9e8b2084..9120584989 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -9,13 +9,9 @@ import thunder from thunder.tests.lit_gpt_model import Config, GPT, Block -try: - from lightning.fabric.utilities.throughput import measure_flops +from lightning.fabric.utilities.throughput import measure_flops +from lightning.fabric.utilities import Throughput - # from lightning.fabric.utilities import Throughput - LIGHTNING_AVAILABLE = True -except: - LIGHTNING_AVAILABLE = False world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) @@ -109,7 +105,10 @@ def __init__( self.config.n_layer = n_layers # Initialize the model + t0 = time.perf_counter() + print(f"Loading model with {self.config.__dict__}") self.model = self.init_model() + print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") # Setup the distributed algorithm choices if self.distributed_mode != "none": @@ -138,14 +137,10 @@ def __init__( } def init_model(self): - print(f"Loading model with {self.config.__dict__}") init_device = torch.device("meta") if self.distributed_mode == "fsdp" else self.device - t0 = time.perf_counter() with init_device: model = GPT(self.config) - model.to(dtype=torch.bfloat16) - print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") - + model.to(dtype=torch.bfloat16) return model def setup_distributed(self): @@ -243,7 +238,7 @@ def pad_collate(batch): y_padded = torch.nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=-1) return x_padded, y_padded - train_data = DummyDataset(self.model.max_seq_length, self.dynamic) + train_data = DummyDataset(self.config.block_size, self.dynamic) train_dataloader = DataLoader( train_data, batch_size=self.micro_batch_size, num_workers=2, collate_fn=pad_collate ) @@ -251,24 +246,30 @@ def pad_collate(batch): return train_dataloader def calculate_model_flops(self): - input_ids, targets = next(self.train_data_iter) - input_ids = input_ids.to(device=self.device) - targets = targets.to(device=self.device) + meta = torch.device("meta") + device = self.device + self.device = meta + + # calculate flops on a meta-device model because we only care about the shapes and + # because the flops calculator installs hooks on the model + meta_model = self.init_model() - model_fwd = lambda: self.model(input_ids) + x = torch.randint(0, 1, (self.micro_batch_size, meta_model.config.block_size), device=meta) + model_fwd = lambda: meta_model(x) model_loss = lambda y: torch.nn.functional.cross_entropy( - y.reshape(-1, y.size(-1)), targets.reshape(-1), ignore_index=-1 + y.reshape(-1, y.size(-1)), x.reshape(-1), ignore_index=-1 ) - if LIGHTNING_AVAILABLE: - self.perf_metrics["model_flops"] = measure_flops(self.model, model_fwd, model_loss) / 1e12 + self.perf_metrics["model_flops"] = measure_flops(meta_model, model_fwd, model_loss) + + self.device = device def train(self): t0 = None - # if global_rank in [0, None]: - # #Calculate the model FLOPs - # self.calculate_model_flops() - # Setup Perf Collection - # self.throughput = Throughput(window_size=10, world_size=world_size) + if global_rank in [0, None]: + # Calculate the model FLOPs + self.calculate_model_flops() + # Setup throughput Collection + self.throughput = Throughput(window_size=self.max_iters - self.warmup_iter, world_size=world_size) if "transformerengine" in self.compile: import transformer_engine.pytorch as te @@ -326,45 +327,30 @@ def train(self): print( f"iter {i}: loss {loss_item:.4f}, iter time: {(t1 - iter_t0) * 1000:.2f}ms, t: {input_ids.size(1)}" ) - - # if global_rank in [0, None] and i >=warmup_iter: - # self.throughput.update( - # time=(t1-t0), - # flops=self.model_flops, - # batches=i, - # samples=(i * self.micro_batch_size * self.gradient_accumulation_steps), - # lengths=(i * self.micro_batch_size * self.gradient_accumulation_steps * self.model.max_seq_length), - # ) - - # metrics = self.throughput.compute() - # if i % 10 == 0: - # print(metrics) + if i >= self.warmup_iter: + self.throughput.update( + time=(t1 - t0), + flops=self.perf_metrics["model_flops"], + batches=i, + samples=(i * self.micro_batch_size * self.gradient_accumulation_steps), + lengths=(i * self.micro_batch_size * self.gradient_accumulation_steps * self.config.block_size), + ) if global_rank in [0, None]: # print(f"Total time: {(t1 - t0):.2f}s") - # print(f"Average time per iter: {((t1 - t0)*1000)/(max_iters-warmup_iter):.2f}ms") self.perf_metrics["average_iter_time"] = ((t1 - t0) * 1000) / (self.max_iters - self.warmup_iter) def add_perf_metrics(self): - # tokens_per_sec = total number of benchmarked iterations x global BS x block_size / total elapsed time (s) - # = global BS x block_size / (total elapsed time (s)/total number of benchmarked iterations) - # = global BS x block_size / average iter time (s) - self.perf_metrics["tokens_per_sec"] = ( - self.global_batch_size * self.model.max_seq_length * 1000 / self.perf_metrics["average_iter_time"] - ) # tokens/s - if self.perf_metrics["model_flops"] is not None: - self.perf_metrics["model_flop_per_sec"] = ( - self.perf_metrics["model_flops"] * 1000 / self.perf_metrics["average_iter_time"] - ) - if world_size is not None: - self.perf_metrics["model_flop_per_sec"] *= world_size + metrics = self.throughput.compute() + self.perf_metrics["tokens_per_sec"] = metrics.get("items_per_sec", metrics["device/items_per_sec"]) + self.perf_metrics["model_flop_per_sec"] = metrics.get("flops_per_sec", metrics["device/flops_per_sec"]) self.perf_metrics["memory_used_GB"] = torch.cuda.max_memory_allocated() / 1e9 def add_model_info_to_metrics(self): if global_rank in [0, None]: self.perf_metrics["model_name"] = self.model_name self.perf_metrics["Num GPUS"] = world_size - self.perf_metrics["Seq Len"] = self.model.max_seq_length + self.perf_metrics["Seq Len"] = self.config.block_size self.perf_metrics["Micro BS"] = self.micro_batch_size self.perf_metrics["Global BS"] = self.global_batch_size self.perf_metrics["GA"] = self.gradient_accumulation_steps @@ -416,7 +402,7 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None benchmark.add_perf_metrics() print( - f"Model name: {benchmark.model_name}\nSeq Length: {benchmark.model.max_seq_length}\nMicro BS: {benchmark.micro_batch_size}\nGlobal BS: {benchmark.global_batch_size}" + f"Model name: {benchmark.model_name}\nSeq Length: {benchmark.config.block_size}\nMicro BS: {benchmark.micro_batch_size}\nGlobal BS: {benchmark.global_batch_size}" ) print( f"Number of Layers: {benchmark.config.n_layer}\nNumber of parameters: {sum(p.numel() for p in benchmark.model.parameters() if p.requires_grad) / 1e9:.02f}B" @@ -429,12 +415,9 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None print(f"Compiler: {benchmark.compile}") print(f"Average iter time: {benchmark.perf_metrics['average_iter_time']:.2f} ms") print(f"Memory used: {benchmark.perf_metrics['memory_used_GB']:.02f} GB") - print(f"Throughput (Tokens/s): {benchmark.perf_metrics['tokens_per_sec']:.02f} tokens/s") - print( - f"Normalized Throughput (Tokens/s/GPU): {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f} tokens/s/gpu" - ) - if benchmark.perf_metrics["model_flop_per_sec"] is not None: - print(f"Model TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec']:.02f} TFLOP/s") + print(f"Tokens/s: {benchmark.perf_metrics['tokens_per_sec']:.02f}") + print(f"Tokens/s/GPU: {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f}") + print(f"TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec'] / 1e12:.02f}") except Exception as error: # Helps catch OutOfMemory Errors and post processing of errors