Skip to content

Commit

Permalink
NaN loss is instantly fatal
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Nov 4, 2024
1 parent e3672b0 commit 223bed4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 38 deletions.
4 changes: 0 additions & 4 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,6 @@ def _add_train_general_opts(parser):
default=None,
help='Criteria to use for early stopping.',
)
group.add(
'--max_nan_batches', '-max_nan_batches', type=int, default=0,
help='Number of batches that may be skipped due to loss blowout.'
)

# GPU
group = parser.add_argument_group('Computation Environment')
Expand Down
57 changes: 23 additions & 34 deletions mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def build_trainer(
task_queue_manager=task_queue_manager,
report_stats_from_parameters=opts.report_stats_from_parameters,
report_training_accuracy=opts.report_training_accuracy,
max_nan_batches=opts.max_nan_batches,
)
return trainer

Expand Down Expand Up @@ -154,7 +153,6 @@ def __init__(
task_queue_manager=None,
report_stats_from_parameters=False,
report_training_accuracy=False,
max_nan_batches=0,
):
# Basic attributes.
self.model = model
Expand All @@ -177,8 +175,6 @@ def __init__(
self.earlystopper = earlystopper
self.dropout = dropout
self.dropout_steps = dropout_steps
self.max_nan_batches = max_nan_batches
self.nan_batches = 0

self.task_queue_manager = task_queue_manager

Expand Down Expand Up @@ -495,37 +491,30 @@ def _gradient_accumulation(
)
# logger.info(loss)

try:
if loss is not None:
if torch.isnan(loss):
raise NanLossException('Loss blowout')
# loss /= normalization
self.optim.backward(loss)

if self.report_training_accuracy:
# Slow: requires max over logits, eq, masked_select
batch_stats = Statistics.from_loss_logits_target(
loss.item(),
logits,
target,
padding_idx=self.loss_functions[metadata.tgt_lang].ignore_index,
)
else:
batch_stats = Statistics(
loss.item(),
num_tokens,
n_correct=None,
)
if loss is not None:
if torch.isnan(loss):
raise NanLossException('Loss blowout')
# loss /= normalization
self.optim.backward(loss)

if self.report_training_accuracy:
# Slow: requires max over logits, eq, masked_select
batch_stats = Statistics.from_loss_logits_target(
loss.item(),
logits,
target,
padding_idx=self.loss_functions[metadata.tgt_lang].ignore_index,
)
else:
batch_stats = Statistics(
loss.item(),
num_tokens,
n_correct=None,
)

total_stats.update(batch_stats)
report_stats.update(batch_stats)
report_stats.update_task_loss(batch_stats.loss, metadata)
except NanLossException:
traceback.print_exc()
logger.info("At step %d, we removed a batch - accum %d", self.optim.training_step, k)
self.nan_batches += 1
if self.nan_batches >= self.max_nan_batches:
raise NanLossException('Exceeded allowed --max_nan_batches.')
total_stats.update(batch_stats)
report_stats.update(batch_stats)
report_stats.update_task_loss(batch_stats.loss, metadata)

if len(seen_comm_batches) != 1:
logger.warning('Communication batches out of synch with batch accumulation')
Expand Down

0 comments on commit 223bed4

Please sign in to comment.