From 630a679438a582f495354af6965a692b349c7b7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 28 Oct 2024 14:20:17 +0200 Subject: [PATCH] Count clipping, don't recompute active adapters --- mammoth/model_builder.py | 2 ++ mammoth/models/model.py | 8 ++++---- mammoth/modules/adapters.py | 2 +- mammoth/modules/layer_stack.py | 2 +- mammoth/trainer.py | 2 +- mammoth/utils/model_saver.py | 14 +++++++++----- mammoth/utils/optimizers.py | 16 ++++++++++++---- 7 files changed, 30 insertions(+), 16 deletions(-) diff --git a/mammoth/model_builder.py b/mammoth/model_builder.py index 5f654557..8eecdfd8 100644 --- a/mammoth/model_builder.py +++ b/mammoth/model_builder.py @@ -90,6 +90,8 @@ def get_transformer_wrapper_kwargs( kwargs.update({ 'max_seq_len': max_seq_len, }) + if side == Side.encoder: + kwargs['return_only_embed'] = True return kwargs diff --git a/mammoth/models/model.py b/mammoth/models/model.py index 912945fc..5e950b36 100644 --- a/mammoth/models/model.py +++ b/mammoth/models/model.py @@ -75,10 +75,10 @@ def forward(self, src, decoder_input, src_mask, metadata=None): return_embeddings=True, ) - encoder_output, alphas = self.attention_bridge(encoder_output, src_mask) - if self.attention_bridge.is_fixed_length: - # turn off masking in the transformer decoder - src_mask = None + # encoder_output, alphas = self.attention_bridge(encoder_output, src_mask) + # if self.attention_bridge.is_fixed_length: + # # turn off masking in the transformer decoder + # src_mask = None retval = active_decoder( decoder_input, diff --git a/mammoth/modules/adapters.py b/mammoth/modules/adapters.py index 0f28d9d9..763eebd8 100644 --- a/mammoth/modules/adapters.py +++ b/mammoth/modules/adapters.py @@ -192,13 +192,13 @@ def _inject_adapters(self): adapted_layer_types = [] adapted_layers = nn.ModuleList() adapted_layer_dropouts = [] + adapter_layers_by_index = self._merge_active_adapters() i = 0 for layer_type, layer_struct, layer_dropout in zip( self._base_layer_types, self._base_layers, self._base_layer_dropouts, ): - adapter_layers_by_index = self._merge_active_adapters() if layer_type == 'f': # Adapters apply to feedforward layers adapter_layers = adapter_layers_by_index[i] diff --git a/mammoth/modules/layer_stack.py b/mammoth/modules/layer_stack.py index 1ceff784..67dabae6 100644 --- a/mammoth/modules/layer_stack.py +++ b/mammoth/modules/layer_stack.py @@ -9,7 +9,7 @@ class AdaptedAttentionLayersStack(nn.Module): """ Wrapper that allows stacking multiple AdaptedAttentionLayers. - Represents one particular stacking: does not allow switching out entire layers + Represents one particular task-specific stacking: does not allow switching out entire layers (but does delegate the switching out of adapters to its components) """ def __init__(self, attention_layers_stack: Sequence[AdaptedAttentionLayers]): diff --git a/mammoth/trainer.py b/mammoth/trainer.py index 82900140..0ce48935 100644 --- a/mammoth/trainer.py +++ b/mammoth/trainer.py @@ -516,13 +516,13 @@ def _gradient_accumulation( total_stats.update(batch_stats) report_stats.update(batch_stats) report_stats.update_task_loss(batch_stats.loss, metadata) - except Exception: traceback.print_exc() logger.info("At step %d, we removed a batch - accum %d", self.optim.training_step, k) self.nan_batches += 1 if self.nan_batches >= self.max_nan_batches: raise Exception('Exceeded allowed --max_nan_batches.') + if len(seen_comm_batches) != 1: logger.warning('Communication batches out of synch with batch accumulation') diff --git a/mammoth/utils/model_saver.py b/mammoth/utils/model_saver.py index bdaab236..cedef2b8 100644 --- a/mammoth/utils/model_saver.py +++ b/mammoth/utils/model_saver.py @@ -60,7 +60,9 @@ def explode_model( # Only the lowest ranked device saves a component state_dicts[name] = component.state_dict(model) # The optimizer parameters are distributed the same way as the components - optim_state_dicts[name] = optim.suboptimizers[name].state_dict() + # Not all components have trainable (unfrozen) parameters, though + if name in optim.suboptimizers: + optim_state_dicts[name] = optim.suboptimizers[name].state_dict() return state_dicts, optim_state_dicts @@ -290,12 +292,14 @@ def _save(self, step, model, data_state, task_queue_manager): if os.path.isfile(checkpoint_path): logger.debug("{} - not saving {} as it is already present".format(device_context.id, checkpoint_path)) else: - logger.info(f'Saving module checkpoint {checkpoint_path} and optimizer {optimizer_path}') - torch.save(state_dict, checkpoint_path) - tmp_checkpoint_paths.append(checkpoint_path) - if key != 'frame': + if key != 'frame' and key in optim_state_dicts: + logger.info(f'Saving module checkpoint {checkpoint_path} and optimizer {optimizer_path}') torch.save(optim_state_dicts[key], optimizer_path) tmp_checkpoint_paths.append(optimizer_path) + else: + logger.info(f'Saving module checkpoint {checkpoint_path} (no optimizer to save)') + torch.save(state_dict, checkpoint_path) + tmp_checkpoint_paths.append(checkpoint_path) return tmp_checkpoint_paths diff --git a/mammoth/utils/optimizers.py b/mammoth/utils/optimizers.py index 6d588e3a..117b33a9 100644 --- a/mammoth/utils/optimizers.py +++ b/mammoth/utils/optimizers.py @@ -154,11 +154,12 @@ def __init__( self._optimizer = optimizer self._learning_rate = learning_rate self._learning_rate_decay_fn = learning_rate_decay_fn - self._max_grad_norm = max_grad_norm self.grad_scaler = grad_scaler self._training_step = 1 self._decay_step = 1 self._n_clips = 0 + self._n_params_tot = self._count_params() + self._max_grad_norm = max_grad_norm * self._n_params_tot @property def param_groups(self): @@ -169,6 +170,12 @@ def training_step(self): """The current training step.""" return self._training_step + def _count_params(self): + result = 0 + for group in self._optimizer.param_groups: + result += sum(param.numel() for param in group['params']) + return result + def learning_rate(self): """Returns the current learning rate.""" if self._learning_rate_decay_fn is None: @@ -195,7 +202,6 @@ def load_state_dict(self, state_dict): def zero_grad(self): """Zero the gradients of optimized parameters.""" self._optimizer.zero_grad() - self._n_clips = 0 def step(self): """Update the model parameters based on current gradients. """ @@ -207,7 +213,7 @@ def step(self): orig_norm = clip_grad_norm_(group['params'], self._max_grad_norm) if orig_norm.item() > self._max_grad_norm: # FIXME: debug. Count and log instead - print(f'Clipping {orig_norm} -> {self._max_grad_norm}') + # print(f'Clipping {orig_norm} -> {self._max_grad_norm}') self._n_clips += 1 if self.grad_scaler is not None: @@ -350,7 +356,9 @@ def report_steps(self): for name, optimizer in self.suboptimizers.items(): count = optimizer.training_step lr = optimizer.learning_rate() - result.append(f'Optimizer "{name}" has been stepped {count} times and has LR {lr}') + n_clips = optimizer._n_clips + result.append(f'Optimizer "{name}" has been stepped {count} times. LR {lr} n_clips {n_clips}') + optimizer._n_clips = 0 return result def count_parameters(self):