Skip to content

Commit

Permalink
Count clipping, don't recompute active adapters
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Oct 28, 2024
1 parent 3a16fa0 commit 630a679
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 16 deletions.
2 changes: 2 additions & 0 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def get_transformer_wrapper_kwargs(
kwargs.update({
'max_seq_len': max_seq_len,
})
if side == Side.encoder:
kwargs['return_only_embed'] = True
return kwargs


Expand Down
8 changes: 4 additions & 4 deletions mammoth/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def forward(self, src, decoder_input, src_mask, metadata=None):
return_embeddings=True,
)

encoder_output, alphas = self.attention_bridge(encoder_output, src_mask)
if self.attention_bridge.is_fixed_length:
# turn off masking in the transformer decoder
src_mask = None
# encoder_output, alphas = self.attention_bridge(encoder_output, src_mask)
# if self.attention_bridge.is_fixed_length:
# # turn off masking in the transformer decoder
# src_mask = None

retval = active_decoder(
decoder_input,
Expand Down
2 changes: 1 addition & 1 deletion mammoth/modules/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,13 @@ def _inject_adapters(self):
adapted_layer_types = []
adapted_layers = nn.ModuleList()
adapted_layer_dropouts = []
adapter_layers_by_index = self._merge_active_adapters()
i = 0
for layer_type, layer_struct, layer_dropout in zip(
self._base_layer_types,
self._base_layers,
self._base_layer_dropouts,
):
adapter_layers_by_index = self._merge_active_adapters()
if layer_type == 'f':
# Adapters apply to feedforward layers
adapter_layers = adapter_layers_by_index[i]
Expand Down
2 changes: 1 addition & 1 deletion mammoth/modules/layer_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class AdaptedAttentionLayersStack(nn.Module):
"""
Wrapper that allows stacking multiple AdaptedAttentionLayers.
Represents one particular stacking: does not allow switching out entire layers
Represents one particular task-specific stacking: does not allow switching out entire layers
(but does delegate the switching out of adapters to its components)
"""
def __init__(self, attention_layers_stack: Sequence[AdaptedAttentionLayers]):
Expand Down
2 changes: 1 addition & 1 deletion mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,13 +516,13 @@ def _gradient_accumulation(
total_stats.update(batch_stats)
report_stats.update(batch_stats)
report_stats.update_task_loss(batch_stats.loss, metadata)

except Exception:
traceback.print_exc()
logger.info("At step %d, we removed a batch - accum %d", self.optim.training_step, k)
self.nan_batches += 1
if self.nan_batches >= self.max_nan_batches:
raise Exception('Exceeded allowed --max_nan_batches.')

if len(seen_comm_batches) != 1:
logger.warning('Communication batches out of synch with batch accumulation')

Expand Down
14 changes: 9 additions & 5 deletions mammoth/utils/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def explode_model(
# Only the lowest ranked device saves a component
state_dicts[name] = component.state_dict(model)
# The optimizer parameters are distributed the same way as the components
optim_state_dicts[name] = optim.suboptimizers[name].state_dict()
# Not all components have trainable (unfrozen) parameters, though
if name in optim.suboptimizers:
optim_state_dicts[name] = optim.suboptimizers[name].state_dict()
return state_dicts, optim_state_dicts


Expand Down Expand Up @@ -290,12 +292,14 @@ def _save(self, step, model, data_state, task_queue_manager):
if os.path.isfile(checkpoint_path):
logger.debug("{} - not saving {} as it is already present".format(device_context.id, checkpoint_path))
else:
logger.info(f'Saving module checkpoint {checkpoint_path} and optimizer {optimizer_path}')
torch.save(state_dict, checkpoint_path)
tmp_checkpoint_paths.append(checkpoint_path)
if key != 'frame':
if key != 'frame' and key in optim_state_dicts:
logger.info(f'Saving module checkpoint {checkpoint_path} and optimizer {optimizer_path}')
torch.save(optim_state_dicts[key], optimizer_path)
tmp_checkpoint_paths.append(optimizer_path)
else:
logger.info(f'Saving module checkpoint {checkpoint_path} (no optimizer to save)')
torch.save(state_dict, checkpoint_path)
tmp_checkpoint_paths.append(checkpoint_path)

return tmp_checkpoint_paths

Expand Down
16 changes: 12 additions & 4 deletions mammoth/utils/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,12 @@ def __init__(
self._optimizer = optimizer
self._learning_rate = learning_rate
self._learning_rate_decay_fn = learning_rate_decay_fn
self._max_grad_norm = max_grad_norm
self.grad_scaler = grad_scaler
self._training_step = 1
self._decay_step = 1
self._n_clips = 0
self._n_params_tot = self._count_params()
self._max_grad_norm = max_grad_norm * self._n_params_tot

@property
def param_groups(self):
Expand All @@ -169,6 +170,12 @@ def training_step(self):
"""The current training step."""
return self._training_step

def _count_params(self):
result = 0
for group in self._optimizer.param_groups:
result += sum(param.numel() for param in group['params'])
return result

def learning_rate(self):
"""Returns the current learning rate."""
if self._learning_rate_decay_fn is None:
Expand All @@ -195,7 +202,6 @@ def load_state_dict(self, state_dict):
def zero_grad(self):
"""Zero the gradients of optimized parameters."""
self._optimizer.zero_grad()
self._n_clips = 0

def step(self):
"""Update the model parameters based on current gradients. """
Expand All @@ -207,7 +213,7 @@ def step(self):
orig_norm = clip_grad_norm_(group['params'], self._max_grad_norm)
if orig_norm.item() > self._max_grad_norm:
# FIXME: debug. Count and log instead
print(f'Clipping {orig_norm} -> {self._max_grad_norm}')
# print(f'Clipping {orig_norm} -> {self._max_grad_norm}')
self._n_clips += 1

if self.grad_scaler is not None:
Expand Down Expand Up @@ -350,7 +356,9 @@ def report_steps(self):
for name, optimizer in self.suboptimizers.items():
count = optimizer.training_step
lr = optimizer.learning_rate()
result.append(f'Optimizer "{name}" has been stepped {count} times and has LR {lr}')
n_clips = optimizer._n_clips
result.append(f'Optimizer "{name}" has been stepped {count} times. LR {lr} n_clips {n_clips}')
optimizer._n_clips = 0
return result

def count_parameters(self):
Expand Down

0 comments on commit 630a679

Please sign in to comment.