From 203d4d5d2ce9c9856c4b8e5cf82d772bf343f4c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 7 Oct 2024 14:44:05 +0300 Subject: [PATCH 1/3] Bugfix: Statisics inherits n_correct from previous instance The default value must be either zero or None, depending on whether accuracy is reported or not. --- mammoth/trainer.py | 7 ++++--- mammoth/utils/report_manager.py | 8 +++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/mammoth/trainer.py b/mammoth/trainer.py index a8c043fe..19f94e24 100644 --- a/mammoth/trainer.py +++ b/mammoth/trainer.py @@ -239,8 +239,9 @@ def train( else: logger.info('Start training loop and validate every %d steps...', valid_steps) - total_stats = mammoth.utils.Statistics() - report_stats = mammoth.utils.Statistics() + n_correct = 0 if self.report_training_accuracy else None + total_stats = mammoth.utils.Statistics(n_correct=n_correct) + report_stats = mammoth.utils.Statistics(n_correct=n_correct) self._start_report_manager(start_time=total_stats.start_time) self.optim.zero_grad() @@ -385,7 +386,7 @@ def validate(self, valid_iter, moving_average=None, task=None): for batch, metadata, _ in valid_iter: if stats is None: - stats = mammoth.utils.Statistics() + stats = mammoth.utils.Statistics(n_correct=0) stats.n_src_words += batch.src.mask.sum().item() src = batch.src.tensor diff --git a/mammoth/utils/report_manager.py b/mammoth/utils/report_manager.py index bf4a276d..9a7d4a91 100644 --- a/mammoth/utils/report_manager.py +++ b/mammoth/utils/report_manager.py @@ -95,7 +95,8 @@ def report_training( if optimizer is not None: for line in optimizer.report_steps(): logger.info(line) - return mammoth.utils.Statistics() + n_correct = None if report_stats.n_correct is None else 0 + return mammoth.utils.Statistics(n_correct=n_correct) else: return report_stats @@ -156,7 +157,8 @@ def _report_training(self, step, num_steps, learning_rate, patience, report_stat report_stats.output(step, num_steps, learning_rate, self.start_time) self.maybe_log_tensorboard(report_stats, "progress", learning_rate, patience, step) - report_stats = mammoth.utils.Statistics() + n_correct = None if report_stats.n_correct is None else 0 + report_stats = mammoth.utils.Statistics(n_correct=n_correct) total = sum(sampled_task_counts.values()) logger.info(f'Task sampling distribution: (total {total})') @@ -183,7 +185,7 @@ def _report_step(self, lr, patience, step, train_stats=None, valid_stats=None): structured_logging({ 'type': 'validation', 'step': step, - 'learning_rate': lr, + # 'learning_rate': lr, 'perplexity': ppl, 'accuracy': acc, 'crossentropy': valid_stats.xent(), From 83c8c26c905298ce1841183ca990485649acfe14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 7 Oct 2024 17:27:56 +0300 Subject: [PATCH 2/3] Pass kwargs also to TransformerWrapper --- mammoth/model_builder.py | 39 +++++++++++++++++++++++++++++---------- mammoth/trainer.py | 6 +++--- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/mammoth/model_builder.py b/mammoth/model_builder.py index 8564cf1c..d63a462f 100644 --- a/mammoth/model_builder.py +++ b/mammoth/model_builder.py @@ -31,6 +31,14 @@ from mammoth.utils.logging import logger from mammoth.utils.misc import use_gpu +TRANSFORMER_WRAPPER_OPTS = { + 'post_emb_norm', + 'tie_embedding', + 'use_abs_pos_emb', + 'scaled_sinu_pos_emb', + 'emb_frac_gradient', +} + def _combine_ordered_dicts(input_dicts: Dict[str, OrderedDict]) -> OrderedDict: result = [] @@ -59,6 +67,7 @@ def get_attention_layers_kwargs( is_last = layer_stack_index == len(depths) - 1 pre_norm_has_final_norm = is_last kwargs = model_opts.x_transformers_opts if model_opts.x_transformers_opts else dict() + kwargs = {key: val for key, val in kwargs.items() if key not in TRANSFORMER_WRAPPER_OPTS} kwargs.update({ 'dim': model_opts.model_dim, 'depth': depth, @@ -69,6 +78,21 @@ def get_attention_layers_kwargs( return kwargs +def get_transformer_wrapper_kwargs( + side: Side, + model_opts, +): + """Return arguments for x_transformers.TransformerWrapper""" + assert side in {Side.encoder, Side.decoder}, f'Invalid side "{side}"' + kwargs = model_opts.x_transformers_opts if model_opts.x_transformers_opts else dict() + kwargs = {key: val for key, val in kwargs.items() if key in TRANSFORMER_WRAPPER_OPTS} + max_seq_len = 0 if model_opts.max_length is None else model_opts.max_length + kwargs.update({ + 'max_seq_len': max_seq_len, + }) + return kwargs + + def build_xcoder( side: Side, model_opts, @@ -196,6 +220,10 @@ def build_xcoder( if single_task: tasks = [task for task in tasks if task.corpus_id == single_task] transformer_wrappers = dict() + transformer_wrapper_kwargs = get_transformer_wrapper_kwargs( + side=side, + model_opts=model_opts, + ) for task in tasks: if side == Side.encoder: xcoder_ids = task.encoder_id @@ -211,22 +239,13 @@ def build_xcoder( lang = task.src_lang if side == Side.encoder else task.tgt_lang vocab = vocabs_dict[(side_alt_str, lang)] - max_seq_len = 0 if model_opts.max_length is None else model_opts.max_length - post_emb_norm = True - tie_embedding = True - use_abs_pos_emb = True - emb_frac_gradient = 1. # Using custom extended TransformerWrapper to allow passing in an embedding transformer_wrapper = TransformerWrapper( num_tokens=len(vocab), - max_seq_len=max_seq_len, attn_layers=adapted_attention_layers_stack, emb_dim=model_opts.model_dim, - post_emb_norm=post_emb_norm, - tie_embedding=tie_embedding, - use_abs_pos_emb=use_abs_pos_emb, - emb_frac_gradient=emb_frac_gradient, token_emb=token_embs[lang], + **transformer_wrapper_kwargs, ) transformer_wrappers[task.corpus_id] = transformer_wrapper diff --git a/mammoth/trainer.py b/mammoth/trainer.py index 19f94e24..976bffee 100644 --- a/mammoth/trainer.py +++ b/mammoth/trainer.py @@ -471,9 +471,9 @@ def _gradient_accumulation( with torch.cuda.amp.autocast(enabled=self.optim.amp): logits, decoder_output = self.model( - rearrange(src, 't b 1 -> b t'), - rearrange(decoder_input, 't b 1 -> b t'), - rearrange(src_mask, 't b -> b t'), + src=rearrange(src, 't b 1 -> b t'), + decoder_input=rearrange(decoder_input, 't b 1 -> b t'), + src_mask=rearrange(src_mask, 't b -> b t'), metadata=metadata, ) logits = rearrange(logits, 'b t i -> t b i') From 97bc2d90bae8c6bd4dc1173dc600e7cb240e1f6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 14 Oct 2024 13:04:34 +0300 Subject: [PATCH 3/3] Distributed component for TransformerWrapper Parameters in the TransformerWrapper, e.g. to_logits, need their own distributed component and optimizer. --- mammoth/distributed/components.py | 45 +++++++++++++++++++++--- mammoth/distributed/tasks.py | 27 +++++++++++--- mammoth/model_builder.py | 26 +++++++++++--- mammoth/tests/test_task_queue_manager.py | 20 +++++------ mammoth/train_single.py | 3 +- 5 files changed, 97 insertions(+), 24 deletions(-) diff --git a/mammoth/distributed/components.py b/mammoth/distributed/components.py index d8733184..48c95e7f 100644 --- a/mammoth/distributed/components.py +++ b/mammoth/distributed/components.py @@ -87,7 +87,41 @@ def needs_communication(self) -> bool: return self.group is not None -# TODO: This is a misnomer: Not an entire XCoder, but just one AttentionLayers block +@dataclass # type: ignore +class DistributedTransformerWrapper(DistributedComponent, ABC): + task_id: str + side: Side + + def get_name(self) -> str: + return f'{self.side.name}_{self.task_id}' + + def get_module(self, model: NMTModel) -> nn.Module: + parent = model.encoder if self.side == Side.encoder else model.decoder + tw = parent[self.task_id] + return tw + + def named_parameters(self, model: NMTModel): + module = self.get_module(model) + for name, p in module.named_parameters(): + # TransformerWrapper contains the AttentionLayers and the embs. + # however, we want to treat these as distinct DistributedComponents + if name.startswith('attn_layers.'): + continue + if name.startswith('token_emb.'): + continue + yield name, p + + def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, Any]: + module = self.get_module(model) + destination: Dict[str, Any] = OrderedDict() + for name, sub_module in module._modules.items(): + if name.endswith('attn_layers'): + # stored separately + continue + sub_module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) + return destination + + @dataclass # type: ignore class DistributedAttentionLayersBlock(DistributedComponent, ABC): layer_stack_index: int @@ -106,8 +140,9 @@ def named_parameters(self, model: NMTModel): for name, p in module.named_parameters(): # encoders and decoders contain embeddings and adapters as submodules # however, we want to treat these as distinct DistributedComponents - if 'embeddings' not in name and 'adapter' not in name: - yield name, p + if 'adapter' in name: + continue + yield name, p def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, Any]: module = self.get_module(model) @@ -121,7 +156,7 @@ def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, A @dataclass -class DistributedEncoder(DistributedAttentionLayersBlock): +class DistributedEncoderAttentionLayersBlock(DistributedAttentionLayersBlock): @property def side(self) -> Side: return Side.encoder @@ -136,7 +171,7 @@ def get_module(self, model: NMTModel) -> nn.Module: @dataclass -class DistributedDecoder(DistributedAttentionLayersBlock): +class DistributedDecoderAttentionLayersBlock(DistributedAttentionLayersBlock): @property def side(self) -> Side: return Side.decoder diff --git a/mammoth/distributed/tasks.py b/mammoth/distributed/tasks.py index 44358bb6..8471a1a6 100644 --- a/mammoth/distributed/tasks.py +++ b/mammoth/distributed/tasks.py @@ -17,9 +17,10 @@ DistributedComponent, DistributedComponentBuilder, DistributedComponentGradientSync, - DistributedDecoder, + DistributedDecoderAttentionLayersBlock, DistributedEmbedding, - DistributedEncoder, + DistributedEncoderAttentionLayersBlock, + DistributedTransformerWrapper, Side, ) from mammoth.distributed.contexts import DeviceContext, WorldContext @@ -369,9 +370,27 @@ def create_all_distributed_components( lang=task.tgt_lang, ) ) + builder.add( + DistributedTransformerWrapper( + global_ranks={global_rank}, + task_ids={task.corpus_id}, + group=None, + side=Side.encoder, + task_id=task.corpus_id, + ) + ) + builder.add( + DistributedTransformerWrapper( + global_ranks={global_rank}, + task_ids={task.corpus_id}, + group=None, + side=Side.decoder, + task_id=task.corpus_id, + ) + ) for layer_stack_index, encoder_id in enumerate(task.encoder_id): builder.add( - DistributedEncoder( + DistributedEncoderAttentionLayersBlock( global_ranks={global_rank}, task_ids={task.corpus_id}, group=None, @@ -381,7 +400,7 @@ def create_all_distributed_components( ) for layer_stack_index, decoder_id in enumerate(task.decoder_id): builder.add( - DistributedDecoder( + DistributedDecoderAttentionLayersBlock( global_ranks={global_rank}, task_ids={task.corpus_id}, group=None, diff --git a/mammoth/model_builder.py b/mammoth/model_builder.py index d63a462f..5f654557 100644 --- a/mammoth/model_builder.py +++ b/mammoth/model_builder.py @@ -14,8 +14,8 @@ from mammoth.distributed.components import ( DistributedAdapter, DistributedComponent, - DistributedDecoder, - DistributedEncoder, + DistributedDecoderAttentionLayersBlock, + DistributedEncoderAttentionLayersBlock, Side, ) from mammoth.modules.adapters import ( @@ -119,10 +119,10 @@ def build_xcoder( ] distributed_xcoder_class: type if side == Side.encoder: - distributed_xcoder_class = DistributedEncoder + distributed_xcoder_class = DistributedEncoderAttentionLayersBlock side_str = 'encoder' else: - distributed_xcoder_class = DistributedDecoder + distributed_xcoder_class = DistributedDecoderAttentionLayersBlock side_str = 'decoder' if single_task: my_components = [ @@ -328,3 +328,21 @@ def build_model( # logger.info(model) logger.info('Building model - done!') return model + + +def validate_optimizer_coverage(model, optimizer): + trainable_model_params = { + name: p for name, p in model.named_parameters() + if p.requires_grad + } + optimized_params = set() + for group in optimizer.param_groups: + optimized_params.update(group['params']) + missing_params = [ + name for name, p in trainable_model_params.items() + if p not in optimized_params + ] + if len(missing_params) > 0: + raise Exception(f'Missing optimizer for params: {sorted(missing_params)}') + else: + logger.info('All non-frozen parameters have an optimizer') diff --git a/mammoth/tests/test_task_queue_manager.py b/mammoth/tests/test_task_queue_manager.py index 30df3202..3ef64f8c 100644 --- a/mammoth/tests/test_task_queue_manager.py +++ b/mammoth/tests/test_task_queue_manager.py @@ -5,8 +5,8 @@ from mammoth.distributed import TaskQueueManager, WorldContext from mammoth.distributed.components import ( Side, - DistributedEncoder, - DistributedDecoder, + DistributedEncoderAttentionLayersBlock, + DistributedDecoderAttentionLayersBlock, DistributedEmbedding, # DistributedAdapter, # DistributedAttentionBridge, @@ -155,35 +155,35 @@ def __call__(self, sorted_global_ranks): use_attention_bridge=False, new_group_func=MockGroup() ) assert all_components == [ - DistributedDecoder( + DistributedDecoderAttentionLayersBlock( global_ranks={0, 2}, task_ids={'train_3_e-b', 'train_0_a-b'}, group="Group 0 with GPU ranks [0, 2]", layer_stack_index=0, xcoder_id="y", ), - DistributedDecoder( + DistributedDecoderAttentionLayersBlock( global_ranks={1}, task_ids={"train_2_a-d", "train_1_c-d"}, group=None, layer_stack_index=0, xcoder_id="yy", ), - DistributedEncoder( + DistributedEncoderAttentionLayersBlock( global_ranks={0, 1}, task_ids={"train_2_a-d", "train_0_a-b"}, group="Group 1 with GPU ranks [0, 1]", layer_stack_index=0, xcoder_id="x", ), - DistributedEncoder( + DistributedEncoderAttentionLayersBlock( global_ranks={1}, task_ids={"train_1_c-d"}, group=None, layer_stack_index=0, xcoder_id="xx", ), - DistributedEncoder( + DistributedEncoderAttentionLayersBlock( global_ranks={2}, task_ids={'train_3_e-b'}, group=None, @@ -244,21 +244,21 @@ def __call__(self, sorted_global_ranks): f"my component {component} not in all_components {all_components}" ) assert my_components == [ - DistributedDecoder( + DistributedDecoderAttentionLayersBlock( global_ranks={1}, task_ids={"train_2_a-d", "train_1_c-d"}, group=None, layer_stack_index=0, xcoder_id="yy", ), - DistributedEncoder( + DistributedEncoderAttentionLayersBlock( global_ranks={0, 1}, task_ids={"train_2_a-d", "train_0_a-b"}, group="Group 1 with GPU ranks [0, 1]", layer_stack_index=0, xcoder_id="x", ), - DistributedEncoder( + DistributedEncoderAttentionLayersBlock( global_ranks={1}, task_ids={"train_1_c-d"}, group=None, diff --git a/mammoth/train_single.py b/mammoth/train_single.py index 036cde05..2b603be4 100644 --- a/mammoth/train_single.py +++ b/mammoth/train_single.py @@ -3,7 +3,7 @@ import torch import time -from mammoth.model_builder import build_model +from mammoth.model_builder import build_model, validate_optimizer_coverage from mammoth.utils.optimizers import MultipleOptimizer from mammoth.utils.misc import set_random_seed from mammoth.trainer import build_trainer @@ -135,6 +135,7 @@ def main( device_context.id, optim.count_parameters() )) + validate_optimizer_coverage(model, optim) # Load parameters from checkpoint if opts.train_from: