Skip to content

Commit

Permalink
Fix conflict between fp16 and deterministic sampling
Browse files Browse the repository at this point in the history
Due to the removal of the grad hook, MultipleOptimizer no longer has a
method step, it has been replaced with externally_managed_step which
takes information about which optimizers need to be stepped. This means
that it is no longer compatible with torch.cuda.amp.GradScaler.

While fixing this issue, the MultipleOptimizer system was also
refactored.
- MultipleOptimizer and the OpenNMT Optimizer wrapper switched places:
  MultipleOptimizer now wraps the other one, instead of the reverse.
- The OpenNMT Optimizer was renamed to SubOptimizer for clarity.
- SubOptimizer handles learning rate scheduling and grad clipping.
- MultipleOptimizer handles creation of multiple optimizers, grad scaling,
  restoring from checkpoint, backward, zero_grad, deciding which
  suboptimizers to step, and reporting.
- Each optimizer now individually controls its learning rate schedule.
  When new components with freshly initialized parameters are introduced
  by the curriculum, they now apply warmup to the LR of these
  parameters. This should improve stability.
- As each optimizer has its own learning rate, it is not obvious what to
  log in the report_training one-liner. Learning rate was removed.
  Instead, all optimizers log their learning rates. This is currently
  log spam, but will be lowered to debug in #70.

Each sub-optimizer having its own GradScaler leads to multiple backward
passes and RuntimeError. There can only be one GradScaler, which must
therefore be the responsibility of MultipleOptimizer.

Closes: #71
  • Loading branch information
Waino committed May 27, 2024
1 parent 339960f commit 20bdc2f
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 230 deletions.
20 changes: 6 additions & 14 deletions mammoth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
""" Main entry point of the Mammoth library """
import mammoth.inputters
import mammoth.models
import mammoth.utils
import mammoth.modules
import mammoth.opts
from mammoth.trainer import Trainer
import sys
import mammoth.utils.optimizers
from mammoth.utils import optimizers

mammoth.utils.optimizers.Optim = mammoth.utils.optimizers.Optimizer
sys.modules["mammoth.Optim"] = mammoth.utils.optimizers
# FIXME: what is the purpose of this hack?
# import sys
# mammoth.utils.optimizers.Optim = mammoth.utils.optimizers.Optimizer
# sys.modules["mammoth.Optim"] = mammoth.utils.optimizers

__all__ = [
mammoth.inputters,
mammoth.models,
mammoth.utils,
mammoth.modules,
mammoth.opts,
"optimizers",
"Trainer"
]

Expand Down
4 changes: 2 additions & 2 deletions mammoth/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time

