From dd4e1ffc1e04ac5586d5d3407abc4582d3f9cff0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 4 Nov 2024 14:02:22 +0200 Subject: [PATCH] NaN loss is instantly fatal --- mammoth/opts.py | 4 ---- mammoth/trainer.py | 58 ++++++++++++++++++---------------------------- 2 files changed, 23 insertions(+), 39 deletions(-) diff --git a/mammoth/opts.py b/mammoth/opts.py index 1a69809b..27abeda6 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -370,10 +370,6 @@ def _add_train_general_opts(parser): default=None, help='Criteria to use for early stopping.', ) - group.add( - '--max_nan_batches', '-max_nan_batches', type=int, default=0, - help='Number of batches that may be skipped due to loss blowout.' - ) # GPU group = parser.add_argument_group('Computation Environment') diff --git a/mammoth/trainer.py b/mammoth/trainer.py index 80552ce5..c1e447a5 100644 --- a/mammoth/trainer.py +++ b/mammoth/trainer.py @@ -13,7 +13,6 @@ import torch import torch.distributed import torch.nn as nn -import traceback from einops import rearrange from itertools import islice @@ -106,7 +105,6 @@ def build_trainer( task_queue_manager=task_queue_manager, report_stats_from_parameters=opts.report_stats_from_parameters, report_training_accuracy=opts.report_training_accuracy, - max_nan_batches=opts.max_nan_batches, ) return trainer @@ -154,7 +152,6 @@ def __init__( task_queue_manager=None, report_stats_from_parameters=False, report_training_accuracy=False, - max_nan_batches=0, ): # Basic attributes. self.model = model @@ -177,8 +174,6 @@ def __init__( self.earlystopper = earlystopper self.dropout = dropout self.dropout_steps = dropout_steps - self.max_nan_batches = max_nan_batches - self.nan_batches = 0 self.task_queue_manager = task_queue_manager @@ -495,37 +490,30 @@ def _gradient_accumulation( ) # logger.info(loss) - try: - if loss is not None: - if torch.isnan(loss): - raise NanLossException('Loss blowout') - # loss /= normalization - self.optim.backward(loss) - - if self.report_training_accuracy: - # Slow: requires max over logits, eq, masked_select - batch_stats = Statistics.from_loss_logits_target( - loss.item(), - logits, - target, - padding_idx=self.loss_functions[metadata.tgt_lang].ignore_index, - ) - else: - batch_stats = Statistics( - loss.item(), - num_tokens, - n_correct=None, - ) + if loss is not None: + if torch.isnan(loss): + raise NanLossException('Loss blowout') + # loss /= normalization + self.optim.backward(loss) + + if self.report_training_accuracy: + # Slow: requires max over logits, eq, masked_select + batch_stats = Statistics.from_loss_logits_target( + loss.item(), + logits, + target, + padding_idx=self.loss_functions[metadata.tgt_lang].ignore_index, + ) + else: + batch_stats = Statistics( + loss.item(), + num_tokens, + n_correct=None, + ) - total_stats.update(batch_stats) - report_stats.update(batch_stats) - report_stats.update_task_loss(batch_stats.loss, metadata) - except NanLossException: - traceback.print_exc() - logger.info("At step %d, we removed a batch - accum %d", self.optim.training_step, k) - self.nan_batches += 1 - if self.nan_batches >= self.max_nan_batches: - raise NanLossException('Exceeded allowed --max_nan_batches.') + total_stats.update(batch_stats) + report_stats.update(batch_stats) + report_stats.update_task_loss(batch_stats.loss, metadata) if len(seen_comm_batches) != 1: logger.warning('Communication batches out of synch with batch accumulation')