diff --git a/mammoth/distributed/components.py b/mammoth/distributed/components.py index d8733184..48c95e7f 100644 --- a/mammoth/distributed/components.py +++ b/mammoth/distributed/components.py @@ -87,7 +87,41 @@ def needs_communication(self) -> bool: return self.group is not None -# TODO: This is a misnomer: Not an entire XCoder, but just one AttentionLayers block +@dataclass # type: ignore +class DistributedTransformerWrapper(DistributedComponent, ABC): + task_id: str + side: Side + + def get_name(self) -> str: + return f'{self.side.name}_{self.task_id}' + + def get_module(self, model: NMTModel) -> nn.Module: + parent = model.encoder if self.side == Side.encoder else model.decoder + tw = parent[self.task_id] + return tw + + def named_parameters(self, model: NMTModel): + module = self.get_module(model) + for name, p in module.named_parameters(): + # TransformerWrapper contains the AttentionLayers and the embs. + # however, we want to treat these as distinct DistributedComponents + if name.startswith('attn_layers.'): + continue + if name.startswith('token_emb.'): + continue + yield name, p + + def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, Any]: + module = self.get_module(model) + destination: Dict[str, Any] = OrderedDict() + for name, sub_module in module._modules.items(): + if name.endswith('attn_layers'): + # stored separately + continue + sub_module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) + return destination + + @dataclass # type: ignore class DistributedAttentionLayersBlock(DistributedComponent, ABC): layer_stack_index: int @@ -106,8 +140,9 @@ def named_parameters(self, model: NMTModel): for name, p in module.named_parameters(): # encoders and decoders contain embeddings and adapters as submodules # however, we want to treat these as distinct DistributedComponents - if 'embeddings' not in name and 'adapter' not in name: - yield name, p + if 'adapter' in name: + continue + yield name, p def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, Any]: module = self.get_module(model) @@ -121,7 +156,7 @@ def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, A @dataclass -class DistributedEncoder(DistributedAttentionLayersBlock): +class DistributedEncoderAttentionLayersBlock(DistributedAttentionLayersBlock): @property def side(self) -> Side: return Side.encoder @@ -136,7 +171,7 @@ def get_module(self, model: NMTModel) -> nn.Module: @dataclass -class DistributedDecoder(DistributedAttentionLayersBlock): +class DistributedDecoderAttentionLayersBlock(DistributedAttentionLayersBlock): @property def side(self) -> Side: return Side.decoder diff --git a/mammoth/distributed/tasks.py b/mammoth/distributed/tasks.py index 44358bb6..8471a1a6 100644 --- a/mammoth/distributed/tasks.py +++ b/mammoth/distributed/tasks.py @@ -17,9 +17,10 @@ DistributedComponent, DistributedComponentBuilder, DistributedComponentGradientSync, - DistributedDecoder, + DistributedDecoderAttentionLayersBlock, DistributedEmbedding, - DistributedEncoder, + DistributedEncoderAttentionLayersBlock, + DistributedTransformerWrapper, Side, ) from mammoth.distributed.contexts import DeviceContext, WorldContext @@ -369,9 +370,27 @@ def create_all_distributed_components( lang=task.tgt_lang, ) ) + builder.add( + DistributedTransformerWrapper( + global_ranks={global_rank}, + task_ids={task.corpus_id}, + group=None, + side=Side.encoder, + task_id=task.corpus_id, + ) + ) + builder.add( + DistributedTransformerWrapper( + global_ranks={global_rank}, + task_ids={task.corpus_id}, + group=None, + side=Side.decoder, + task_id=task.corpus_id, + ) + ) for layer_stack_index, encoder_id in enumerate(task.encoder_id): builder.add( - DistributedEncoder( + DistributedEncoderAttentionLayersBlock( global_ranks={global_rank}, task_ids={task.corpus_id}, group=None, @@ -381,7 +400,7 @@ def create_all_distributed_components( ) for layer_stack_index, decoder_id in enumerate(task.decoder_id): builder.add( - DistributedDecoder( + DistributedDecoderAttentionLayersBlock( global_ranks={global_rank}, task_ids={task.corpus_id}, group=None, diff --git a/mammoth/model_builder.py b/mammoth/model_builder.py index 39896b4a..e8b2b986 100644 --- a/mammoth/model_builder.py +++ b/mammoth/model_builder.py @@ -14,8 +14,8 @@ from mammoth.distributed.components import ( DistributedAdapter, DistributedComponent, - DistributedDecoder, - DistributedEncoder, + DistributedDecoderAttentionLayersBlock, + DistributedEncoderAttentionLayersBlock, Side, ) from mammoth.modules.adapters import ( @@ -44,6 +44,14 @@ def forward(self, x): token_emb = self.emb(x.long()) return token_emb +TRANSFORMER_WRAPPER_OPTS = { + 'post_emb_norm', + 'tie_embedding', + 'use_abs_pos_emb', + 'scaled_sinu_pos_emb', + 'emb_frac_gradient', +} + def _combine_ordered_dicts(input_dicts: Dict[str, OrderedDict]) -> OrderedDict: result = [] @@ -72,6 +80,7 @@ def get_attention_layers_kwargs( is_last = layer_stack_index == len(depths) - 1 pre_norm_has_final_norm = is_last kwargs = model_opts.x_transformers_opts if model_opts.x_transformers_opts else dict() + kwargs = {key: val for key, val in kwargs.items() if key not in TRANSFORMER_WRAPPER_OPTS} kwargs.update({ 'dim': model_opts.model_dim, 'depth': depth, @@ -82,6 +91,21 @@ def get_attention_layers_kwargs( return kwargs +def get_transformer_wrapper_kwargs( + side: Side, + model_opts, +): + """Return arguments for x_transformers.TransformerWrapper""" + assert side in {Side.encoder, Side.decoder}, f'Invalid side "{side}"' + kwargs = model_opts.x_transformers_opts if model_opts.x_transformers_opts else dict() + kwargs = {key: val for key, val in kwargs.items() if key in TRANSFORMER_WRAPPER_OPTS} + max_seq_len = 0 if model_opts.max_length is None else model_opts.max_length + kwargs.update({ + 'max_seq_len': max_seq_len, + }) + return kwargs + + def build_xcoder( side: Side, model_opts, @@ -108,10 +132,10 @@ def build_xcoder( ] distributed_xcoder_class: type if side == Side.encoder: - distributed_xcoder_class = DistributedEncoder + distributed_xcoder_class = DistributedEncoderAttentionLayersBlock side_str = 'encoder' else: - distributed_xcoder_class = DistributedDecoder + distributed_xcoder_class = DistributedDecoderAttentionLayersBlock side_str = 'decoder' if single_task: my_components = [ @@ -210,6 +234,10 @@ def build_xcoder( if single_task: tasks = [task for task in tasks if task.corpus_id == single_task] transformer_wrappers = dict() + transformer_wrapper_kwargs = get_transformer_wrapper_kwargs( + side=side, + model_opts=model_opts, + ) for task in tasks: if side == Side.encoder: xcoder_ids = task.encoder_id @@ -225,21 +253,11 @@ def build_xcoder( lang = task.src_lang if side == Side.encoder else task.tgt_lang vocab = vocabs_dict[(side_alt_str, lang)] - max_seq_len = 0 if model_opts.max_length is None else model_opts.max_length - post_emb_norm = True - tie_embedding = True - use_abs_pos_emb = True - emb_frac_gradient = 1. # Using custom extended TransformerWrapper to allow passing in an embedding transformer_wrapper = TransformerWrapper( num_tokens=len(vocab), - max_seq_len=max_seq_len, attn_layers=adapted_attention_layers_stack, emb_dim=model_opts.model_dim, - post_emb_norm=post_emb_norm, - tie_embedding=tie_embedding, - use_abs_pos_emb=use_abs_pos_emb, - emb_frac_gradient=emb_frac_gradient, token_emb=token_embs[lang], initialize_embeddings=not (model_opts.use_embeddingless) ) @@ -324,3 +342,21 @@ def build_model( # logger.info(model) logger.info('Building model - done!') return model + + +def validate_optimizer_coverage(model, optimizer): + trainable_model_params = { + name: p for name, p in model.named_parameters() + if p.requires_grad + } + optimized_params = set() + for group in optimizer.param_groups: + optimized_params.update(group['params']) + missing_params = [ + name for name, p in trainable_model_params.items() + if p not in optimized_params + ] + if len(missing_params) > 0: + raise Exception(f'Missing optimizer for params: {sorted(missing_params)}') + else: + logger.info('All non-frozen parameters have an optimizer') diff --git a/mammoth/tests/test_task_queue_manager.py b/mammoth/tests/test_task_queue_manager.py index 30df3202..3ef64f8c 100644 --- a/mammoth/tests/test_task_queue_manager.py +++ b/mammoth/tests/test_task_queue_manager.py @@ -5,8 +5,8 @@ from mammoth.distributed import TaskQueueManager, WorldContext from mammoth.distributed.components import ( Side, - DistributedEncoder, - DistributedDecoder, + DistributedEncoderAttentionLayersBlock, + DistributedDecoderAttentionLayersBlock, DistributedEmbedding, # DistributedAdapter, # DistributedAttentionBridge, @@ -155,35 +155,35 @@ def __call__(self, sorted_global_ranks): use_attention_bridge=False, new_group_func=MockGroup() ) assert all_components == [ - DistributedDecoder( + DistributedDecoderAttentionLayersBlock( global_ranks={0, 2}, task_ids={'train_3_e-b', 'train_0_a-b'}, group="Group 0 with GPU ranks [0, 2]", layer_stack_index=0, xcoder_id="y", ), - DistributedDecoder( + DistributedDecoderAttentionLayersBlock( global_ranks={1}, task_ids={"train_2_a-d", "train_1_c-d"}, group=None, layer_stack_index=0, xcoder_id="yy", ), - DistributedEncoder( + DistributedEncoderAttentionLayersBlock( global_ranks={0, 1}, task_ids={"train_2_a-d", "train_0_a-b"}, group="Group 1 with GPU ranks [0, 1]", layer_stack_index=0, xcoder_id="x", ), - DistributedEncoder( + DistributedEncoderAttentionLayersBlock( global_ranks={1}, task_ids={"train_1_c-d"}, group=None, layer_stack_index=0, xcoder_id="xx", ), - DistributedEncoder( + DistributedEncoderAttentionLayersBlock( global_ranks={2}, task_ids={'train_3_e-b'}, group=None, @@ -244,21 +244,21 @@ def __call__(self, sorted_global_ranks): f"my component {component} not in all_components {all_components}" ) assert my_components == [ - DistributedDecoder( + DistributedDecoderAttentionLayersBlock( global_ranks={1}, task_ids={"train_2_a-d", "train_1_c-d"}, group=None, layer_stack_index=0, xcoder_id="yy", ), - DistributedEncoder( + DistributedEncoderAttentionLayersBlock( global_ranks={0, 1}, task_ids={"train_2_a-d", "train_0_a-b"}, group="Group 1 with GPU ranks [0, 1]", layer_stack_index=0, xcoder_id="x", ), - DistributedEncoder( + DistributedEncoderAttentionLayersBlock( global_ranks={1}, task_ids={"train_1_c-d"}, group=None, diff --git a/mammoth/train_single.py b/mammoth/train_single.py index 90034235..c21491e8 100644 --- a/mammoth/train_single.py +++ b/mammoth/train_single.py @@ -3,7 +3,7 @@ import torch import time -from mammoth.model_builder import build_model +from mammoth.model_builder import build_model, validate_optimizer_coverage from mammoth.utils.optimizers import MultipleOptimizer from mammoth.utils.misc import set_random_seed from mammoth.trainer import build_trainer @@ -139,6 +139,7 @@ def main( device_context.id, optim.count_parameters() )) + validate_optimizer_coverage(model, optim) # Load parameters from checkpoint if opts.train_from: diff --git a/mammoth/trainer.py b/mammoth/trainer.py index a8c043fe..976bffee 100644 --- a/mammoth/trainer.py +++ b/mammoth/trainer.py @@ -239,8 +239,9 @@ def train( else: logger.info('Start training loop and validate every %d steps...', valid_steps) - total_stats = mammoth.utils.Statistics() - report_stats = mammoth.utils.Statistics() + n_correct = 0 if self.report_training_accuracy else None + total_stats = mammoth.utils.Statistics(n_correct=n_correct) + report_stats = mammoth.utils.Statistics(n_correct=n_correct) self._start_report_manager(start_time=total_stats.start_time) self.optim.zero_grad() @@ -385,7 +386,7 @@ def validate(self, valid_iter, moving_average=None, task=None): for batch, metadata, _ in valid_iter: if stats is None: - stats = mammoth.utils.Statistics() + stats = mammoth.utils.Statistics(n_correct=0) stats.n_src_words += batch.src.mask.sum().item() src = batch.src.tensor @@ -470,9 +471,9 @@ def _gradient_accumulation( with torch.cuda.amp.autocast(enabled=self.optim.amp): logits, decoder_output = self.model( - rearrange(src, 't b 1 -> b t'), - rearrange(decoder_input, 't b 1 -> b t'), - rearrange(src_mask, 't b -> b t'), + src=rearrange(src, 't b 1 -> b t'), + decoder_input=rearrange(decoder_input, 't b 1 -> b t'), + src_mask=rearrange(src_mask, 't b -> b t'), metadata=metadata, ) logits = rearrange(logits, 'b t i -> t b i') diff --git a/mammoth/utils/model_saver.py b/mammoth/utils/model_saver.py index d6f90196..3d7bb4b1 100644 --- a/mammoth/utils/model_saver.py +++ b/mammoth/utils/model_saver.py @@ -112,7 +112,7 @@ def load_parameters_from_checkpoint( logger.warning( f'Could not find optim checkpoint file {optimizer_path}. Affected parameters are reinitialized.' ) - all_ok = False + all_ok = True if all_ok: if reset_optim: logger.info(f'All modules restored from checkpoint {checkpoint_prefix}') diff --git a/mammoth/utils/report_manager.py b/mammoth/utils/report_manager.py index bf4a276d..9a7d4a91 100644 --- a/mammoth/utils/report_manager.py +++ b/mammoth/utils/report_manager.py @@ -95,7 +95,8 @@ def report_training( if optimizer is not None: for line in optimizer.report_steps(): logger.info(line) - return mammoth.utils.Statistics() + n_correct = None if report_stats.n_correct is None else 0 + return mammoth.utils.Statistics(n_correct=n_correct) else: return report_stats @@ -156,7 +157,8 @@ def _report_training(self, step, num_steps, learning_rate, patience, report_stat report_stats.output(step, num_steps, learning_rate, self.start_time) self.maybe_log_tensorboard(report_stats, "progress", learning_rate, patience, step) - report_stats = mammoth.utils.Statistics() + n_correct = None if report_stats.n_correct is None else 0 + report_stats = mammoth.utils.Statistics(n_correct=n_correct) total = sum(sampled_task_counts.values()) logger.info(f'Task sampling distribution: (total {total})') @@ -183,7 +185,7 @@ def _report_step(self, lr, patience, step, train_stats=None, valid_stats=None): structured_logging({ 'type': 'validation', 'step': step, - 'learning_rate': lr, + # 'learning_rate': lr, 'perplexity': ppl, 'accuracy': acc, 'crossentropy': valid_stats.xent(), diff --git a/mammoth/utils/statistics.py b/mammoth/utils/statistics.py index 0c680360..098862a8 100644 --- a/mammoth/utils/statistics.py +++ b/mammoth/utils/statistics.py @@ -109,6 +109,8 @@ def update(self, stat, update_n_src_words=False): if stat.n_words: self.n_words += stat.n_words if stat.n_correct: + if self.n_correct is None: + self.n_correct=0 self.n_correct += stat.n_correct if update_n_src_words: