Skip to content

Commit

Permalink
Merge branch 'check_nan_in_grad' into 'main'
Browse files Browse the repository at this point in the history
Fix NaN checking in grads: should be performed before data-parallel all-reduce

See merge request ADLR/megatron-lm!989
  • Loading branch information
deepakn94 committed Feb 28, 2024
2 parents 1dada7e + d668077 commit 53a350e
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 67 deletions.
7 changes: 7 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,13 @@ def validate_args(args, defaults={}):
if args.fp16:
assert not args.bf16
args.params_dtype = torch.half
# Turn off checking for NaNs in loss and grads if using dynamic loss scaling,
# where NaNs in grads / loss are signal to the loss scaler.
if not args.loss_scale:
args.check_for_nan_in_loss_and_grad = False
if args.rank == 0:
print('WARNING: Setting args.check_for_nan_in_loss_and_grad to False since '
'dynamic loss scaling is being used')
if args.bf16:
assert not args.fp16
args.params_dtype = torch.bfloat16
Expand Down
8 changes: 7 additions & 1 deletion megatron/core/distributed/distributed_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class DistributedDataParallel(MegatronModule):
disable_bucketing: If true, force assign all parameters to a single bucket. If false,
use standard bucketing policy: assign parameters to smaller buckets and all-reduce
per bucket _if_ overlap_grad_reduce is True and pp_rank is 0.
check_for_nan_in_grad: If true, check if local grad norm is NaN.
"""

Expand All @@ -46,6 +47,7 @@ def __init__(
use_distributed_optimizer: bool,
expert_data_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
disable_bucketing: bool = False,
check_for_nan_in_grad: bool = False,
bucket_size: int = 40000000,
):
super().__init__(config=config)
Expand All @@ -66,6 +68,8 @@ def __init__(
bucket_size = None
if disable_bucketing:
bucket_size = None

self.check_for_nan_in_grad = check_for_nan_in_grad
self.bucket_size = bucket_size

self.module = module
Expand Down Expand Up @@ -115,7 +119,8 @@ def allocate_grad_buffers_for_parameters(
param_to_name,
self.overlap_grad_reduce,
self.use_distributed_optimizer,
gradient_scaling_factor=gradient_scaling_factor,
gradient_scaling_factor,
self.check_for_nan_in_grad,
)
)
for param in params:
Expand Down Expand Up @@ -176,6 +181,7 @@ def param_hook(*unused):
):
param.main_grad.add_(param.grad.data)
param.grad = None

if self.overlap_grad_reduce:
param_to_grad_buffer[param].register_grad_ready(param)

Expand Down
19 changes: 19 additions & 0 deletions megatron/core/distributed/grad_buffer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import math
import os
from logging import getLogger
from typing import Dict, List

Expand Down Expand Up @@ -44,6 +45,7 @@ class Bucket:
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
check_for_nan_in_grad: If true, check if local grad norm is NaN.
"""

def __init__(
Expand All @@ -57,6 +59,7 @@ def __init__(
overlap_grad_reduce: bool,
use_distributed_optimizer: bool,
gradient_scaling_factor: float,
check_for_nan_in_grad: bool,
):
# State for bookkeeping: params is the set of parameters this bucket is
# responsible for, params_with_grad is the set of parameters with grads
Expand All @@ -76,6 +79,7 @@ def __init__(
self.overlap_grad_reduce = overlap_grad_reduce
self.use_distributed_optimizer = use_distributed_optimizer
self.gradient_scaling_factor = gradient_scaling_factor
self.check_for_nan_in_grad = check_for_nan_in_grad

self.reset()

Expand All @@ -100,6 +104,17 @@ def start_grad_sync(self):
self.communication_handle is None and not self.communication_issued
), 'Should not have multiple communication calls in flight at once'

# Make sure norm of grads in bucket are not NaN
# prior to data-parallel all-reduce / reduce-scatter.
if self.check_for_nan_in_grad:
global_rank = torch.distributed.get_rank()
norm = self.data.norm(p=2)
assert not norm.isnan(), (
f'Rank {global_rank}: found NaN in local grad norm in '
f'backward pass before data-parallel communication collective. '
f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}'
)

