From c9f976087a9490b8b7326603c1531bb8c75c094a Mon Sep 17 00:00:00 2001 From: Said Taghadouini Date: Sat, 20 Apr 2024 22:49:39 +0200 Subject: [PATCH] Use CUDA Events for measuring elasped time --- src/nanotron/generation/decode.py | 11 ++++++++--- src/nanotron/helpers.py | 7 ++++--- src/nanotron/trainer.py | 9 ++++++--- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/nanotron/generation/decode.py b/src/nanotron/generation/decode.py index 48d801cc..6f9be916 100644 --- a/src/nanotron/generation/decode.py +++ b/src/nanotron/generation/decode.py @@ -227,13 +227,17 @@ def decode_text( ) if is_bench: - start_time, elapsed_time_first_iteration = time.perf_counter(), 0 + start_time = torch.cuda.Event(enable_timing=True) + end_time = torch.cuda.Event(enable_timing=True) + start_time.record() + elapsed_time_first_iteration = 0 for generation_iter in range(max_new_tokens): if is_bench and generation_iter == 0: + end_time.record() torch.cuda.synchronize() - elapsed_time_first_iteration = start_time - time.perf_counter() + elapsed_time_first_iteration = start_time.elapsed_time(end_time) / 1000 all_new_decoder_input_ids_and_mask_same_rank: List[ Tuple[Union[torch.LongTensor, TensorPointer], Union[torch.BoolTensor, TensorPointer]] @@ -384,8 +388,9 @@ def generator(): if is_bench: # Compute throughput (tok/s/gpu). Note that the first generation is done with full seq_len, so we don't count it. + end_time.record() torch.cuda.synchronize() - total_time_sec = time.perf_counter() - start_time - elapsed_time_first_iteration + total_time_sec = start_time.elapsed_time(end_time) / 1000 - elapsed_time_first_iteration # We generate 1 token per iteration per batch (batch=microbatch) # Number of tokens generated every iteration: gbs/iteration_time global_batch_size = len(batches) * parallel_context.dp_pg.size() diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 116cf653..4d3d7b9e 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -377,14 +377,15 @@ def test_all_pair_to_pair( continue test_tensor = torch.zeros((int(throughput_size),), dtype=torch.uint8, device=torch.device("cuda")) for k in range(throughput_iters): - pre = time.perf_counter() - torch.cuda.synchronize() + pre = torch.cuda.Event(enable_timing=True) + post = torch.cuda.Event(enable_timing=True) + pre.record() if wr == a: dist.send(test_tensor, b, group=parallel_context.world_pg, tag=i + k) elif wr == b: dist.recv(test_tensor, a, group=parallel_context.world_pg, tag=i + k) torch.cuda.synchronize() - duration = time.perf_counter() - pre + duration = pre.elapsed_time(post) / 1000 # time is reported in milliseconds del test_tensor gc.collect() torch.cuda.empty_cache() diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 4abf3722..8a504f84 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -372,7 +372,8 @@ def train( if isinstance(prof, torch.profiler.profile): prof.step() - self.iteration_start_time = time.time() + self.iteration_start_time = torch.cuda.Event(enable_timing=True) + self.iteration_start_time.record() self._update_dataloader_based_on_training_stages(dataloader_or_dls) # Training step @@ -500,9 +501,11 @@ def train_step_logs( loss_avg: Optional[torch.Tensor], ) -> None: # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 - dist.barrier() + iteration_end_time = torch.cuda.Event(enable_timing=True) + iteration_end_time.record() + # dist.barrier() torch.cuda.synchronize() - elapsed_time_per_iteration_ms = (time.time() - self.iteration_start_time) * 1000 + elapsed_time_per_iteration_ms = self.iteration_start_time.elapsed_time(iteration_end_time) # time reported in milliseconds tokens_per_sec = ( self.global_batch_size * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) ) # tokens_per_sec is calculated using sequence_length