From 3a16fa0b7c102f46d43efe4267523027706a1e9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 14 Oct 2024 17:26:14 +0300 Subject: [PATCH] State dict fixes The adapter injection code was causing parameter duplication. Another issue: to normalize or not to normalize? We compute a normalization based on either tokens or sents, but never apply it. The effect can be compensated for using the learning rate, as long as batches are approximately the same size. Too high learning rates lead to gradient clipping, which is extra detrimental because each component is individually clipped. Clipping deterministically requires one of the following: - access to gradients for all parameters of the entire model (infeasible) - component local clipping (current approach) - communicating a clipping factor across devices (maybe we should do this?) --- mammoth/distributed/components.py | 20 +++++- mammoth/modules/adapters.py | 7 ++ mammoth/tests/test_task_queue_manager.py | 87 +++++++++++++++++++++++- mammoth/trainer.py | 6 ++ mammoth/utils/optimizers.py | 11 ++- tools/generate_synth_data.py | 2 +- 6 files changed, 129 insertions(+), 4 deletions(-) diff --git a/mammoth/distributed/components.py b/mammoth/distributed/components.py index 48c95e7f..baa24440 100644 --- a/mammoth/distributed/components.py +++ b/mammoth/distributed/components.py @@ -121,6 +121,15 @@ def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, A sub_module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) return destination + def load_state_dict(self, model: NMTModel, state_dict: Dict[str, Any]): + module = self.get_module(model) + mismatch = module.load_state_dict(state_dict, strict=False) + missing_keys = [ + name for name in mismatch.missing_keys + if not name.startswith('attn_layers.') or name.startswith('token_emb.') + ] + return mismatch._replace(missing_keys=missing_keys) + @dataclass # type: ignore class DistributedAttentionLayersBlock(DistributedComponent, ABC): @@ -147,13 +156,22 @@ def named_parameters(self, model: NMTModel): def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, Any]: module = self.get_module(model) destination: Dict[str, Any] = OrderedDict() - for name, sub_module in module._modules.items(): + for name, sub_module in module.get_sub_modules().items(): if name == 'adapters': # Adapters are stored separately continue sub_module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) return destination + def load_state_dict(self, model: NMTModel, state_dict: Dict[str, Any]): + module = self.get_module(model) + mismatch = module.load_state_dict(state_dict, strict=False) + missing_keys = [ + name for name in mismatch.missing_keys + if not name.startswith('layers.') + ] + return mismatch._replace(missing_keys=missing_keys) + @dataclass class DistributedEncoderAttentionLayersBlock(DistributedAttentionLayersBlock): diff --git a/mammoth/modules/adapters.py b/mammoth/modules/adapters.py index 3c6d3d68..0f28d9d9 100644 --- a/mammoth/modules/adapters.py +++ b/mammoth/modules/adapters.py @@ -225,3 +225,10 @@ def _inject_adapters(self): def forward(self, *args, **kwargs): self._inject_adapters() return super().forward(*args, **kwargs) + + def get_sub_modules(self): + omit_submodules = {'layers'} + return { + name: sub_module for name, sub_module in self._modules.items() + if name not in omit_submodules + } diff --git a/mammoth/tests/test_task_queue_manager.py b/mammoth/tests/test_task_queue_manager.py index 3ef64f8c..7c92ac03 100644 --- a/mammoth/tests/test_task_queue_manager.py +++ b/mammoth/tests/test_task_queue_manager.py @@ -5,9 +5,10 @@ from mammoth.distributed import TaskQueueManager, WorldContext from mammoth.distributed.components import ( Side, - DistributedEncoderAttentionLayersBlock, DistributedDecoderAttentionLayersBlock, DistributedEmbedding, + DistributedEncoderAttentionLayersBlock, + DistributedTransformerWrapper, # DistributedAdapter, # DistributedAttentionBridge, # DistributedComponentAction, @@ -169,6 +170,34 @@ def __call__(self, sorted_global_ranks): layer_stack_index=0, xcoder_id="yy", ), + DistributedTransformerWrapper( + global_ranks={0}, + task_ids={'train_0_a-b'}, + group=None, + task_id='train_0_a-b', + side=Side.decoder, + ), + DistributedTransformerWrapper( + global_ranks={1}, + task_ids={'train_1_c-d'}, + group=None, + task_id='train_1_c-d', + side=Side.decoder, + ), + DistributedTransformerWrapper( + global_ranks={1}, + task_ids={'train_2_a-d'}, + group=None, + task_id='train_2_a-d', + side=Side.decoder, + ), + DistributedTransformerWrapper( + global_ranks={2}, + task_ids={'train_3_e-b'}, + group=None, + task_id='train_3_e-b', + side=Side.decoder, + ), DistributedEncoderAttentionLayersBlock( global_ranks={0, 1}, task_ids={"train_2_a-d", "train_0_a-b"}, @@ -190,6 +219,34 @@ def __call__(self, sorted_global_ranks): layer_stack_index=0, xcoder_id="xxx", ), + DistributedTransformerWrapper( + global_ranks={0}, + task_ids={'train_0_a-b'}, + group=None, + task_id='train_0_a-b', + side=Side.encoder, + ), + DistributedTransformerWrapper( + global_ranks={1}, + task_ids={'train_1_c-d'}, + group=None, + task_id='train_1_c-d', + side=Side.encoder, + ), + DistributedTransformerWrapper( + global_ranks={1}, + task_ids={'train_2_a-d'}, + group=None, + task_id='train_2_a-d', + side=Side.encoder, + ), + DistributedTransformerWrapper( + global_ranks={2}, + task_ids={'train_3_e-b'}, + group=None, + task_id='train_3_e-b', + side=Side.encoder, + ), DistributedEmbedding( global_ranks={0, 1}, task_ids={"train_0_a-b", "train_2_a-d"}, @@ -251,6 +308,20 @@ def __call__(self, sorted_global_ranks): layer_stack_index=0, xcoder_id="yy", ), + DistributedTransformerWrapper( + global_ranks={1}, + task_ids={'train_1_c-d'}, + group=None, + task_id='train_1_c-d', + side=Side.decoder, + ), + DistributedTransformerWrapper( + global_ranks={1}, + task_ids={'train_2_a-d'}, + group=None, + task_id='train_2_a-d', + side=Side.decoder, + ), DistributedEncoderAttentionLayersBlock( global_ranks={0, 1}, task_ids={"train_2_a-d", "train_0_a-b"}, @@ -265,6 +336,20 @@ def __call__(self, sorted_global_ranks): layer_stack_index=0, xcoder_id="xx", ), + DistributedTransformerWrapper( + global_ranks={1}, + task_ids={'train_1_c-d'}, + group=None, + task_id='train_1_c-d', + side=Side.encoder, + ), + DistributedTransformerWrapper( + global_ranks={1}, + task_ids={'train_2_a-d'}, + group=None, + task_id='train_2_a-d', + side=Side.encoder, + ), DistributedEmbedding( global_ranks={0, 1}, task_ids={"train_0_a-b", "train_2_a-d"}, diff --git a/mammoth/trainer.py b/mammoth/trainer.py index 976bffee..82900140 100644 --- a/mammoth/trainer.py +++ b/mammoth/trainer.py @@ -393,6 +393,10 @@ def validate(self, valid_iter, moving_average=None, task=None): src_mask = batch.src.mask decoder_input = batch.tgt.tensor[:-1] target = batch.tgt.tensor[1:] + # if self.norm_method == "tokens": + # normalization = batch.tgt.mask.sum().item() + # else: + # normalization = batch.batch_size with torch.cuda.amp.autocast(enabled=self.optim.amp): # F-prop through the model. @@ -410,6 +414,7 @@ def validate(self, valid_iter, moving_average=None, task=None): rearrange(logits, 't b i -> (t b) i'), rearrange(target, 't b 1 -> (t b)'), ) + # loss /= normalization # Update statistics. padding_idx = self.loss_functions[metadata.tgt_lang].ignore_index @@ -490,6 +495,7 @@ def _gradient_accumulation( if loss is not None: if torch.isnan(loss): raise Exception('Loss blowout') + # loss /= normalization self.optim.backward(loss) if self.report_training_accuracy: diff --git a/mammoth/utils/optimizers.py b/mammoth/utils/optimizers.py index 76fff002..6d588e3a 100644 --- a/mammoth/utils/optimizers.py +++ b/mammoth/utils/optimizers.py @@ -158,6 +158,7 @@ def __init__( self.grad_scaler = grad_scaler self._training_step = 1 self._decay_step = 1 + self._n_clips = 0 @property def param_groups(self): @@ -194,6 +195,7 @@ def load_state_dict(self, state_dict): def zero_grad(self): """Zero the gradients of optimized parameters.""" self._optimizer.zero_grad() + self._n_clips = 0 def step(self): """Update the model parameters based on current gradients. """ @@ -202,7 +204,11 @@ def step(self): for group in self._optimizer.param_groups: group['lr'] = learning_rate if self._max_grad_norm > 0: - clip_grad_norm_(group['params'], self._max_grad_norm) + orig_norm = clip_grad_norm_(group['params'], self._max_grad_norm) + if orig_norm.item() > self._max_grad_norm: + # FIXME: debug. Count and log instead + print(f'Clipping {orig_norm} -> {self._max_grad_norm}') + self._n_clips += 1 if self.grad_scaler is not None: self.grad_scaler.step(self._optimizer) @@ -396,6 +402,9 @@ def amp(self): """True if use torch amp mix precision training.""" return self.grad_scaler is not None + def n_clips(self): + return sum(optimizer._n_clips for optimizer in self.suboptimizers.values()) + # Code below is an implementation of https://arxiv.org/pdf/1804.04235.pdf # inspired but modified from https://github.com/DeadAt0m/adafactor-pytorch diff --git a/tools/generate_synth_data.py b/tools/generate_synth_data.py index 686216b5..e4c6c6ff 100644 --- a/tools/generate_synth_data.py +++ b/tools/generate_synth_data.py @@ -355,7 +355,7 @@ def generate_from_config( @click.command(context_settings={'show_default': True}) @click.option('--config_path', type=Path, required=True) @click.option('--vocab_size', type=int, default=300) -@click.option('--num_examples_train', type=int, default=10000) +@click.option('--num_examples_train', type=int, default=1000000) @click.option('--num_examples_test', type=int, default=100) @click.option('--start_seed', type=int, default=1) @click.option(