from mammoth.model_builder import build_model
from mammoth.utils.optimizers import Optimizer
from mammoth.utils.optimizers import MultipleOptimizer
from mammoth.utils.misc import set_random_seed
from mammoth.trainer import build_trainer
from mammoth.models import build_model_saver
Expand Down Expand Up @@ -119,7 +119,7 @@ def main(

# Build optimizer.
logger.info("{} - Build optimizer".format(device_context.id))
optim = Optimizer.from_opts(
optim = MultipleOptimizer.from_opts(
model,
opts,
task_queue_manager=task_queue_manager,
Expand Down
37 changes: 22 additions & 15 deletions mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ def __init__(
self.model.train()

def _accum_count(self, step):
if step == 0:
_accum = self.accum_count_l[0]
for i in range(len(self.accum_steps)):
if step > self.accum_steps[i]:
_accum = self.accum_count_l[i]
Expand Down Expand Up @@ -254,6 +256,7 @@ def train(
while True:
i += 1

# global training step
step = self.optim.training_step
self._maybe_update_dropout(step)

Expand All @@ -264,14 +267,16 @@ def train(
batch_task_sample = self.task_queue_manager.sample_corpus_ids()
my_task = batch_task_sample.tasks[self.task_queue_manager.global_rank]

gradient_syncs = self.task_queue_manager.distributed_component_gradient_sync(batch_task_sample)

self._gradient_accumulation(
batches_with_meta,
total_stats,
report_stats,
my_task,
gradient_syncs,
)

gradient_syncs = self.task_queue_manager.distributed_component_gradient_sync(batch_task_sample)
for gradient_sync in gradient_syncs:
component = gradient_sync.component
if not component.needs_communication():
Expand All @@ -293,25 +298,25 @@ def train(
self.optim.externally_managed_step(gradient_syncs)
self.optim.zero_grad()

if step % 1000 == 0 and step > 0:
# TODO: if you are going to uncomment that block, please make it optional
# logger.info(f'After gradient sync {step}')
# for name, p in self.model.named_parameters():
# logger.info(
# f'{device_context.node_rank}:{device_context.local_rank}'
# f' {name}: {p.flatten()[:10]}'
# )
if hasattr(self.optim._optimizer, 'report_steps'):
for line in self.optim._optimizer.report_steps():
logger.info(f'{device_context.node_rank}:{device_context.local_rank} {line}')
# if step % 1000 == 0 and step > 0:
# TODO: if you are going to uncomment that block, please make it optional
# logger.info(f'After gradient sync {step}')
# for name, p in self.model.named_parameters():
# logger.info(
# f'{device_context.node_rank}:{device_context.local_rank}'
# f' {name}: {p.flatten()[:10]}'
# )

if self.average_decay > 0 and i % self.average_every == 0:
self._update_average(step)

# Learning rate used to be retrieved with: self.optim.learning_rate()
# However, as each optimizer has its own learning rate, it is not obvious what to log here.
# We might log the mean or the range of learning rates, but the simplest thing is to log nothing.
report_stats = self._maybe_report_training(
step,
train_steps,
self.optim.learning_rate(),
None,
report_stats,
sampled_task_counts=self.task_queue_manager.sampled_task_counts,
)
Expand All @@ -330,7 +335,7 @@ def train(
logger.info(f'{device_context.node_rank}:{device_context.local_rank} report stat step {step}')
if device_context.is_master():
self._report_step(
self.optim.learning_rate(), # learning_rate_to_show, #self.optim.learning_rate(),
None,
step,
valid_stats=valid_stats,
)
Expand Down Expand Up @@ -412,6 +417,7 @@ def _gradient_accumulation(
total_stats,
report_stats,
my_task,
gradient_syncs,
):
normalization = 0
seen_comm_batches = set()
Expand Down Expand Up @@ -483,7 +489,7 @@ def _gradient_accumulation(

except Exception:
traceback.print_exc()
logger.info("At step %d, we removed a batch - accum %d", self.training_step_all, k)
logger.info("At step %d, we removed a batch - accum %d", self.optim.training_step, k)
if len(seen_comm_batches) != 1:
logger.warning('Communication batches out of synch with batch accumulation')

Expand Down Expand Up @@ -530,6 +536,7 @@ def _maybe_report_training(self, step, num_steps, learning_rate, report_stats, s
report_stats,
multigpu=self.device_context.is_distributed(),
sampled_task_counts=sampled_task_counts,
optimizer=self.optim,
)

def _report_step(self, learning_rate, step, train_stats=None, valid_stats=None):
Expand Down
4 changes: 1 addition & 3 deletions mammoth/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from mammoth.utils.alignment import make_batch_align_matrix
from mammoth.utils.report_manager import ReportMgr, build_report_manager
from mammoth.utils.statistics import Statistics
from mammoth.utils.optimizers import MultipleOptimizer, Optimizer, AdaFactorFairSeq
from mammoth.utils.optimizers import MultipleOptimizer
from mammoth.utils.earlystopping import EarlyStopping, scorers_from_opts
from mammoth.utils.loss import build_loss_compute

Expand All @@ -16,8 +16,6 @@
"build_report_manager",
"Statistics",
"MultipleOptimizer",
"Optimizer",
"AdaFactorFairSeq",
"EarlyStopping",
"scorers_from_opts",
"make_batch_align_matrix",
Expand Down
Loading

0 comments on commit 20bdc2f

Please sign in to comment.