From d66807759cb070b09c5eda9f31f0663a9858b50e Mon Sep 17 00:00:00 2001 From: Deepak Narayanan Date: Tue, 5 Dec 2023 10:56:54 -0800 Subject: [PATCH] Fix NaN checking in grads: should be performed before data-parallel communication Compute norm once per batch (instead of once per microbatch) and once per bucket (instead of once per param) --- megatron/arguments.py | 7 +++ .../distributed/distributed_data_parallel.py | 8 +++- megatron/core/distributed/grad_buffer.py | 19 ++++++++ megatron/core/optimizer/__init__.py | 7 +-- megatron/core/optimizer/clip_grads.py | 18 +------- megatron/core/optimizer/distrib_optimizer.py | 3 -- megatron/core/optimizer/optimizer.py | 44 ++++--------------- megatron/core/optimizer/optimizer_config.py | 3 -- megatron/training.py | 3 +- 9 files changed, 45 insertions(+), 67 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index bffb098818..d481a0781c 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -181,6 +181,13 @@ def validate_args(args, defaults={}): if args.fp16: assert not args.bf16 args.params_dtype = torch.half + # Turn off checking for NaNs in loss and grads if using dynamic loss scaling, + # where NaNs in grads / loss are signal to the loss scaler. + if not args.loss_scale: + args.check_for_nan_in_loss_and_grad = False + if args.rank == 0: + print('WARNING: Setting args.check_for_nan_in_loss_and_grad to False since ' + 'dynamic loss scaling is being used') if args.bf16: assert not args.fp16 args.params_dtype = torch.bfloat16 diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index e3c8ece83a..d8cc637236 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -33,6 +33,7 @@ class DistributedDataParallel(MegatronModule): disable_bucketing: If true, force assign all parameters to a single bucket. If false, use standard bucketing policy: assign parameters to smaller buckets and all-reduce per bucket _if_ overlap_grad_reduce is True and pp_rank is 0. + check_for_nan_in_grad: If true, check if local grad norm is NaN. """ @@ -46,6 +47,7 @@ def __init__( use_distributed_optimizer: bool, expert_data_parallel_group: Optional[torch.distributed.ProcessGroup] = None, disable_bucketing: bool = False, + check_for_nan_in_grad: bool = False, bucket_size: int = 40000000, ): super().__init__(config=config) @@ -66,6 +68,8 @@ def __init__( bucket_size = None if disable_bucketing: bucket_size = None + + self.check_for_nan_in_grad = check_for_nan_in_grad self.bucket_size = bucket_size self.module = module @@ -115,7 +119,8 @@ def allocate_grad_buffers_for_parameters( param_to_name, self.overlap_grad_reduce, self.use_distributed_optimizer, - gradient_scaling_factor=gradient_scaling_factor, + gradient_scaling_factor, + self.check_for_nan_in_grad, ) ) for param in params: @@ -176,6 +181,7 @@ def param_hook(*unused): ): param.main_grad.add_(param.grad.data) param.grad = None + if self.overlap_grad_reduce: param_to_grad_buffer[param].register_grad_ready(param) diff --git a/megatron/core/distributed/grad_buffer.py b/megatron/core/distributed/grad_buffer.py index 949bc9468c..17d77c270d 100644 --- a/megatron/core/distributed/grad_buffer.py +++ b/megatron/core/distributed/grad_buffer.py @@ -1,6 +1,7 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import math +import os from logging import getLogger from typing import Dict, List @@ -44,6 +45,7 @@ class Bucket: gradient_scaling_factor: This factor is utilized to scale gradients prior to their communication. Its application is twofold: it facilitates the averaging of gradients and the scaling of gradients in the context of the Mixture of Experts (MoE) model. + check_for_nan_in_grad: If true, check if local grad norm is NaN. """ def __init__( @@ -57,6 +59,7 @@ def __init__( overlap_grad_reduce: bool, use_distributed_optimizer: bool, gradient_scaling_factor: float, + check_for_nan_in_grad: bool, ): # State for bookkeeping: params is the set of parameters this bucket is # responsible for, params_with_grad is the set of parameters with grads @@ -76,6 +79,7 @@ def __init__( self.overlap_grad_reduce = overlap_grad_reduce self.use_distributed_optimizer = use_distributed_optimizer self.gradient_scaling_factor = gradient_scaling_factor + self.check_for_nan_in_grad = check_for_nan_in_grad self.reset() @@ -100,6 +104,17 @@ def start_grad_sync(self): self.communication_handle is None and not self.communication_issued ), 'Should not have multiple communication calls in flight at once' + # Make sure norm of grads in bucket are not NaN + # prior to data-parallel all-reduce / reduce-scatter. + if self.check_for_nan_in_grad: + global_rank = torch.distributed.get_rank() + norm = self.data.norm(p=2) + assert not norm.isnan(), ( + f'Rank {global_rank}: found NaN in local grad norm in ' + f'backward pass before data-parallel communication collective. ' + f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}' + ) + self.data *= self.gradient_scaling_factor # Use async_op only when overlap_grad_reduce is True. if self.use_distributed_optimizer: @@ -173,6 +188,7 @@ class GradBuffer: gradient_scaling_factor: This factor is utilized to scale gradients prior to their communication. Its application is twofold: it facilitates the averaging of gradients and the scaling of gradients in the context of the Mixture of Experts (MoE) model. + check_for_nan_in_grad: If true, check if local grad norm is NaN. """ def __init__( @@ -185,6 +201,7 @@ def __init__( overlap_grad_reduce: bool, use_distributed_optimizer: bool, gradient_scaling_factor: float, + check_for_nan_in_grad: bool, ): # Check that params are unique. @@ -203,6 +220,7 @@ def __init__( self.overlap_grad_reduce = overlap_grad_reduce self.use_distributed_optimizer = use_distributed_optimizer self.gradient_scaling_factor = gradient_scaling_factor + self.check_for_nan_in_grad = check_for_nan_in_grad self.is_last_microbatch = True # Data structures to store underlying buckets and relevant indexing data. @@ -384,6 +402,7 @@ def _set_bucket( overlap_grad_reduce=self.overlap_grad_reduce, use_distributed_optimizer=self.use_distributed_optimizer, gradient_scaling_factor=self.gradient_scaling_factor, + check_for_nan_in_grad=self.check_for_nan_in_grad, ) self.buckets.append(bucket) for bucket_param in bucket_params: diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index b3461f9032..231d986fb7 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -162,7 +162,6 @@ def get_megatron_optimizer_based_on_param_groups( optimizer, config.clip_grad, config.log_num_zeros_in_grad, - config.check_for_nan_in_loss_and_grad, params_have_main_grad, config.fp16, config.bf16, @@ -184,11 +183,7 @@ def get_megatron_optimizer_based_on_param_groups( # FP32. return FP32Optimizer( - optimizer, - config.clip_grad, - config.log_num_zeros_in_grad, - config.check_for_nan_in_loss_and_grad, - params_have_main_grad, + optimizer, config.clip_grad, config.log_num_zeros_in_grad, params_have_main_grad, ) diff --git a/megatron/core/optimizer/clip_grads.py b/megatron/core/optimizer/clip_grads.py index 4ad2445a89..0f94754c9d 100644 --- a/megatron/core/optimizer/clip_grads.py +++ b/megatron/core/optimizer/clip_grads.py @@ -14,12 +14,7 @@ def clip_grad_norm_fp32( - parameters, - grads_for_norm, - max_norm, - check_for_nan_in_grad, - norm_type=2, - model_parallel_group=None, + parameters, grads_for_norm, max_norm, norm_type=2, model_parallel_group=None, ): """Clips gradient norm of an iterable of parameters whose gradients are in fp32. @@ -34,7 +29,6 @@ def clip_grad_norm_fp32( grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single Tensor that will be used for calculating the grad norm. max_norm (float or int): max norm of the gradients. - check_for_nan_in_grad (bool): check if gradients have a NaN. norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. model_parallel_group (group): given the nature of the distributed @@ -95,16 +89,6 @@ def clip_grad_norm_fp32( grad_norm = torch.norm(grad, norm_type) total_norm += grad_norm ** norm_type - # Check individual rank grad norms are not NaN - # prior to model-parallel all-reduce. - if check_for_nan_in_grad: - global_rank = torch.distributed.get_rank() - assert not total_norm.isnan(), ( - f'Rank {global_rank}: found NaN in local grad norm in ' - f'backwards pass. Device: {torch.cuda.current_device()}, ' - f'node: {os.uname()[1]}' - ) - # Sum across all model-parallel GPUs. torch.distributed.all_reduce( total_norm, op=torch.distributed.ReduceOp.SUM, group=model_parallel_group diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index 1423a6abb6..3eb66d7b90 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -45,7 +45,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer): clip_grad: clip gradeints with this global L2 norm. Note that clipping is ignored if clip_grad == 0 log_num_zeros_in_grad: return number of zeros in the gradients. - check_for_nan_in_grad: check if gradients have a NaN. params_have_main_grad: flag indicating if parameters have a `main_grad` field. If this is set, we are assuming that the model parameters are store in the `main_grad` @@ -374,7 +373,6 @@ def __init__( optimizer, clip_grad, log_num_zeros_in_grad, - check_for_nan_in_grad, params_have_main_grad, fp16, bf16, @@ -399,7 +397,6 @@ def __init__( optimizer, clip_grad, log_num_zeros_in_grad, - check_for_nan_in_grad, params_have_main_grad, fp16, bf16, diff --git a/megatron/core/optimizer/optimizer.py b/megatron/core/optimizer/optimizer.py index a3a431d6ae..5caa6b96d5 100644 --- a/megatron/core/optimizer/optimizer.py +++ b/megatron/core/optimizer/optimizer.py @@ -51,12 +51,7 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): class MegatronOptimizer(ABC): def __init__( - self, - optimizer, - clip_grad, - log_num_zeros_in_grad, - check_for_nan_in_grad, - params_have_main_grad, + self, optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad, ): """Input optimizer is the base optimizer for example Adam.""" @@ -65,7 +60,6 @@ def __init__( # Set gradient clipping and logging params. self.clip_grad = clip_grad self.log_num_zeros_in_grad = log_num_zeros_in_grad - self.check_for_nan_in_grad = check_for_nan_in_grad self.params_have_main_grad = params_have_main_grad def get_parameters(self): @@ -97,15 +91,11 @@ def get_model_parallel_group(self): """Default returned here, but the distributed optimizer overrides this.""" return parallel_state.get_model_parallel_group() - def clip_grad_norm(self, clip_grad, check_for_nan_in_grad): + def clip_grad_norm(self, clip_grad): params = self.get_parameters() grads_for_norm = self.get_main_grads_for_grad_norm() return clip_grad_norm_fp32( - params, - grads_for_norm, - clip_grad, - check_for_nan_in_grad, - model_parallel_group=self.get_model_parallel_group(), + params, grads_for_norm, clip_grad, model_parallel_group=self.get_model_parallel_group(), ) def count_zeros(self): @@ -176,7 +166,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer): clip_grad: clip gradeints with this global L2 norm. Note that clipping is ignored if clip_grad == 0 log_num_zeros_in_grad: return number of zeros in the gradients. - check_for_nan_in_grad: check if gradients have a NaN. params_have_main_grad: flag indicating if parameters have a `main_grad` field. If this is set, we are assuming that the model parameters are store in the `main_grad` @@ -201,7 +190,6 @@ def __init__( optimizer, clip_grad, log_num_zeros_in_grad, - check_for_nan_in_grad, params_have_main_grad, fp16, bf16, @@ -210,11 +198,7 @@ def __init__( ): super().__init__( - optimizer, - clip_grad, - log_num_zeros_in_grad, - check_for_nan_in_grad, - params_have_main_grad, + optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad, ) self.fp16 = fp16 @@ -307,7 +291,7 @@ def step(self, args, timers): timers('optimizer-clip-main-grad', log_level=1).start(barrier=args.barrier_with_L1_time) grad_norm = None if self.clip_grad > 0.0: - grad_norm = self.clip_grad_norm(self.clip_grad, self.check_for_nan_in_grad) + grad_norm = self.clip_grad_norm(self.clip_grad) timers('optimizer-clip-main-grad').stop() # Count the zeros in the grads. @@ -339,7 +323,6 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): clip_grad: clip gradeints with this global L2 norm. Note that clipping is ignored if clip_grad == 0 log_num_zeros_in_grad: return number of zeros in the gradients. - check_for_nan_in_grad: check if gradients have a NaN. params_have_main_grad: flag indicating if parameters have a `main_grad` field. If this is set, we are assuming that the model parameters are store in the `main_grad` @@ -363,7 +346,6 @@ def __init__( optimizer, clip_grad, log_num_zeros_in_grad, - check_for_nan_in_grad, params_have_main_grad, fp16, bf16, @@ -375,7 +357,6 @@ def __init__( optimizer, clip_grad, log_num_zeros_in_grad, - check_for_nan_in_grad, params_have_main_grad, fp16, bf16, @@ -558,20 +539,11 @@ def load_state_dict(self, state_dict): class FP32Optimizer(MegatronOptimizer): def __init__( - self, - optimizer, - clip_grad, - log_num_zeros_in_grad, - check_for_nan_in_grad, - params_have_main_grad, + self, optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad, ): super(FP32Optimizer, self).__init__( - optimizer, - clip_grad, - log_num_zeros_in_grad, - check_for_nan_in_grad, - params_have_main_grad, + optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad, ) self._scale = torch.tensor([1.0], dtype=torch.float, device='cuda') @@ -603,7 +575,7 @@ def step(self, args, timers): timers('optimizer-clip-main-grad', log_level=1).start(barrier=args.barrier_with_L1_time) grad_norm = None if self.clip_grad > 0.0: - grad_norm = self.clip_grad_norm(self.clip_grad, self.check_for_nan_in_grad) + grad_norm = self.clip_grad_norm(self.clip_grad) timers('optimizer-clip-main-grad').stop() # count the zeros in the grads diff --git a/megatron/core/optimizer/optimizer_config.py b/megatron/core/optimizer/optimizer_config.py index 2689d667bd..664e7c9036 100644 --- a/megatron/core/optimizer/optimizer_config.py +++ b/megatron/core/optimizer/optimizer_config.py @@ -78,8 +78,6 @@ class OptimizerConfig: clip_grad (float): Gradient clipping based on global L2 norm. log_num_zeros_in_grad (bool): If true, calculate and log the number of zeros in gradient. - - check_for_nan_in_loss_and_grad (bool): If true, check for NaNs in loss and gradient. """ # Precision. @@ -113,4 +111,3 @@ class OptimizerConfig: # Miscellaneous. clip_grad: float = 1.0 log_num_zeros_in_grad: bool = False - check_for_nan_in_loss_and_grad: bool = False diff --git a/megatron/training.py b/megatron/training.py index d604e6c489..e39d13e2e7 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -413,7 +413,8 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap use_distributed_optimizer=args.use_distributed_optimizer, # Turn off bucketing for model_chunk 2 onwards, since communication for these # model chunks is overlapped with compute anyway. - disable_bucketing=(model_chunk_idx > 0)) + disable_bucketing=(model_chunk_idx > 0), + check_for_nan_in_grad=args.check_for_nan_in_loss_and_grad) for (model_chunk_idx, model_chunk) in enumerate(model)] # Broadcast params from data parallel src rank to other data parallel ranks.