diff --git a/mammoth/__init__.py b/mammoth/__init__.py index fd6ae773..5326c567 100644 --- a/mammoth/__init__.py +++ b/mammoth/__init__.py @@ -1,22 +1,14 @@ """ Main entry point of the Mammoth library """ -import mammoth.inputters -import mammoth.models -import mammoth.utils -import mammoth.modules -import mammoth.opts from mammoth.trainer import Trainer -import sys -import mammoth.utils.optimizers +from mammoth.utils import optimizers -mammoth.utils.optimizers.Optim = mammoth.utils.optimizers.Optimizer -sys.modules["mammoth.Optim"] = mammoth.utils.optimizers +# FIXME: what is the purpose of this hack? +# import sys +# mammoth.utils.optimizers.Optim = mammoth.utils.optimizers.Optimizer +# sys.modules["mammoth.Optim"] = mammoth.utils.optimizers __all__ = [ - mammoth.inputters, - mammoth.models, - mammoth.utils, - mammoth.modules, - mammoth.opts, + "optimizers", "Trainer" ] diff --git a/mammoth/train_single.py b/mammoth/train_single.py index 300ee158..cd6b4156 100644 --- a/mammoth/train_single.py +++ b/mammoth/train_single.py @@ -4,7 +4,7 @@ import time from mammoth.model_builder import build_model -from mammoth.utils.optimizers import Optimizer +from mammoth.utils.optimizers import MultipleOptimizer from mammoth.utils.misc import set_random_seed from mammoth.trainer import build_trainer from mammoth.models import build_model_saver @@ -119,7 +119,7 @@ def main( # Build optimizer. logger.info("{} - Build optimizer".format(device_context.id)) - optim = Optimizer.from_opts( + optim = MultipleOptimizer.from_opts( model, opts, task_queue_manager=task_queue_manager, diff --git a/mammoth/trainer.py b/mammoth/trainer.py index 6fc95d99..cb1a0b50 100644 --- a/mammoth/trainer.py +++ b/mammoth/trainer.py @@ -195,6 +195,8 @@ def __init__( self.model.train() def _accum_count(self, step): + if step == 0: + _accum = self.accum_count_l[0] for i in range(len(self.accum_steps)): if step > self.accum_steps[i]: _accum = self.accum_count_l[i] @@ -254,6 +256,7 @@ def train( while True: i += 1 + # global training step step = self.optim.training_step self._maybe_update_dropout(step) @@ -264,14 +267,16 @@ def train( batch_task_sample = self.task_queue_manager.sample_corpus_ids() my_task = batch_task_sample.tasks[self.task_queue_manager.global_rank] + gradient_syncs = self.task_queue_manager.distributed_component_gradient_sync(batch_task_sample) + self._gradient_accumulation( batches_with_meta, total_stats, report_stats, my_task, + gradient_syncs, ) - gradient_syncs = self.task_queue_manager.distributed_component_gradient_sync(batch_task_sample) for gradient_sync in gradient_syncs: component = gradient_sync.component if not component.needs_communication(): @@ -293,25 +298,25 @@ def train( self.optim.externally_managed_step(gradient_syncs) self.optim.zero_grad() - if step % 1000 == 0 and step > 0: - # TODO: if you are going to uncomment that block, please make it optional - # logger.info(f'After gradient sync {step}') - # for name, p in self.model.named_parameters(): - # logger.info( - # f'{device_context.node_rank}:{device_context.local_rank}' - # f' {name}: {p.flatten()[:10]}' - # ) - if hasattr(self.optim._optimizer, 'report_steps'): - for line in self.optim._optimizer.report_steps(): - logger.info(f'{device_context.node_rank}:{device_context.local_rank} {line}') + # if step % 1000 == 0 and step > 0: + # TODO: if you are going to uncomment that block, please make it optional + # logger.info(f'After gradient sync {step}') + # for name, p in self.model.named_parameters(): + # logger.info( + # f'{device_context.node_rank}:{device_context.local_rank}' + # f' {name}: {p.flatten()[:10]}' + # ) if self.average_decay > 0 and i % self.average_every == 0: self._update_average(step) + # Learning rate used to be retrieved with: self.optim.learning_rate() + # However, as each optimizer has its own learning rate, it is not obvious what to log here. + # We might log the mean or the range of learning rates, but the simplest thing is to log nothing. report_stats = self._maybe_report_training( step, train_steps, - self.optim.learning_rate(), + None, report_stats, sampled_task_counts=self.task_queue_manager.sampled_task_counts, ) @@ -330,7 +335,7 @@ def train( logger.info(f'{device_context.node_rank}:{device_context.local_rank} report stat step {step}') if device_context.is_master(): self._report_step( - self.optim.learning_rate(), # learning_rate_to_show, #self.optim.learning_rate(), + None, step, valid_stats=valid_stats, ) @@ -412,6 +417,7 @@ def _gradient_accumulation( total_stats, report_stats, my_task, + gradient_syncs, ): normalization = 0 seen_comm_batches = set() @@ -483,7 +489,7 @@ def _gradient_accumulation( except Exception: traceback.print_exc() - logger.info("At step %d, we removed a batch - accum %d", self.training_step_all, k) + logger.info("At step %d, we removed a batch - accum %d", self.optim.training_step, k) if len(seen_comm_batches) != 1: logger.warning('Communication batches out of synch with batch accumulation') @@ -530,6 +536,7 @@ def _maybe_report_training(self, step, num_steps, learning_rate, report_stats, s report_stats, multigpu=self.device_context.is_distributed(), sampled_task_counts=sampled_task_counts, + optimizer=self.optim, ) def _report_step(self, learning_rate, step, train_stats=None, valid_stats=None): diff --git a/mammoth/utils/__init__.py b/mammoth/utils/__init__.py index 49933156..eafb5432 100644 --- a/mammoth/utils/__init__.py +++ b/mammoth/utils/__init__.py @@ -3,7 +3,7 @@ from mammoth.utils.alignment import make_batch_align_matrix from mammoth.utils.report_manager import ReportMgr, build_report_manager from mammoth.utils.statistics import Statistics -from mammoth.utils.optimizers import MultipleOptimizer, Optimizer, AdaFactorFairSeq +from mammoth.utils.optimizers import MultipleOptimizer from mammoth.utils.earlystopping import EarlyStopping, scorers_from_opts from mammoth.utils.loss import build_loss_compute @@ -16,8 +16,6 @@ "build_report_manager", "Statistics", "MultipleOptimizer", - "Optimizer", - "AdaFactorFairSeq", "EarlyStopping", "scorers_from_opts", "make_batch_align_matrix", diff --git a/mammoth/utils/optimizers.py b/mammoth/utils/optimizers.py index 8d1e7b58..d3f29664 100644 --- a/mammoth/utils/optimizers.py +++ b/mammoth/utils/optimizers.py @@ -3,35 +3,13 @@ import math import torch import torch.optim as optim -from collections import Counter from math import sqrt from torch.nn.utils import clip_grad_norm_ from mammoth.utils.logging import logger -def attention_bridge_optimizer(model, task_queue_manager, base_optimizer): - suboptimizers = {} - # All components on device, in consistent order across devices - my_components = task_queue_manager.get_my_distributed_components() - # Also keeping components that are on a single device - for component in my_components: - name = component.get_name() - params = [ - param for param_name, param in component.named_parameters(model) - if param.requires_grad - ] - if name in suboptimizers: - raise Exception(f'Trying to create second optimizer for "{name}"') - if len(params) != 0: - optimizer = base_optimizer(params) - suboptimizers[name] = optimizer - - optimizer = MultipleOptimizer(suboptimizers, None) - return optimizer - - -def build_torch_optimizer(model, opts, task_queue_manager): - """Builds the PyTorch optimizer. +def get_base_optimizer(opts): + """Builds the PyTorch optimizer factory. We use the default parameters for Adam that are suggested by the original paper https://arxiv.org/pdf/1412.6980.pdf @@ -45,11 +23,10 @@ def build_torch_optimizer(model, opts, task_queue_manager): established value, so we use that here as well Args: - model: The model to optimize. opts. The dictionary of options. Returns: - A ``torch.optim.Optimizer`` instance. + A callable that returns ``torch.optim.Optimizer`` instances. """ betas = [opts.adam_beta1, opts.adam_beta2] if opts.optim == 'sgd': @@ -91,12 +68,7 @@ def build_torch_optimizer(model, opts, task_queue_manager): else: raise ValueError('Invalid optimizer type: ' + opts.optim) - optimizer = attention_bridge_optimizer( - model, - task_queue_manager, - base_optimizer, - ) - return optimizer + return base_optimizer def make_learning_rate_decay_fn(opts): @@ -166,227 +138,255 @@ def linear_warmup_decay(step, warmup_steps, rate, train_steps): return max(end_rate, (train_steps - step) / (train_steps - warmup_steps)) -class MultipleOptimizer(object): - """Separate optimizers for each distributed component""" - - def __init__(self, op): - self.optimizers = op - self._steps = Counter() +class SubOptimizer(object): + """ + Wraps a base optimizer (torch.optim.Optimizer). + Handles learning rate scheduling and grad clipping. + """ + def __init__( + self, + optimizer, + learning_rate, + learning_rate_decay_fn, + max_grad_norm=0, + grad_scaler=None, + ): + self._optimizer = optimizer + self._learning_rate = learning_rate + self._learning_rate_decay_fn = learning_rate_decay_fn + self._max_grad_norm = max_grad_norm + self.grad_scaler = grad_scaler + self._training_step = 1 + self._decay_step = 1 @property def param_groups(self): - param_groups = [] - for name in self.optimizers: - optimizer = self.optimizers[name] - param_groups.extend(optimizer.param_groups) - return param_groups - - def zero_grad(self): - """Reset the gradient of all sub-optimizers to zero""" - for name in self.optimizers: - self.optimizers[name].zero_grad() + return self._optimizer.param_groups - def externally_managed_step(self, gradient_syncs, grad_scaler=None): - """Step through only the trained suboptimizers""" - trained_components = { - gradient_sync.component.get_name() for gradient_sync in gradient_syncs - } - for name in self.optimizers: - if name in trained_components: - # logger.warning(f'Stepping {name}') # DEBUG - self._steps[name] += 1 - self.optimizers[name].step() + @property + def training_step(self): + """The current training step.""" + return self._training_step - def report_steps(self): - result = [] - for name in self.optimizers: - count = self._steps[name] - result.append(f'Optimizer "{name}" has been stepped {count} times') - return result + def learning_rate(self): + """Returns the current learning rate.""" + if self._learning_rate_decay_fn is None: + return self._learning_rate + scale = self._learning_rate_decay_fn(self._decay_step) + return scale * self._learning_rate def state_dict(self): - """Returns the state dictionary""" return { - 'optimizers': {k: v.state_dict() for k, v in self.optimizers.items()}, - 'steps': self._steps, + 'training_step': self._training_step, + 'decay_step': self._decay_step, + 'optimizer': self._optimizer.state_dict(), } def load_state_dict(self, state_dict): - """Loads the optimizer from the state dictionary""" + self._training_step = state_dict['training_step'] + # State can be partially restored. + if 'decay_step' in state_dict: + self._decay_step = state_dict['decay_step'] + if 'optimizer' in state_dict: + self._optimizer.load_state_dict(state_dict['optimizer']) - # do not load any optimizer state if one component is missing - do_load = True - for k in state_dict["optimizers"].keys(): - if k not in self.optimizers.keys(): - do_load = False + def zero_grad(self): + """Zero the gradients of optimized parameters.""" + self._optimizer.zero_grad() - if do_load is True: - for k in state_dict["optimizers"].keys(): - self.optimizers[k].load_state_dict(state_dict["optimizers"][k]) - else: - logger.info("Some components do not match. Do not load optimizer from checkpoint.") + def step(self): + """Update the model parameters based on current gradients. """ + learning_rate = self.learning_rate() - self._steps = state_dict["steps"] + for group in self._optimizer.param_groups: + group['lr'] = learning_rate + if self._max_grad_norm > 0: + clip_grad_norm_(group['params'], self._max_grad_norm) + + if self.grad_scaler is not None: + self.grad_scaler.step(self._optimizer) + else: + self._optimizer.step() + self._decay_step += 1 + self._training_step += 1 -class Optimizer(object): +class MultipleOptimizer(object): """ - Controller class for optimization. Mostly a thin - wrapper for `optim`, but also useful for implementing - rate scheduling beyond what is currently available. - Also implements necessary methods for training RNNs such - as grad manipulations. + Separate sub-optimizers for each distributed component. + Handles creation of multiple optimizers, grad scaling, + restoring from checkpoint, backward, zero_grad, + deciding which suboptimizers to step, and reporting. """ - - def __init__(self, optimizer, learning_rate, learning_rate_decay_fn=None, max_grad_norm=None): - """Initializes the controller. - - Args: - optimizer: A ``torch.optim.Optimizer`` instance. - learning_rate: The initial learning rate. - learning_rate_decay_fn: An optional callable taking the current step - as argument and return a learning rate scaling factor. - max_grad_norm: Clip gradients to this global norm. - """ - self._optimizer = optimizer - self._learning_rate = learning_rate - self._learning_rate_decay_fn = learning_rate_decay_fn - self._max_grad_norm = max_grad_norm or 0 - self._training_step = 1 - self._decay_step = 1 - self._fp16 = None - self._scaler = None + def __init__( + self, + suboptimizers, + grad_scaler=None, + ): + self.suboptimizers = suboptimizers + self.grad_scaler = grad_scaler + # Global training step is incremented for each minibatch. + # There may not be any actual parameters that have been trained this number of steps, + # or any optimizer stepped this number of times. + self.global_training_step = 1 @classmethod def from_opts(cls, model, opts, task_queue_manager, checkpoint=None): - """Builds the optimizer from options. + optim_opts, optim_state_dict = cls._maybe_restore_from_checkpoint(opts, checkpoint) + base_optimizer = get_base_optimizer(optim_opts) + use_grad_scaler = opts.model_dtype == "fp16" + if use_grad_scaler: + from torch.cuda.amp import GradScaler + grad_scaler = GradScaler() + else: + grad_scaler = None + suboptimizers = cls._get_suboptimizers( + model, + task_queue_manager, + base_optimizer, + optim_opts, + grad_scaler=grad_scaler, + ) - Args: - cls: The ``Optimizer`` class to instantiate. - model: The model to optimize. - opts: The dict of user options. - checkpoint: An optional checkpoint to load states from. + return cls( + suboptimizers=suboptimizers, + grad_scaler=grad_scaler, + ) - Returns: - An ``Optimizer`` instance. - """ - optim_opt = opts + @staticmethod + def _maybe_restore_from_checkpoint(opts, checkpoint=None): + optim_opts = opts optim_state_dict = None if opts.train_from and checkpoint is not None: - optim = checkpoint['optim'] - ckpt_opt = checkpoint['opts'] - ckpt_state_dict = {} - if isinstance(optim, Optimizer): # Backward compatibility. - ckpt_state_dict['training_step'] = optim._training_step - ckpt_state_dict['decay_step'] = optim._decay_step - ckpt_state_dict['optimizer'] = optim._optimizer.state_dict() - else: - ckpt_state_dict = optim + ckpt_opts = checkpoint['opts'] + ckpt_state_dict = checkpoint['optim'] if opts.reset_optim == 'none': # Load everything from the checkpoint. - optim_opt = ckpt_opt + optim_opts = ckpt_opts optim_state_dict = ckpt_state_dict elif opts.reset_optim == 'all': # Build everything from scratch. pass elif opts.reset_optim == 'states': # Reset optimizer, keep options. - optim_opt = ckpt_opt + optim_opts = ckpt_opts optim_state_dict = ckpt_state_dict del optim_state_dict['optimizer'] elif opts.reset_optim == 'keep_states': # Reset options, keep optimizer. optim_state_dict = ckpt_state_dict + return optim_opts, optim_state_dict + + @staticmethod + def _get_suboptimizers(model, task_queue_manager, base_optimizer, optim_opts, grad_scaler=None): + suboptimizers = {} + # All components on device, in consistent order across devices + my_components = task_queue_manager.get_my_distributed_components() + # Also keeping components that are on a single device + for component in my_components: + name = component.get_name() + params = [ + param for param_name, param in component.named_parameters(model) + if param.requires_grad + ] + if name in suboptimizers: + raise Exception(f'Trying to create second optimizer for "{name}"') + if len(params) != 0: + optimizer = SubOptimizer( + optimizer=base_optimizer(params), + learning_rate=optim_opts.learning_rate, + learning_rate_decay_fn=make_learning_rate_decay_fn(optim_opts), + max_grad_norm=optim_opts.max_grad_norm, + grad_scaler=grad_scaler, + ) + suboptimizers[name] = optimizer + + return suboptimizers - optimizer = cls( - build_torch_optimizer(model, optim_opt, task_queue_manager), - optim_opt.learning_rate, - learning_rate_decay_fn=make_learning_rate_decay_fn(optim_opt), - max_grad_norm=optim_opt.max_grad_norm, - ) + @property + def param_groups(self): + param_groups = [] + for name in self.suboptimizers: + optimizer = self.suboptimizers[name] + param_groups.extend(optimizer.param_groups) + return param_groups - if opts.model_dtype == "fp16": - optimizer._fp16 = "amp" - from torch.cuda.amp import GradScaler + def backward(self, loss): + if self.grad_scaler is not None: + self.grad_scaler.scale(loss).backward() + else: + loss.backward() - optimizer._scaler = GradScaler() + def zero_grad(self): + """Reset the gradient of all sub-optimizers to zero""" + for name in self.suboptimizers: + self.suboptimizers[name].zero_grad() - if optim_state_dict: - optimizer.load_state_dict(optim_state_dict) - return optimizer + def externally_managed_step(self, gradient_syncs): + """Step through only the trained suboptimizers""" - @property - def training_step(self): - """The current training step.""" - return self._training_step + self.global_training_step += 1 + trained_components = { + gradient_sync.component.get_name() for gradient_sync in gradient_syncs + } + for name, optimizer in self.suboptimizers.items(): + if name in trained_components: + # logger.warning(f'Stepping {name}') # DEBUG + if self.grad_scaler is not None: + self.grad_scaler.unscale_(optimizer) + optimizer.step() - @property - def amp(self): - """True if use torch amp mix precision training.""" - return self._fp16 == "amp" + if self.grad_scaler is not None: + # Updates the scale for next iteration. + self.grad_scaler.update() - def learning_rate(self): - """Returns the current learning rate.""" - if self._learning_rate_decay_fn is None: - return self._learning_rate - scale = self._learning_rate_decay_fn(self._decay_step) - return scale * self._learning_rate + def report_steps(self): + result = [] + for name, optimizer in self.suboptimizers.items(): + count = optimizer.training_step + lr = optimizer.learning_rate() + result.append(f'Optimizer "{name}" has been stepped {count} times and has LR {lr}') + return result def state_dict(self): + """Returns the state dictionary""" return { - 'training_step': self._training_step, - 'decay_step': self._decay_step, - 'optimizer': self._optimizer.state_dict(), + 'optimizers': {k: v.state_dict() for k, v in self.suboptimizers.items()}, + 'steps': self._steps, } def load_state_dict(self, state_dict): - self._training_step = state_dict['training_step'] - # State can be partially restored. - if 'decay_step' in state_dict: - self._decay_step = state_dict['decay_step'] - if 'optimizer' in state_dict: - self._optimizer.load_state_dict(state_dict['optimizer']) + """Loads the optimizer from the state dictionary""" - def zero_grad(self): - """Zero the gradients of optimized parameters.""" - self._optimizer.zero_grad() + # do not load any optimizer state if one component is missing + do_load = True + for k in state_dict["optimizers"].keys(): + if k not in self.suboptimizers.keys(): + do_load = False - def backward(self, loss): - """Wrapper for backward pass. Some optimizer requires ownership of the - backward pass.""" - if self.amp: - self._scaler.scale(loss).backward() + if do_load is True: + for k in state_dict["optimizers"].keys(): + self.suboptimizers[k].load_state_dict(state_dict["optimizers"][k]) else: - loss.backward() + logger.info("Some components do not match. Do not load optimizer from checkpoint.") - def externally_managed_step(self, *args, **kwargs): - """Update the model parameters based on current gradients. + self._steps = state_dict["steps"] - Optionally, will employ gradient modification or update learning - rate. + @property + def training_step(self): """ - learning_rate = self.learning_rate() - - if self.amp: - for suboptimizer in self._optimizer.optimizers.values(): - self._scaler.unscale_(suboptimizer) - - for group in self._optimizer.param_groups: - group['lr'] = learning_rate - if self._max_grad_norm > 0: - clip_grad_norm_(group['params'], self._max_grad_norm) - - if self.amp: - self._scaler.step(self._optimizer) + Global training step, incremented for each minibatch. + There may not be any actual parameters that have been trained this number of steps, + or any optimizer stepped this number of times. + """ + return self.global_training_step - # Updates the scale for next iteration. - self._scaler.update() - else: - self._optimizer.externally_managed_step(*args, **kwargs) - self._decay_step += 1 - self._training_step += 1 + @property + def amp(self): + """True if use torch amp mix precision training.""" + return self.grad_scaler is not None # Code below is an implementation of https://arxiv.org/pdf/1804.04235.pdf diff --git a/mammoth/utils/report_manager.py b/mammoth/utils/report_manager.py index cc3702f1..bf4a276d 100644 --- a/mammoth/utils/report_manager.py +++ b/mammoth/utils/report_manager.py @@ -60,6 +60,7 @@ def report_training( report_stats, multigpu=False, sampled_task_counts=None, + optimizer=None, ): """ This is the user-defined batch-level traing progress @@ -91,6 +92,9 @@ def report_training( report_stats, sampled_task_counts=sampled_task_counts ) + if optimizer is not None: + for line in optimizer.report_steps(): + logger.info(line) return mammoth.utils.Statistics() else: return report_stats diff --git a/mammoth/utils/statistics.py b/mammoth/utils/statistics.py index ff8292a5..270ee09d 100644 --- a/mammoth/utils/statistics.py +++ b/mammoth/utils/statistics.py @@ -149,14 +149,14 @@ def output(self, step, num_steps, learning_rate, start, metadata=None): if num_steps > 0: step_fmt = "%s/%5d" % (step_fmt, num_steps) logger.info( - ("%s: Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + "lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") + ("%s: Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; %3.0f/%3.0f tok/s; %6.0f sec") % ( meta_str, step_fmt, self.accuracy(), self.ppl(), self.xent(), - learning_rate, + # learning_rate, # was "lr: %7.5f;" self.n_src_words / (t + 1e-5), self.n_words / (t + 1e-5), time.time() - start, @@ -174,7 +174,7 @@ def log_tensorboard(self, prefix, writer, learning_rate, patience, step): writer.add_scalar(prefix + "/ppl", self.ppl(), step) writer.add_scalar(prefix + "/accuracy", self.accuracy(), step) writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) - writer.add_scalar(prefix + "/lr", learning_rate, step) + # writer.add_scalar(prefix + "/lr", learning_rate, step) if patience is not None: writer.add_scalar(prefix + "/patience", patience, step)