self.data *= self.gradient_scaling_factor
# Use async_op only when overlap_grad_reduce is True.
if self.use_distributed_optimizer:
Expand Down Expand Up @@ -173,6 +188,7 @@ class GradBuffer:
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
check_for_nan_in_grad: If true, check if local grad norm is NaN.
"""

def __init__(
Expand All @@ -185,6 +201,7 @@ def __init__(
overlap_grad_reduce: bool,
use_distributed_optimizer: bool,
gradient_scaling_factor: float,
check_for_nan_in_grad: bool,
):

# Check that params are unique.
Expand All @@ -203,6 +220,7 @@ def __init__(
self.overlap_grad_reduce = overlap_grad_reduce
self.use_distributed_optimizer = use_distributed_optimizer
self.gradient_scaling_factor = gradient_scaling_factor
self.check_for_nan_in_grad = check_for_nan_in_grad
self.is_last_microbatch = True

# Data structures to store underlying buckets and relevant indexing data.
Expand Down Expand Up @@ -399,6 +417,7 @@ def _set_bucket(
overlap_grad_reduce=self.overlap_grad_reduce,
use_distributed_optimizer=self.use_distributed_optimizer,
gradient_scaling_factor=self.gradient_scaling_factor,
check_for_nan_in_grad=self.check_for_nan_in_grad,
)
self.buckets.append(bucket)
for bucket_param in bucket_params:
Expand Down
7 changes: 1 addition & 6 deletions megatron/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def get_megatron_optimizer_based_on_param_groups(
optimizer,
config.clip_grad,
config.log_num_zeros_in_grad,
config.check_for_nan_in_loss_and_grad,
params_have_main_grad,
config.fp16,
config.bf16,
Expand All @@ -184,11 +183,7 @@ def get_megatron_optimizer_based_on_param_groups(

# FP32.
return FP32Optimizer(
optimizer,
config.clip_grad,
config.log_num_zeros_in_grad,
config.check_for_nan_in_loss_and_grad,
params_have_main_grad,
optimizer, config.clip_grad, config.log_num_zeros_in_grad, params_have_main_grad,
)


Expand Down
18 changes: 1 addition & 17 deletions megatron/core/optimizer/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@


def clip_grad_norm_fp32(
parameters,
grads_for_norm,
max_norm,
check_for_nan_in_grad,
norm_type=2,
model_parallel_group=None,
parameters, grads_for_norm, max_norm, norm_type=2, model_parallel_group=None,
):
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
Expand All @@ -34,7 +29,6 @@ def clip_grad_norm_fp32(
grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single
Tensor that will be used for calculating the grad norm.
max_norm (float or int): max norm of the gradients.
check_for_nan_in_grad (bool): check if gradients have a NaN.
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
model_parallel_group (group): given the nature of the distributed
Expand Down Expand Up @@ -95,16 +89,6 @@ def clip_grad_norm_fp32(
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type

# Check individual rank grad norms are not NaN
# prior to model-parallel all-reduce.
if check_for_nan_in_grad:
global_rank = torch.distributed.get_rank()
assert not total_norm.isnan(), (
f'Rank {global_rank}: found NaN in local grad norm in '
f'backwards pass. Device: {torch.cuda.current_device()}, '
f'node: {os.uname()[1]}'
)

# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(
total_norm, op=torch.distributed.ReduceOp.SUM, group=model_parallel_group
Expand Down
3 changes: 0 additions & 3 deletions megatron/core/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
check_for_nan_in_grad: check if gradients have a NaN.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
Expand Down Expand Up @@ -374,7 +373,6 @@ def __init__(
optimizer,
clip_grad,
log_num_zeros_in_grad,
check_for_nan_in_grad,
params_have_main_grad,
fp16,
bf16,
Expand All @@ -399,7 +397,6 @@ def __init__(
optimizer,
clip_grad,
log_num_zeros_in_grad,
check_for_nan_in_grad,
params_have_main_grad,
fp16,
bf16,
Expand Down
44 changes: 8 additions & 36 deletions megatron/core/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,7 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):

class MegatronOptimizer(ABC):
def __init__(
self,
optimizer,
clip_grad,
log_num_zeros_in_grad,
check_for_nan_in_grad,
params_have_main_grad,
self, optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad,
):

"""Input optimizer is the base optimizer for example Adam."""
Expand All @@ -65,7 +60,6 @@ def __init__(
# Set gradient clipping and logging params.
self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad
self.check_for_nan_in_grad = check_for_nan_in_grad
self.params_have_main_grad = params_have_main_grad

def get_parameters(self):
Expand Down Expand Up @@ -97,15 +91,11 @@ def get_model_parallel_group(self):
"""Default returned here, but the distributed optimizer overrides this."""
return parallel_state.get_model_parallel_group()

def clip_grad_norm(self, clip_grad, check_for_nan_in_grad):
def clip_grad_norm(self, clip_grad):
params = self.get_parameters()
grads_for_norm = self.get_main_grads_for_grad_norm()
return clip_grad_norm_fp32(
params,
grads_for_norm,
clip_grad,
check_for_nan_in_grad,
model_parallel_group=self.get_model_parallel_group(),
params, grads_for_norm, clip_grad, model_parallel_group=self.get_model_parallel_group(),
)

def count_zeros(self):
Expand Down Expand Up @@ -176,7 +166,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
check_for_nan_in_grad: check if gradients have a NaN.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
Expand All @@ -201,7 +190,6 @@ def __init__(
optimizer,
clip_grad,
log_num_zeros_in_grad,
check_for_nan_in_grad,
params_have_main_grad,
fp16,
bf16,
Expand All @@ -210,11 +198,7 @@ def __init__(
):

super().__init__(
optimizer,
clip_grad,
log_num_zeros_in_grad,
check_for_nan_in_grad,
params_have_main_grad,
optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad,
)

self.fp16 = fp16
Expand Down Expand Up @@ -307,7 +291,7 @@ def step(self, args, timers):
timers('optimizer-clip-main-grad', log_level=1).start(barrier=args.barrier_with_L1_time)
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad, self.check_for_nan_in_grad)
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()

# Count the zeros in the grads.
Expand Down Expand Up @@ -339,7 +323,6 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
check_for_nan_in_grad: check if gradients have a NaN.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
Expand All @@ -363,7 +346,6 @@ def __init__(
optimizer,
clip_grad,
log_num_zeros_in_grad,
check_for_nan_in_grad,
params_have_main_grad,
fp16,
bf16,
Expand All @@ -375,7 +357,6 @@ def __init__(
optimizer,
clip_grad,
log_num_zeros_in_grad,
check_for_nan_in_grad,
params_have_main_grad,
fp16,
bf16,
Expand Down Expand Up @@ -558,20 +539,11 @@ def load_state_dict(self, state_dict):

class FP32Optimizer(MegatronOptimizer):
def __init__(
self,
optimizer,
clip_grad,
log_num_zeros_in_grad,
check_for_nan_in_grad,
params_have_main_grad,
self, optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad,
):

super(FP32Optimizer, self).__init__(
optimizer,
clip_grad,
log_num_zeros_in_grad,
check_for_nan_in_grad,
params_have_main_grad,
optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad,
)

self._scale = torch.tensor([1.0], dtype=torch.float, device='cuda')
Expand Down Expand Up @@ -603,7 +575,7 @@ def step(self, args, timers):
timers('optimizer-clip-main-grad', log_level=1).start(barrier=args.barrier_with_L1_time)
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad, self.check_for_nan_in_grad)
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()

# count the zeros in the grads
Expand Down
3 changes: 0 additions & 3 deletions megatron/core/optimizer/optimizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ class OptimizerConfig:
clip_grad (float): Gradient clipping based on global L2 norm.
log_num_zeros_in_grad (bool): If true, calculate and log the number of zeros in gradient.
check_for_nan_in_loss_and_grad (bool): If true, check for NaNs in loss and gradient.
"""

# Precision.
Expand Down Expand Up @@ -113,4 +111,3 @@ class OptimizerConfig:
# Miscellaneous.
clip_grad: float = 1.0
log_num_zeros_in_grad: bool = False
check_for_nan_in_loss_and_grad: bool = False
3 changes: 2 additions & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,8 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
use_distributed_optimizer=args.use_distributed_optimizer,
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0))
disable_bucketing=(model_chunk_idx > 0),
check_for_nan_in_grad=args.check_for_nan_in_loss_and_grad)
for (model_chunk_idx, model_chunk) in enumerate(model)]

# Broadcast params from data parallel src rank to other data parallel ranks.
Expand Down

0 comments on commit 53a350e

Please sign in to comment.