diff --git a/config/example.hamburger.yaml b/config/example.hamburger.yaml index 859fc90b..fdfb4a4e 100644 --- a/config/example.hamburger.yaml +++ b/config/example.hamburger.yaml @@ -20,35 +20,35 @@ adapters: enc_group: layer_stack_index: 0 layers: [0, 1] - hidden_size: 8 # 512 (rnn_size) / 64 (reduction factor) + hidden_dim: 8 # 512 (rnn_size) / 64 (reduction factor) ids: - foo - bar enc_highresource: layer_stack_index: 0 layers: [0, 1] - hidden_size: 8 + hidden_dim: 8 ids: - en - de enc_lowresource: layer_stack_index: 0 layers: [0] - hidden_size: 8 + hidden_dim: 8 ids: - uu decoder: dec_group: layer_stack_index: 0 layers: [0] - hidden_size: 8 + hidden_dim: 8 ids: - foo - bar dec_highresource: layer_stack_index: 1 layers: [0, 1] - hidden_size: 16 + hidden_dim: 16 ids: - en - de @@ -56,7 +56,7 @@ adapters: dec_lowresource: layer_stack_index: 1 layers: [0] - hidden_size: 8 + hidden_dim: 8 ids: - vv diff --git a/examples/config_config.yaml b/examples/config_config.yaml index 6ff337b7..eaee57a1 100644 --- a/examples/config_config.yaml +++ b/examples/config_config.yaml @@ -70,28 +70,28 @@ adapters: enc_lang_bottom: layer_stack_index: 0 layers: [0, 1, 2] - hidden_size: 8 + hidden_dim: 8 ids: LANGUAGE enc_lang_top: layer_stack_index: 1 layers: [0, 1, 2] - hidden_size: 8 + hidden_dim: 8 ids: LANGUAGE decoder: dec_lang_bottom: layer_stack_index: 0 layers: [0, 1] - hidden_size: 16 + hidden_dim: 16 ids: LANGUAGE dec_lang_mid: layer_stack_index: 1 layers: [0, 1, 2] - hidden_size: 16 + hidden_dim: 16 ids: LANGUAGE dec_lang_top: layer_stack_index: 2 layers: [0] - hidden_size: 16 + hidden_dim: 16 ids: LANGUAGE save_model: models/opus.spm32k.adafactor.hamburger.l2.dsae/opus.spm32k.adafactor.hamburger.l2.dsae @@ -107,7 +107,7 @@ encoder_type: transformer decoder_type: transformer rnn_size: 512 word_vec_size: 512 -transformer_ff: 2048 +ff_mult: 4 heads: 8 enc_layers: [3, 3] dec_layers: [2, 3, 1] diff --git a/mammoth/distributed/components.py b/mammoth/distributed/components.py index 7998cb46..bba4c211 100644 --- a/mammoth/distributed/components.py +++ b/mammoth/distributed/components.py @@ -77,6 +77,7 @@ def needs_communication(self) -> bool: return self.group is not None +# TODO: This is a misnomer: Not an entire XCoder, but just one AttentionLayers block @dataclass class DistributedXCoder(DistributedComponent, ABC): layer_stack_index: int diff --git a/mammoth/model_builder.py b/mammoth/model_builder.py index 08fe6c7b..23c8a9bc 100644 --- a/mammoth/model_builder.py +++ b/mammoth/model_builder.py @@ -4,24 +4,28 @@ """ import torch import torch.nn as nn -from torch.nn.init import xavier_uniform_ -from pathlib import Path - from collections import defaultdict - -import mammoth.modules - -from mammoth.models.adapters import ( +from functools import partial +from pathlib import Path +from torch.nn.init import xavier_uniform_ +from typing import Optional, List +from x_transformers import TransformerWrapper + +from mammoth.distributed.components import ( + DistributedAdapter, + DistributedComponent, + DistributedDecoder, + DistributedEncoder, + Side, +) +from mammoth.models import NMTModel +from mammoth.modules.adapters import ( + AdaptedAttentionLayers, Adapter, - EncoderAdapterLayer, - DecoderAdapterLayer, + FeedForwardAdapterLayer, + LoraAdapterLayer, ) -from mammoth.constants import DefaultTokens -from mammoth.modules.layer_stack_decoder import LayerStackDecoder -from mammoth.modules.layer_stack_encoder import LayerStackEncoder -from mammoth.modules import Embeddings -from mammoth.modules.embeddings import PluggableEmbeddings -from mammoth.modules.util_class import Cast +from mammoth.modules.layer_stack import AdaptedAttentionLayersStack, StackXcoder from mammoth.utils.logging import logger from mammoth.utils.misc import use_gpu from mammoth.utils.module_splitter import _combine_ordered_dicts @@ -30,292 +34,362 @@ from mammoth.modules.attention_bridge import AttentionBridge -def build_embeddings(opts, vocab, for_encoder=True): - """ - Args: - opts: the option in current environment. - vocab: stoi-ish object. - for_encoder(bool): build Embeddings for encoder or decoder? - """ - word_padding_idx = vocab.stoi[DefaultTokens.PAD] - opts.word_padding_idx = word_padding_idx - - freeze_word_vecs = opts.freeze_word_vecs_enc if for_encoder else opts.freeze_word_vecs_dec - emb = Embeddings( - word_vec_size=opts.model_dim, - position_encoding=opts.position_encoding, - dropout=opts.dropout[0] if isinstance(opts.dropout, list) else opts.dropout, - word_padding_idx=word_padding_idx, - word_vocab_size=len(vocab), - freeze_word_vecs=freeze_word_vecs, - enable_embeddingless=opts.enable_embeddingless - ) - if opts.enable_embeddingless: - logger.info("Creating an embeddingless model.") - return emb - - -def build_encoder(opts, embeddings, task_queue_manager): - """ - Various encoder dispatcher function. - Args: - opts: the option in current environment. - embeddings (Embeddings): vocab embeddings for this encoder. - """ - assert opts.encoder_type == 'transformer', 'Only Transformer is supported' - return LayerStackEncoder.from_opts(opts, embeddings, task_queue_manager) - - -def build_decoder(opts, embeddings, task_queue_manager): - """ - Various decoder dispatcher function. - Args: - opts: the option in current environment. - embeddings (Embeddings): vocab embeddings for this decoder. - """ - assert opts.decoder_type == 'transformer', 'Only Transformer is supported' - return LayerStackDecoder.from_opts(opts, embeddings, task_queue_manager) +def uses_adapters(opts): + return 'adapters' in opts and opts.adapters -def load_test_multitask_model(opts, task=None, model_path=None): - """If a checkpoint ending with ".pt" returns a full model - otherwise it builds a bilingual model""" +def load_test_multitask_model(opts, task_queue_manager, task=None, model_path=None): if task is None: raise ValueError('Must set task') if model_path is None: model_path = opts.models[0] - if model_path.endswith('.pt'): - return load_test_model(opts, model_path) - else: - checkpoint_modules = [ - (f'encoder.embeddings.embeddings_{task.src_lang}.', f'src_embeddings_{task.src_lang}'), - (f'decoder.embeddings.embeddings_{task.tgt_lang}.', f'tgt_embeddings_{task.tgt_lang}'), - (f'generator.generator_{task.tgt_lang}.', f'generator_{task.tgt_lang}'), - ('attention_bridge.', 'attention_bridge'), - ] - - for layer_stack_idx, layer_stack_key in enumerate(task.encoder_id): + checkpoint_modules = [ + (f'encoder.embeddings.embeddings_{task.src_lang}.', f'src_embeddings_{task.src_lang}'), + (f'decoder.embeddings.embeddings_{task.tgt_lang}.', f'tgt_embeddings_{task.tgt_lang}'), + (f'generator.generator_{task.tgt_lang}.', f'generator_{task.tgt_lang}'), + ('attention_bridge.', 'attention_bridge'), + ] + + for layer_stack_idx, layer_stack_key in enumerate(task.encoder_id): + checkpoint_modules.append( + ( + f'encoder.encoders.{layer_stack_idx}.{layer_stack_key}.', + f'encoder_{layer_stack_idx}_{layer_stack_key}' + ) + ) + if task.encoder_adapter_ids: + for layer_stack_idx, adapter_group, sub_id in task.encoder_adapter_ids: checkpoint_modules.append( ( - f'encoder.encoders.{layer_stack_idx}.{layer_stack_key}.', - f'encoder_{layer_stack_idx}_{layer_stack_key}' + f'encoder.encoders.{layer_stack_idx}.{layer_stack_key}.adapters.adapter_{adapter_group}_{sub_id}.', # noqa + f'encoder_adapter_{layer_stack_idx}_{layer_stack_key}_{adapter_group}_{sub_id}' ) ) - if task.encoder_adapter_ids: - for layer_stack_idx, adapter_group, sub_id in task.encoder_adapter_ids: - checkpoint_modules.append( - ( - f'encoder.encoders.{layer_stack_idx}.{layer_stack_key}.adapters.adapter_{adapter_group}_{sub_id}.', # noqa - f'encoder_adapter_{layer_stack_idx}_{layer_stack_key}_{adapter_group}_{sub_id}' - ) - ) - for layer_stack_idx, layer_stack_key in enumerate(task.decoder_id): + for layer_stack_idx, layer_stack_key in enumerate(task.decoder_id): + checkpoint_modules.append( + ( + f'decoder.decoders.{layer_stack_idx}.{layer_stack_key}.', + f'decoder_{layer_stack_idx}_{layer_stack_key}' + ) + ) + if task.decoder_adapter_ids: + for layer_stack_idx, adapter_group, sub_id in task.decoder_adapter_ids: checkpoint_modules.append( ( - f'decoder.decoders.{layer_stack_idx}.{layer_stack_key}.', - f'decoder_{layer_stack_idx}_{layer_stack_key}' + f'decoder.decoders.{layer_stack_idx}.{layer_stack_key}.adapters.adapter_{adapter_group}_{sub_id}.', # noqa + f'decoder_adapter_{layer_stack_idx}_{layer_stack_key}_{adapter_group}_{sub_id}' ) ) - if task.decoder_adapter_ids: - for layer_stack_idx, adapter_group, sub_id in task.decoder_adapter_ids: - checkpoint_modules.append( - ( - f'decoder.decoders.{layer_stack_idx}.{layer_stack_key}.adapters.adapter_{adapter_group}_{sub_id}.', # noqa - f'decoder_adapter_{layer_stack_idx}_{layer_stack_key}_{adapter_group}_{sub_id}' - ) - ) - - model_path = model_path.rstrip('_') - checkpoint_paths = [ - (prefix, f'{model_path}_{key}.pt') for (prefix, key) in checkpoint_modules - ] - - opts.model_frame = model_path + '_frame.pt' - frame = torch.load(opts.model_frame, map_location=lambda storage, loc: storage) - - checkpoint_state_dicts = { - prefix: torch.load(path, map_location=lambda storage, loc: storage) - for prefix, path in checkpoint_paths - } - - combined_state_dict = _combine_ordered_dicts(checkpoint_state_dicts) - - vocabs_dict = { - 'src': frame["vocab"].get(('src', task.src_lang)), - 'tgt': frame["vocab"].get(('tgt', task.tgt_lang)), - } - # FIXME - # fields["indices"] = Field(use_vocab=False, dtype=torch.long, sequential=False) - - model_opts = ArgumentParser.ckpt_model_opts(frame['opts']) - # Avoid functionality on inference - # model_opts.update_vocab = False - model = create_bilingual_model( - task=task, - model_opts=model_opts, - vocabs_dict=vocabs_dict - ) - model_params = {name for name, p in model.named_parameters()} - model_params.update(name for name, p in model.named_buffers()) - for key in set(combined_state_dict.keys()): - if key not in model_params: - print(f'Deleting unnecessary key: {key}') - del combined_state_dict[key] - for key in model_params: - if key not in combined_state_dict: - print(f'Key missing {key}') - model.load_state_dict(combined_state_dict) - device = torch.device("cuda" if use_gpu(opts) else "cpu") - model.to(device) - - model.eval() - return vocabs_dict, model, model_opts + model_path = model_path.rstrip('_') + checkpoint_paths = [ + (prefix, f'{model_path}_{key}.pt') for (prefix, key) in checkpoint_modules + ] + opts.model_frame = model_path + '_frame.pt' + frame = torch.load(opts.model_frame, map_location=lambda storage, loc: storage) -def load_test_model(opts, model_path=None): - if model_path is None: - model_path = opts.models[0] + checkpoint_state_dicts = { + prefix: torch.load(path, map_location=lambda storage, loc: storage) + for prefix, path in checkpoint_paths + } - if len(opts.models) > 1: - model_path_enc = opts.models[0] - checkpoint = torch.load(model_path_enc, map_location=lambda storage, loc: storage) - model = checkpoint['whole_model'] + combined_state_dict = _combine_ordered_dicts(checkpoint_state_dicts) - model_path_dec = opts.models[1] - model_dec = torch.load(model_path_dec, map_location=lambda storage, loc: storage)['whole_model'] - model.decoder = model_dec.decoder - model.generator = model_dec.generator - else: - checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) - model = checkpoint['whole_model'] - - model_opts = ArgumentParser.ckpt_model_opts(checkpoint['opts']) - ArgumentParser.update_model_opts(model_opts) - ArgumentParser.validate_model_opts(model_opts) - vocabs = checkpoint['vocab'] - print("VOCABS") - print(vocabs) - if opts.gpu != -1: - device = torch.device("cuda") - model.to(device) - - lang_pair = opts.lang_pair - src_lang, tgt_lang = lang_pair.split("-") + vocabs_dict = { + 'src': frame["vocab"].get(('src', task.src_lang)), + 'tgt': frame["vocab"].get(('tgt', task.tgt_lang)), + } # FIXME - vocabs_dict = {} - vocabs_dict['src'] = vocabs[('src', src_lang)] - vocabs_dict['tgt'] = vocabs[('tgt', tgt_lang)] - # indices = None # Field(use_vocab=False, dtype=torch.long, sequential=False) - # fields["indices"] = indices + # fields["indices"] = Field(use_vocab=False, dtype=torch.long, sequential=False) + model_opts = ArgumentParser.ckpt_model_opts(frame['opts']) # Avoid functionality on inference # model_opts.update_vocab = False + model = build_model( + model_opts, + opts, + vocabs_dict, + task_queue_manager, + checkpoint=None, + single_task=task.corpus_id, + ) + model_params = {name for name, p in model.named_parameters()} + model_params.update(name for name, p in model.named_buffers()) + for key in set(combined_state_dict.keys()): + if key not in model_params: + print(f'Deleting unnecessary key: {key}') + del combined_state_dict[key] + for key in model_params: + if key not in combined_state_dict: + print(f'Key missing {key}') + model.load_state_dict(combined_state_dict) + device = torch.device("cuda" if use_gpu(opts) else "cpu") + model.to(device) - if opts.fp32: - model.float() - elif opts.int8: - if opts.gpu >= 0: - raise ValueError("Dynamic 8-bit quantization is not supported on GPU") - torch.quantization.quantize_dynamic(model, inplace=True) model.eval() - model.generator.eval() + return vocabs_dict, model, model_opts -def create_bilingual_model( - task, model_opts, vocabs_dict +def get_attention_layers_kwargs( + side: Side, + layer_stack_index, + xcoder_id, + model_opts, ): - """For translation.""" - src_lang = task.src_lang - tgt_lang = task.tgt_lang - generators_md = nn.ModuleDict() - src_emb = build_src_emb(model_opts, vocabs_dict['src']) - tgt_emb = build_tgt_emb(model_opts, vocabs_dict['tgt']) - pluggable_src_emb = PluggableEmbeddings({src_lang: src_emb}) - pluggable_tgt_emb = PluggableEmbeddings({tgt_lang: tgt_emb}) - - pluggable_src_emb.activate(src_lang) - pluggable_tgt_emb.activate(tgt_lang) - encoder = LayerStackEncoder.from_trans_opt(model_opts, pluggable_src_emb, task=task) - decoder = LayerStackDecoder.from_trans_opt(model_opts, pluggable_tgt_emb, task=task) - generator = build_generator(model_opts, len(vocabs_dict['tgt']), tgt_emb) - generators_md.add_module(f'generator_{tgt_lang}', generator) - - attention_bridge = AttentionBridge.from_opts(model_opts) - - nmt_model = mammoth.models.NMTModel( - encoder=encoder, - decoder=decoder, - attention_bridge=attention_bridge - ) - if uses_adapters(model_opts): - logger.info('Creating adapters...') - create_bilingual_adapters(nmt_model, model_opts, task) - else: - logger.info('Does not use adapters...') - print('built model:') - print(nmt_model) - nmt_model.generator = generators_md - return nmt_model - - -def build_src_emb(model_opts, src_vocab): - # Build embeddings. - src_emb = build_embeddings(model_opts, src_vocab) - return src_emb - - -def build_tgt_emb(model_opts, tgt_vocab): - # Build embeddings. - tgt_emb = build_embeddings(model_opts, tgt_vocab, for_encoder=False) - - # if share_embeddings: - # tgt_emb.word_lut.weight = src_emb.word_lut.weight - - return tgt_emb - - -def build_task_specific_model( + """Return arguments for x_transformers.AttentionLayers""" + depths = model_opts.enc_layers if side == Side.decoder else model_opts.dec_layers + depth = depths[layer_stack_index] + causal = side == Side.decoder + cross_attend = side == Side.decoder + is_last = layer_stack_index == len(depths) - 1 + # changed from default + use_simple_rmsnorm = True + attn_flash = True + ff_glu = True + pre_norm_has_final_norm = is_last + # Mostly x_transformers defaults. Make (some of this) configurable. + return { + 'dim': model_opts.model_dim, + 'depth': depth, + 'heads': model_opts.heads, + 'causal': causal, + 'cross_attend': cross_attend, + 'only_cross': False, + 'use_scalenorm': False, + 'use_rmsnorm': False, + 'use_simple_rmsnorm': use_simple_rmsnorm, + 'use_adaptive_layernorm': False, + 'use_adaptive_rmsnorm': False, + 'use_adaptive_layerscale': False, + 'norm_add_unit_offset': True, + 'dim_condition': None, + 'adaptive_condition_mlp': False, + 'adaptive_condition_mlp_expansion': 4, + 'alibi_pos_bias': False, + 'alibi_num_heads': None, + 'rel_pos_bias': False, + 'rel_pos_num_buckets': 32, + 'rel_pos_max_distance': 128, + 'dynamic_pos_bias': False, + 'dynamic_pos_bias_log_distance': False, + 'dynamic_pos_bias_mlp_depth': 2, + 'dynamic_pos_bias_norm': False, + 'rotary_pos_emb': False, + 'rotary_emb_dim': None, + 'rotary_xpos': False, + 'rotary_interpolation_factor': 1., + 'rotary_xpos_scale_base': 512, + 'rotary_base_rescale_factor': 1., + 'weight_tie_layers': False, + 'custom_layers': None, + 'layers_execute_order': None, + # 'sandwich_coef': None, # Sandwich would be very unintuitive with multiple layerstacks + 'par_ratio': None, + 'residual_attn': False, + 'cross_residual_attn': False, + # 'macaron': False, # Can not support macaron and inject adapters at each 'f' layer + 'pre_norm': True, + 'pre_norm_has_final_norm': pre_norm_has_final_norm, + 'gate_residual': False, + 'scale_residual': False, + 'scale_residual_constant': 1., + 'shift_tokens': 0, + 'sandwich_norm': False, + 'softclamp_output': False, + 'softclamp_output_value': 30., + 'resi_dual': False, + 'resi_dual_scale': 1., + 'zero_init_branch_output': False, + 'layer_dropout': 0., + 'cross_attn_tokens_dropout': 0., + 'disable_abs_pos_emb': None, + 'use_layerscale': False, + 'layerscale_init_value': 0., + + 'ff_dim_out': None, + 'ff_mult': model_opts.ff_mult, + 'ff_glu': ff_glu, + 'ff_glu_mult_bias': False, + 'ff_swish': False, + 'ff_relu_squared': False, + 'ff_post_act_ln': False, + 'ff_dropout': 0., + 'ff_no_bias': False, + 'ff_zero_init_output': False, + + 'attn_dim_context': None, + 'attn_flash': attn_flash, + 'attn_talking_heads': False, + 'attn_head_scale': False, + 'attn_sparse_topk': None, + 'attn_num_mem_kv': 0, + 'attn_dropout': 0., + 'attn_on_attn': False, + 'attn_gate_value_heads': False, + 'attn_swiglu_values': False, + 'attn_gate_values': False, + 'attn_zero_init_output': False, + 'attn_max_attend_past': None, + 'attn_qk_norm': False, + 'attn_qk_norm_groups': 1, + 'attn_qk_norm_scale': 10, + 'attn_qk_norm_dim_scale': False, + 'attn_one_kv_head': False, + 'attn_kv_heads': None, + 'attn_shared_kv': False, + 'attn_value_dim_head': None, + 'attn_tensor_product': False, # https://arxiv.org/abs/2208.06061 + 'attn_add_zero_kv': False, # same as add_zero_attn in pytorch + 'attn_rotary_embed_values': False, + 'attn_use_cope': False, + 'attn_cope_max_pos': 16, + 'attn_cope_soft_onehot_pos': False, + 'attn_cope_talking_heads': False, + 'attn_softclamp_logits': False, + 'attn_logit_softclamp_value': 50., + 'attn_onnxable': False, + } + + +def build_xcoder( + side: Side, model_opts, vocabs_dict, device, task_queue_manager, - checkpoint, + single_task: Optional[str] = None, ): + my_components: List[DistributedComponent] = task_queue_manager.get_my_distributed_components() + my_components = [ + component for component in my_components + if hasattr(component, 'side') and component.side == side + ] + distributed_xcoder_class: type + if side == Side.encoder: + distributed_xcoder_class = DistributedEncoder + side_str = 'encoder' + else: + distributed_xcoder_class = DistributedDecoder + side_str = 'decoder' + if single_task: + my_components = [ + component for component in my_components + if single_task in component.task_ids + ] - src_embs = dict() - tgt_embs = dict() - - generators_md = nn.ModuleDict() + # Create AdaptedAttentionLayers objects (an extension of an x_transformers.AttentionLayers block) + attention_layers_components = [ + component for component in my_components + if isinstance(component, distributed_xcoder_class) + ] + attention_layer_blocks = defaultdict(dict) + for component in attention_layers_components: + layer_stack_index = component.layer_stack_index + xcoder_id = component.xcoder_id + attention_layers_kwargs = get_attention_layers_kwargs( + side=side, + layer_stack_index=layer_stack_index, + xcoder_id=xcoder_id, + model_opts=model_opts, + ) + attention_layer_blocks[layer_stack_index][xcoder_id] = AdaptedAttentionLayers(**attention_layers_kwargs) - # FIXME: it's getting late and I just want this to compile - for side, lang, _, vocab in task_queue_manager.get_my_vocabs(side='src', vocabs_dict=vocabs_dict): - src_emb = build_src_emb(model_opts, vocab) - src_embs[lang] = src_emb - pluggable_src_emb = PluggableEmbeddings(src_embs) - encoder = build_only_enc(model_opts, pluggable_src_emb, task_queue_manager, checkpoint) + # Create AdapterLayer objects and Adapter objects + if uses_adapters(model_opts): + adapter_components = [ + component for component in my_components + if isinstance(component, DistributedAdapter) and component.side == side + ] + adapter_params_by_group = dict() + for adapter_group, adapter_opts in model_opts.adapters[side_str].items(): + adapter_params_by_group[adapter_group] = { + 'layer_stack_index': adapter_opts['layer_stack_index'], + 'hidden_dim': adapter_opts['hidden_dim'], + 'layers': adapter_opts['layers'], + 'sub_ids': adapter_opts['ids'], + } + for component in adapter_components: + adapter_params = adapter_params_by_group[component.adapter_group] + if model_opts.adapter_type.lower() == 'lora': + adapter_layer_func = partial( + LoraAdapterLayer, + dim=model_opts.model_dim, + r=adapter_params['hidden_dim'], + ) + elif model_opts.adapter_type.lower() == 'ff': + mult = adapter_params['hidden_dim'] / model_opts.model_dim + # TODO: make norm locations and glu configurable + adapter_layer_func = partial( + FeedForwardAdapterLayer, + dim=model_opts.model_dim, + mult=mult, + pre_norm=True, + sandwich_norm=False, + glu=True, + ) + else: + raise ValueError(f'Unrecognized adapter_type {model_opts.adapter_type}') + for sub_id in adapter_params['sub_ids']: + for layer_idx in adapter_params['layers']: + adapter_layer = adapter_layer_func() + adapter = Adapter( + adapter_group=component.adapter_group, + sub_id=sub_id, + ) + adapter.add_layer(layer_idx, adapter_layer) + layer_stack_index = adapter_params['layer_stack_index'] + for attention_layers in attention_layer_blocks[layer_stack_index]: + attention_layers.add_adapter(adapter) + + # Create AdaptedAttentionLayersStack objects and TransformerWrapper objects + tasks = task_queue_manager.get_my_tasks() + if single_task: + tasks = [task for task in tasks if task.corpus_id == single_task] + transformer_wrappers = dict() + for task in tasks: + if side == Side.encoder: + xcoder_ids = task.encoder_id + else: + xcoder_ids = task.decoder_id + attention_layers_stack = [ + attention_layer_blocks[layer_stack_index][xcoder_id] + for layer_stack_index, xcoder_id in enumerate(xcoder_ids) + ] + adapted_attention_layers_stack = AdaptedAttentionLayersStack( + attention_layers_stack=attention_layers_stack + ) - for side, lang, _, vocab in task_queue_manager.get_my_vocabs(side='tgt', vocabs_dict=vocabs_dict): - tgt_emb = build_tgt_emb(model_opts, vocab) - tgt_embs[lang] = tgt_emb - generator = build_generator(model_opts, len(vocab), tgt_emb) - generators_md.add_module(f'generator_{lang}', generator) + side_alt_str = 'src' if side == Side.encoder else 'tgt' + lang = task.src_lang if side == Side.encoder else task.tgt_lang + vocab = vocabs_dict[(side_alt_str, lang)] + max_seq_len = 0 if model_opts.max_length is None else model_opts.max_length + post_emb_norm = True + tie_embedding = True + use_abs_pos_emb = True + emb_frac_gradient = 1. + # FIXME: this won't work: creates embeddings for each task, not for each language + # Have to reimplement TransformerWrapper to allow passing in an embedding + transformer_wrapper = TransformerWrapper( + num_tokens=len(vocab), + max_seq_len=max_seq_len, + attn_layers=adapted_attention_layers_stack, + emb_dim=model_opts.model_dim, + post_emb_norm=post_emb_norm, + tie_embedding=tie_embedding, + use_abs_pos_emb=use_abs_pos_emb, + emb_frac_gradient=emb_frac_gradient, + ) + transformer_wrappers[task.corpus_id] = transformer_wrapper - if checkpoint: - trainstep = int(checkpoint['optim']['training_step']) - 1 - for modname, gen in generators_md.items(): - mod_path = Path(checkpoint['opts'].save_model + f"_step_{trainstep}_{modname}.pt") - if mod_path.exists(): - module = torch.load(mod_path) - gen.load_state_dict(module) - logger.info(f"Successfully loaded {modname} from the checkpoint.") + # Create a StackXcoder + stack_xcoder = StackXcoder(transformer_wrappers) + return stack_xcoder - pluggable_tgt_emb = PluggableEmbeddings(tgt_embs) - decoder = build_only_dec(model_opts, pluggable_tgt_emb, task_queue_manager, checkpoint) - # TODO: implement hierarchical approach to layer sharing +def build_attention_bridge(model_opts): attention_bridge = AttentionBridge.from_opts(model_opts) if model_opts.param_init != 0.0: @@ -325,6 +399,24 @@ def build_task_specific_model( for p in attention_bridge.parameters(): if p.dim() > 1: xavier_uniform_(p, gain=nn.init.calculate_gain('relu')) + return attention_bridge + + +def restore_from_checkpoint(stack_xcoder, checkpoint): + # FIXME: saving and loading are broken + trainstep = int(checkpoint['optim']['training_step']) - 1 + for modname, gen in generators_md.items(): + mod_path = Path(checkpoint['opts'].save_model + f"_step_{trainstep}_{modname}.pt") + if mod_path.exists(): + module = torch.load(mod_path) + gen.load_state_dict(module) + logger.info(f"Successfully loaded {modname} from the checkpoint.") + + pluggable_tgt_emb = PluggableEmbeddings(tgt_embs) + decoder = build_only_dec(model_opts, pluggable_tgt_emb, task_queue_manager, checkpoint) + + # TODO: implement hierarchical approach to layer sharing + attention_bridge = build_attention_bridge(model_opts) if checkpoint: # trainstep= int(checkpoint['optim']['training_step'])-1 - already recoderd in generators @@ -336,11 +428,6 @@ def build_task_specific_model( if model_opts.model_dtype == 'fp16' and model_opts.optim == 'fusedadam': attention_bridge.half() - nmt_model = mammoth.models.NMTModel( - encoder=encoder, - decoder=decoder, - attention_bridge=attention_bridge - ) if uses_adapters(model_opts): logger.info('Creating adapters...') create_all_adapters(nmt_model, model_opts, task_queue_manager) @@ -353,150 +440,13 @@ def build_task_specific_model( return nmt_model, generators_md -def build_only_enc(model_opts, src_emb, task_queue_manager, checkpoint): - """Truly only builds encoder: no embeddings""" - encoder = build_encoder(model_opts, src_emb, task_queue_manager) - if model_opts.param_init != 0.0: - for name, p in encoder.named_parameters(): - if not ("embedding" in name and "pe" not in name and model_opts.enable_embeddingless is True): - p.data.uniform_(-model_opts.param_init, model_opts.param_init) - - if model_opts.param_init_glorot: - for name, p in encoder.named_parameters(): - if not ("embedding" in name and "pe" not in name and model_opts.enable_embeddingless is True): - if p.dim() > 1: - xavier_uniform_(p, gain=nn.init.calculate_gain('relu')) - if checkpoint: - logger.info("Loading from checkpoint") - trainstep = int(checkpoint['optim']['training_step']) - 1 - embnames = [srctgt['src_tgt'].split('-')[0] for srctgt in checkpoint['opts'].tasks.values()] - embnames = set(embnames) - groupnames = [ - (idx, modname) for srctgt in checkpoint['opts'].tasks.values() - for idx, modname in enumerate(srctgt['enc_sharing_group']) - ] - groupnames = set(groupnames) - # load embs - for modname in embnames: - module = torch.load(checkpoint['opts'].save_model + f"_step_{trainstep}_src_embeddings_{modname}.pt") - if f'embeddings_{modname}' in encoder.embeddings._modules.keys(): - encoder.embeddings._modules[f'embeddings_{modname}'].load_state_dict(module) - logger.info(f"Successfully loaded the embeddings of {modname} from the checkpoint.") - - # load layers - for idx, modname in groupnames: - mod_path = Path(checkpoint['opts'].save_model + f"_step_{trainstep}_encoder_{idx}_{modname}.pt") - if mod_path.exists() and modname in encoder.encoders._modules[str(idx)].keys(): - module = torch.load(mod_path) - encoder.encoders._modules[str(idx)][modname].load_state_dict(module) - logger.info(f"Successfully loaded layer {str(idx)} of {modname} from the checkpoint.") - if model_opts.model_dtype == 'fp16' and model_opts.optim == 'fusedadam': - encoder.half() - - return encoder - - -def build_only_dec(model_opts, tgt_emb, task_queue_manager, checkpoint): - decoder = build_decoder(model_opts, tgt_emb, task_queue_manager) - if model_opts.param_init != 0.0: - for name, p in decoder.named_parameters(): - if not ("embedding" in name and "pe" not in name and model_opts.enable_embeddingless is True): - p.data.uniform_(-model_opts.param_init, model_opts.param_init) - if model_opts.param_init_glorot: - for name, p in decoder.named_parameters(): - if not ("embedding" in name and "pe" not in name and model_opts.enable_embeddingless is True): - if p.dim() > 1: - xavier_uniform_(p, gain=nn.init.calculate_gain('relu')) - - if checkpoint: - logger.info("Loading from checkpoint") - trainstep = int(checkpoint['optim']['training_step']) - 1 - embnames = [srctgt['src_tgt'].split('-')[1] for srctgt in checkpoint['opts'].tasks.values()] - embnames = set(embnames) - groupnames = [ - (idx, modname) for srctgt in checkpoint['opts'].tasks.values() - for idx, modname in enumerate(srctgt['dec_sharing_group']) - ] - groupnames = set(groupnames) - # load embs - for modname in embnames: - if f'embeddings_{modname}' in decoder.embeddings._modules.keys(): - module = torch.load(checkpoint['opts'].save_model + f"_step_{trainstep}_tgt_embeddings_{modname}.pt") - decoder.embeddings._modules[f'embeddings_{modname}'].load_state_dict(module) - logger.info(f"Successfully loaded the embeddings of {modname} from the checkpoint.") - - # load layers - for idx, modname in groupnames: - mod_path = Path(checkpoint['opts'].save_model + f"_step_{trainstep}_decoder_{idx}_{modname}.pt") - if mod_path.exists() and modname in decoder.decoders._modules[str(idx)].keys(): - module = torch.load(mod_path) - decoder.decoders._modules[str(idx)][modname].load_state_dict(module) - logger.info(f"Successfully loaded layer {str(idx)} of {modname} from the checkpoint.") - if model_opts.model_dtype == 'fp16' and model_opts.optim == 'fusedadam': - decoder.half() - - return decoder - - -def build_generator(model_opts, n_tgts, tgt_emb): - # Build Generator. - assert not model_opts.copy_attn, 'copy_attn not supported' - gen_func = nn.LogSoftmax(dim=-1) - generator = nn.Sequential( - nn.Linear(model_opts.model_dim, n_tgts), Cast(torch.float32), gen_func - ) - - if model_opts.share_decoder_embeddings: - generator[0].weight = tgt_emb.word_lut.weight - - if model_opts.param_init != 0.0: - for p in generator.parameters(): - p.data.uniform_(-model_opts.param_init, model_opts.param_init) - if model_opts.param_init_glorot: - for p in generator.parameters(): - if p.dim() > 1: - xavier_uniform_(p, gain=nn.init.calculate_gain('relu')) - - return generator - - -# TODO: confirm this was dead code -# def use_embeddings_from_checkpoint(fields, model, generator, checkpoint): -# # Update vocabulary embeddings with checkpoint embeddings -# logger.info("Updating vocabulary embeddings with checkpoint embeddings") -# # Embedding layers -# enc_emb_name = "encoder.embeddings.make_embedding.emb_luts.0.weight" -# dec_emb_name = "decoder.embeddings.make_embedding.emb_luts.0.weight" -# -# for field_name, emb_name in [("src", enc_emb_name), ("tgt", dec_emb_name)]: -# if emb_name not in checkpoint["model"]: -# continue -# multifield = fields[field_name] -# checkpoint_multifield = checkpoint["vocab"][field_name] -# for (name, field), (checkpoint_name, checkpoint_field) in zip(multifield, checkpoint_multifield): -# new_tokens = [] -# for i, tok in enumerate(field.vocab.itos): -# if tok in checkpoint_field.vocab.stoi: -# old_i = checkpoint_field.vocab.stoi[tok] -# model.state_dict()[emb_name][i] = checkpoint["model"][emb_name][old_i] -# if field_name == "tgt": -# generator.state_dict()["0.weight"][i] = checkpoint["generator"]["0.weight"][old_i] -# generator.state_dict()["0.bias"][i] = checkpoint["generator"]["0.bias"][old_i] -# else: -# # Just for debugging purposes -# new_tokens.append(tok) -# logger.info("%s: %d new tokens" % (name, len(new_tokens))) -# # Remove old vocabulary associated embeddings -# del checkpoint["model"][emb_name] -# del checkpoint["generator"]["0.weight"], checkpoint["generator"]["0.bias"] - - -def build_base_model_langspec( +def build_model( model_opts, + opts, vocabs_dict, - gpu, task_queue_manager, checkpoint=None, + single_task=None, ): """Build a model from opts. @@ -504,174 +454,49 @@ def build_base_model_langspec( model_opts: the option loaded from checkpoint. It's important that the opts have been updated and validated. See :class:`mammoth.utils.parse.ArgumentParser`. + opts: overriding options. vocabs_dict (dict[str, mammoth.inputters.Vocab]): `Vocab` objects for the model. - gpu (bool): whether to use gpu. - checkpoint: the model gnerated by train phase, or a resumed snapshot + task_queue_manager: TaskQueueManager + checkpoint: the model generated by train phase, or a resumed snapshot model from a stopped training. - gpu_id (int or NoneType): Which GPU to use. + single_task: corpus_id of task, to create a single-task model Returns: the NMTModel. """ - - # for back compat when attention_dropout was not defined - try: - model_opts.attention_dropout - except AttributeError: - model_opts.attention_dropout = model_opts.dropout - - # Build Model - logger.info("MODEL BUILDER") + logger.info('Building model...') + gpu = use_gpu(opts) if gpu: device = torch.device("cuda") else: device = torch.device("cpu") logger.info(device) - model, generators_md = build_task_specific_model( + + encoder = build_xcoder( + side=Side.encoder, model_opts=model_opts, vocabs_dict=vocabs_dict, device=device, task_queue_manager=task_queue_manager, - checkpoint=checkpoint, - ) - - model.generator = generators_md - model.to(device) - - return model, generators_md - - -def uses_adapters(opts): - return 'adapters' in opts and opts.adapters - - -def create_all_adapters(model, opts, task_queue_manager): - my_enc_adapter_ids = set() - my_dec_adapter_ids = set() - adapter_to_encoder_ids = defaultdict(set) - adapter_to_decoder_ids = defaultdict(set) - for task in task_queue_manager.get_my_tasks(): - for adapter_id in task.encoder_adapter_ids: - adapter_id = tuple(adapter_id) - my_enc_adapter_ids.add(adapter_id) - adapter_to_encoder_ids[adapter_id].add(tuple(task.encoder_id)) - for adapter_id in task.decoder_adapter_ids: - adapter_id = tuple(adapter_id) - my_dec_adapter_ids.add(adapter_id) - adapter_to_decoder_ids[adapter_id].add(tuple(task.decoder_id)) - _create_adapters( - model, - opts, - my_enc_adapter_ids, - adapter_to_encoder_ids, - my_dec_adapter_ids, - adapter_to_decoder_ids, + single_task=single_task, ) - - -def create_bilingual_adapters(model, opts, task): - my_enc_adapter_ids = [] - my_dec_adapter_ids = [] - adapter_to_encoder_ids = {} - adapter_to_decoder_ids = {} - - for adapter_id in task.encoder_adapter_ids: - adapter_id = tuple(adapter_id) - my_enc_adapter_ids.add(adapter_id) - # This is a list of list, because in general the adapter could be used in several stacks - adapter_to_encoder_ids[adapter_id] = [task.encoder_id] - for adapter_id in task.decoder_adapter_ids: - adapter_id = tuple(adapter_id) - my_dec_adapter_ids.add(adapter_id) - adapter_to_decoder_ids[adapter_id] = [task.decoder_id] - - _create_adapters( - model, - opts, - my_enc_adapter_ids, - adapter_to_encoder_ids, - my_dec_adapter_ids, - adapter_to_decoder_ids, - ) - - -def _create_adapters( - model, - opts, - my_enc_adapter_ids, - adapter_to_encoder_ids, - my_dec_adapter_ids, - adapter_to_decoder_ids, -): - my_enc_adapter_ids = [tuple(item) for item in my_enc_adapter_ids] - my_dec_adapter_ids = [tuple(item) for item in my_dec_adapter_ids] - for adapter_group, adapter_opts in opts.adapters['encoder'].items(): - layer_stack_index = adapter_opts['layer_stack_index'] - for sub_id in adapter_opts['ids']: - adapter_id_long = (layer_stack_index, adapter_group, sub_id) - if adapter_id_long not in my_enc_adapter_ids: - continue - adapter = Adapter(adapter_group, sub_id) - input_dim = opts.model_dim - hidden_dim = adapter_opts['hidden_size'] - - # all stacks to which this adapter should be added - adapted_stacks = set( - stacks[layer_stack_index] for stacks in adapter_to_encoder_ids[adapter_id_long] - ) - adapter_cls = EncoderAdapterLayer - - for layer_idx in adapter_opts['layers']: - adapter.add_layer( - layer_idx, - adapter_cls(input_dim, hidden_dim, pfeiffer=False, init='small') - ) - model.encoder.add_adapter( - adapter_group=adapter_group, - sub_id=sub_id, - adapter=adapter, - layer_stack_index=layer_stack_index, - module_ids=adapted_stacks, - ) - for adapter_group, adapter_opts in opts.adapters['decoder'].items(): - layer_stack_index = adapter_opts['layer_stack_index'] - for sub_id in adapter_opts['ids']: - adapter_id_long = (layer_stack_index, adapter_group, sub_id) - if adapter_id_long not in my_dec_adapter_ids: - continue - adapter = Adapter(adapter_group, sub_id) - input_dim = opts.model_dim - hidden_dim = adapter_opts['hidden_size'] - - adapted_stacks = set( - stacks[layer_stack_index] for stacks in adapter_to_decoder_ids[adapter_id_long] - ) - adapter_cls = DecoderAdapterLayer - - for layer_idx in adapter_opts['layers']: - adapter.add_layer( - layer_idx, - adapter_cls(input_dim, hidden_dim, pfeiffer=False, init='small') - ) - model.decoder.add_adapter( - adapter_group=adapter_group, - sub_id=sub_id, - adapter=adapter, - layer_stack_index=layer_stack_index, - module_ids=adapted_stacks, - ) - - -def build_model(model_opts, opts, vocabs_dict, task_queue_manager, checkpoint): - logger.info('Building model...') - model, generators_md = build_base_model_langspec( + decoder = build_xcoder( + side=Side.decoder, model_opts=model_opts, vocabs_dict=vocabs_dict, - gpu=use_gpu(opts), + device=device, task_queue_manager=task_queue_manager, - checkpoint=checkpoint, + single_task=single_task, + ) + attention_bridge = build_attention_bridge(model_opts) + model = NMTModel( + encoder=encoder, + decoder=decoder, + attention_bridge=attention_bridge ) + + model.to(device) # logger.info(model) logger.info('Building model - done!') - return model, generators_md + return model diff --git a/mammoth/models/adapters.py b/mammoth/models/adapters.py deleted file mode 100644 index d52ec20a..00000000 --- a/mammoth/models/adapters.py +++ /dev/null @@ -1,258 +0,0 @@ -""" -Inject small bottleneck layers with residual connection -into an already trained network, to adapt it for a new task. -""" - -import torch.nn as nn -import torch.nn.functional as F -from abc import ABC -from collections import defaultdict - -from mammoth.modules import TransformerEncoder -from mammoth.modules import TransformerDecoder -from mammoth.rmsnorm_torch import RMSNorm - - -class AdapterLayer(ABC, nn.Module): - """ - A single adapter layer module - - Implements Simple, Scalable Adaptation for Neural Machine Translation - (https://arxiv.org/abs/1909.08478) - See also fairseq implementation: - https://github.com/ahmetustun/fairseq/blob/master/fairseq/modules/adapter_layer.py - """ - - def __init__(self, input_dim, hidden_dim, pfeiffer=False, init='small', layernorm='layernorm'): - super().__init__() - # Omit LayerCache - self._does_not_need_cache = True - - self.down_proj = nn.Linear(input_dim, hidden_dim) - self.up_proj = nn.Linear(hidden_dim, input_dim) - self.pfeiffer = pfeiffer - if not self.pfeiffer: - if layernorm == 'rmsnorm': - self.layer_norm = RMSNorm(input_dim, eps=1e-6) - else: - self.layer_norm = nn.LayerNorm(input_dim, eps=1e-6) - - if init == 'small' or 'init' == 'bert': - if init == 'small': - almost_zero = 1e-5 - delta = 1e-6 - - def init_fn(tensor): - nn.init.uniform_( - tensor, - almost_zero - delta, almost_zero + delta - ) - elif init == 'bert': - - def init_fn(tensor): - nn.init.normal_(tensor, mean=0.0, std=0.02) - - # Init up. - init_fn(self.up_proj.weight) - init_fn(self.up_proj.bias) - - # Init down. - init_fn(self.down_proj.weight) - init_fn(self.down_proj.bias) - - def forward(self, x): - if self.pfeiffer: - y = self.down_proj(x) - y = F.relu(y) - y = self.up_proj(y) - else: - y = self.layer_norm(x) - y = self.down_proj(y) - y = F.relu(y) - y = self.up_proj(y) - y = x + y - return y - - -class EncoderAdapterLayer(AdapterLayer): - # same call signature as TransformerEncoderLayer - def forward(self, inputs, mask): - out = super().forward(inputs) - return out - - -class DecoderAdapterLayer(AdapterLayer): - # same call signature as TransformerDecoderLayer - def forward( - self, - output, - src_memory_bank, - src_pad_mask, - tgt_pad_mask, - layer_cache=None, - step=None, - with_align=False, - future=False, - ): - output = super().forward(output) - attn = None - attn_align = None - return output, attn, attn_align - - -class Adapter(nn.Module): - """ - A container for one or several AdapterLayers, - together with layer indices for injecting into the base network. - """ - - def __init__(self, adapter_group: str, sub_id: str): - super().__init__() - self.name = self._name(adapter_group, sub_id) - # mapping layer_idx -> ModuleList of AdapterLayer to inject after that layer - self.adapter_layers = nn.ModuleDict() - self._adapted_layer_indices = set() - - @staticmethod - def _name(adapter_group: str, sub_id: str) -> str: - assert isinstance(adapter_group, str), f'Expecting str, not {adapter_group}' - assert isinstance(sub_id, str), f'Expecting str, not {sub_id}' - return f'adapter_{adapter_group}_{sub_id}' - - def add_layer(self, layer_idx, adapter_layer: AdapterLayer): - self._adapted_layer_indices.add(layer_idx) - layer_idx_str = f'layer{layer_idx}' - if layer_idx_str not in self.adapter_layers: - self.adapter_layers[layer_idx_str] = nn.ModuleList() - self.adapter_layers[layer_idx_str].append(adapter_layer) - - def get_layers(self): - return self.adapter_layers.items() - - def __repr__(self): - return f'' - - -class TransformerAdapterMixin: - """ - Mixin to manage one or several Adapters - for a TransformerEncoder or TransformerDecoder. - """ - - def __init__(self, *args, **kwargs): - # run init of next parallel inheritance class - super(TransformerAdapterMixin, self).__init__(*args, **kwargs) - self.adapters = nn.ModuleDict() - self.active = set() - - def freeze_base_model(self, requires_grad=False): - adapter_parameters = {name for name, p in self.adapters.named_parameters()} - for name, p in self.named_parameters(): - if name not in adapter_parameters: - # freeze everything except the adapter parameters - p.requires_grad = requires_grad - - def get_adapter(self, adapter_group: str, sub_id: str): - name = Adapter._name(adapter_group, sub_id) - return self.adapters[name] - - def add_adapter(self, adapter_group: str, sub_id: str, adapter: Adapter): - name = Adapter._name(adapter_group, sub_id) - if name in self.adapters: - raise ValueError(f'Duplicate Adapter "{name}"') - max_layer_index = max(adapter._adapted_layer_indices) - if not self._check_n_layers(max_layer_index): - raise ValueError(f'Invalid number of layers {max_layer_index} in Adapter "{name}"') - self.adapters[name] = adapter - - def _check_n_layers(self, max_layer_index): - """Override this""" - return True - - def deactivate_adapters(self): - self.active = set() - - def activate_adapter(self, adapter_group: str, sub_id: str): - name = Adapter._name(adapter_group, sub_id) - if name not in self.adapters: - raise ValueError( - f'Nonexistent Adapter "{name}". ' - f'Should be one of: {" ".join(self.adapters.keys())}' - ) - self.active.add(name) - - def _merge_active_adapters(self): - """ - Returns a single mapping layer_idx -> list of AdapterLayer, - containing the layers of all currently active adapters - """ - active_adapters = [ - adapter for name, adapter in self.adapters.items() - if name in self.active - ] - merged = defaultdict(list) - for adapter in active_adapters: - for layer_idx, layers in adapter.get_layers(): - merged[layer_idx].extend(layers) - return merged - - def _inject_adapters(self, base_layers): - active_layers = self._merge_active_adapters() - result = [] - for layer_idx, base_layer in enumerate(base_layers): - layer_idx = f'layer{layer_idx}' - result.append(base_layer) - if layer_idx in active_layers: - result.extend(active_layers[layer_idx]) - return result - - -class AdaptedTransformerEncoder(TransformerAdapterMixin, TransformerEncoder): - def _forward_loop(self, out, mask): - injected = self._inject_adapters(self.transformer) - for layer in injected: - out = layer(out, mask) - return out - - def _check_n_layers(self, max_layer_index): - return max_layer_index <= len(self.transformer) - - def state_dict(self, *args, include_adapters=False, **kwargs): - if not include_adapters: - # hide adapters - omitted_adapters = self.adapters - self.adapters = None - - result = TransformerEncoder.state_dict(self, *args, **kwargs) - - if not include_adapters: - # unhide adapters - self.adapters = omitted_adapters - - return result - - -class AdaptedTransformerDecoder(TransformerAdapterMixin, TransformerDecoder): - def forward(self, *args, **kwargs): - self._injected = self._inject_adapters(self.transformer_layers) - return super().forward(*args, **kwargs) - - def _get_layers(self): - return self._injected - - def _check_n_layers(self, max_layer_index): - return max_layer_index <= len(self.transformer_layers) - - def state_dict(self, *args, include_adapters=False, **kwargs): - if not include_adapters: - # hide adapters - omitted_adapters = self.adapters - self.adapters = None - - result = TransformerDecoder.state_dict(self, *args, **kwargs) - - if not include_adapters: - # unhide adapters - self.adapters = omitted_adapters - - return result diff --git a/mammoth/modules/__init__.py b/mammoth/modules/__init__.py index 975b2ef6..67f4a3b8 100644 --- a/mammoth/modules/__init__.py +++ b/mammoth/modules/__init__.py @@ -1,8 +1,6 @@ """Components libary""" from mammoth.modules.util_class import Elementwise from mammoth.modules.multi_headed_attn import MultiHeadedAttention -from mammoth.modules.embeddings import Embeddings, PositionalEncoding -# from mammoth.modules.weight_norm import WeightNormConv2d from mammoth.modules.average_attn import AverageAttention from mammoth.modules.attention_bridge import AttentionBridge @@ -14,28 +12,14 @@ from mammoth.modules.transformer_decoder import TransformerDecoder -str2enc = { - "transformer": TransformerEncoder, - "mean": MeanEncoder, -} - -str2dec = { - "transformer": TransformerDecoder, -} - - __all__ = [ "DecoderBase", "TransformerDecoder", - "str2dec", "EncoderBase", "TransformerEncoder", "MeanEncoder", - "str2enc", "Elementwise", "MultiHeadedAttention", - "Embeddings", - "PositionalEncoding", "AverageAttention", "AttentionBridge", ] diff --git a/mammoth/modules/adapters.py b/mammoth/modules/adapters.py new file mode 100644 index 00000000..65a84800 --- /dev/null +++ b/mammoth/modules/adapters.py @@ -0,0 +1,215 @@ +""" +Inject small bottleneck layers with residual connection. +Can be applied during main training, +or injected into an already trained network to adapt it for a new task. +""" + +import torch +import torch.nn as nn +from collections import defaultdict +from typing import Union, Set +from functools import partial + +from x_transformers.x_transformers import ( + AttentionLayers, + FeedForward, + Residual, + SimpleRMSNorm, +) + + +class FeedForwardAdapterLayer(nn.Module): + """A separate adapter layer injected after a FeedForward. Has its own norms.""" + def __init__(self, dim, pre_norm=True, sandwich_norm=False, **kwargs): + super().__init__() + norm_fn = partial(SimpleRMSNorm, dim) + self.pre_branch_norm = norm_fn() if pre_norm else None + self.post_branch_norm = norm_fn() if sandwich_norm else None + self.post_main_norm = norm_fn() if not pre_norm else None + self.ff = FeedForward(dim, **kwargs) + self.residual = Residual(dim, scale_residual=False) + + @property + def is_wrapper(self): + return False + + def as_layer_struct(self): + return nn.ModuleList([ + nn.ModuleList([ + self.pre_branch_norm, + self.post_branch_norm, + self.post_main_norm, + ]), + self.ff, + self.residual, + ]) + + +class LoraAdapterLayer(nn.Module): + """A LoRA adapter layer wrapping a FeedForward. No additional norms.""" + def __init__( + self, + dim, + dim_out=None, + r=8, + alpha=None + ): + super().__init__() + dim_out = dim_out if dim_out is not None else dim + alpha = alpha if alpha is not None else r + self.scale = alpha / r + + self.A = nn.Parameter(torch.randn(dim, r)) + self.B = nn.Parameter(torch.zeros(r, dim_out)) + self.wrapped_base_layer = None + + @property + def is_wrapper(self): + return True + + def wrap(self, base_layer): + self.wrapped_base_layer = base_layer + return self + + @property + def weight(self): + return (self.A @ self.B) * self.scale + + def forward(self, x): + return (x @ self.weight) + self.wrapped_base_layer.forward(x) + + +AdapterLayer = Union[FeedForwardAdapterLayer, LoraAdapterLayer] + + +class Adapter(nn.Module): + """ + A container for one or several AdapterLayers, + together with layer indices for injecting into the base network. + """ + + def __init__(self, adapter_group: str, sub_id: str): + super().__init__() + self.adapter_group = adapter_group + self.sub_id = sub_id + self.name = self._name(adapter_group, sub_id) + # mapping layer_idx -> ModuleList of AdapterLayer to inject at that layer + self.adapter_layers = nn.ModuleDict() + self._adapted_layer_indices: Set[int] = set() + + @staticmethod + def _name(adapter_group: str, sub_id: str) -> str: + assert isinstance(adapter_group, str), f'Expecting str, not {adapter_group}' + assert isinstance(sub_id, str), f'Expecting str, not {sub_id}' + return f'adapter_{adapter_group}_{sub_id}' + + def add_layer(self, layer_idx: int, adapter_layer: AdapterLayer): + self._adapted_layer_indices.add(layer_idx) + layer_idx_str = f'layer{layer_idx}' + if layer_idx_str not in self.adapter_layers: + self.adapter_layers[layer_idx_str] = nn.ModuleList() + self.adapter_layers[layer_idx_str].append(adapter_layer) + + def get_layers(self): + return self.adapter_layers.items() + + def __repr__(self): + return f'' + + +class AdaptedAttentionLayers(AttentionLayers): + """ + Extends an x_transformers.AttentionLayers block + with the ability to inject additional layers + or dymically wrap layers in LoRA wrappers. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._base_layer_types = tuple(self.layer_types) + self._base_layers = nn.ModuleList(self.layers) + self._base_layer_dropouts = tuple(self.layer_dropouts) + + def freeze_base_model(self, requires_grad=False): + for name, p in self._base_layers.named_parameters(): + p.requires_grad = requires_grad + + def get_adapter(self, adapter_group: str, sub_id: str): + name = Adapter._name(adapter_group, sub_id) + return self.adapters[name] + + def add_adapter(self, adapter: Adapter): + if adapter.name in self.adapters: + raise ValueError(f'Duplicate Adapter "{adapter.name}"') + max_layer_index = max(adapter._adapted_layer_indices) + n_feedforwards = sum(layer_type == 'f' for layer_type in self._base_layer_types) + if max_layer_index >= n_feedforwards: + raise ValueError( + f'Invalid layer number {max_layer_index} in Adapter "{adapter.name}". ' + f'There are ony {n_feedforwards} layers.' + ) + self.adapters[adapter.name] = adapter + + def deactivate_adapters(self): + self.active = set() + + def activate_adapter(self, adapter_group: str, sub_id: str): + name = Adapter._name(adapter_group, sub_id) + if name in self.adapters: + self.active.add(name) + + def _merge_active_adapters(self): + """ + Returns a single mapping layer_idx -> list of AdapterLayer, + containing the layers of all currently active adapters + """ + active_adapters = [ + adapter for name, adapter in self.adapters.items() + if name in self.active + ] + merged = defaultdict(list) + for adapter in active_adapters: + for layer_idx, layers in adapter.get_layers(): + merged[layer_idx].extend(layers) + return merged + + def _inject_adapters(self): + adapted_layer_types = [] + adapted_layers = nn.ModuleList() + adapted_layer_dropouts = [] + i = 0 + for layer_type, layer_struct, layer_dropout in zip( + self._base_layer_types, + self._base_layers, + self._base_layer_dropouts, + ): + adapter_layers_by_index = self._merge_active_adapters() + if layer_type == 'f': + # Adapters apply to feedforward layers + adapter_layers = adapter_layers_by_index[i] + for adapter_layer in adapter_layers: + if adapter_layer.is_wrapper: + # LoRA wraps the base ff + adapted_layer_types.append('f') + adapted_layers.append(adapter_layer.wrap(layer_struct)) + adapted_layer_dropouts.append(layer_dropout) + else: + # FeedForwards are injected after the base ff + adapted_layer_types.append('f') + adapted_layer_types.append('f') + adapted_layers.append(layer_struct) + adapted_layers.append(adapter_layer.as_layer_struct()) + adapted_layer_dropouts.append(layer_dropout) + adapted_layer_dropouts.append(layer_dropout) + i += 1 + else: + # Attetion layers are unmodified + adapted_layer_types.append(layer_type) + adapted_layers.append(layer_struct) + adapted_layer_dropouts.append(layer_dropout) + self.layer_types = adapted_layer_types + self.layers = adapted_layers + self.layer_dropouts = adapted_layer_dropouts + + def forward(self, *args, **kwargs): + self._inject_adapters() + return super().forward(*args, **kwargs) diff --git a/mammoth/modules/embeddings.py b/mammoth/modules/embeddings.py deleted file mode 100644 index 7f4cbbca..00000000 --- a/mammoth/modules/embeddings.py +++ /dev/null @@ -1,416 +0,0 @@ -""" Embeddings module """ -import math -import warnings - -import torch -import torch.nn as nn - -from mammoth.modules.util_class import Elementwise -# from mammoth.utils.logging import logger - -import torch.nn.functional as F -# import bitsandbytes as bnb - - -class SequenceTooLongError(Exception): - pass - - -class PositionalEncoding(nn.Module): - """Sinusoidal positional encoding for non-recurrent neural networks. - - Implementation based on "Attention Is All You Need" - :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` - - Args: - dropout (float): dropout parameter - dim (int): embedding size - """ - - def __init__(self, dropout, dim, max_len=5000): - if dim % 2 != 0: - raise ValueError("Cannot use sin/cos positional encoding with odd dim (got dim={:d})".format(dim)) - pe = torch.zeros(max_len, dim) - position = torch.arange(0, max_len).unsqueeze(1) - div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))) - pe[:, 0::2] = torch.sin(position.float() * div_term) - pe[:, 1::2] = torch.cos(position.float() * div_term) - pe = pe.unsqueeze(1) - super(PositionalEncoding, self).__init__() - self.register_buffer('pe', pe) - self.dropout = nn.Dropout(p=dropout) - self.dim = dim - - def forward(self, emb, step=None): - """Embed inputs. - - Args: - emb (FloatTensor): Sequence of word vectors - ``(seq_len, batch_size, self.dim)`` - step (int or NoneType): If stepwise (``seq_len = 1``), use - the encoding for this position. - """ - - emb = emb * math.sqrt(self.dim) - step = step or 0 - if self.pe.size(0) < step + emb.size(0): - raise SequenceTooLongError( - f"Sequence is {emb.size(0) + step} but PositionalEncoding is" - f" limited to {self.pe.size(0)}. See max_len argument." - ) - emb = emb + self.pe[step:(emb.size(0) + step)] - emb = self.dropout(emb) - return emb - - -class Embeddings(nn.Module): - """Words embeddings for encoder/decoder. - - Additionally includes ability to add input features - based on "Linguistic Input Features Improve Neural Machine Translation" - :cite:`sennrich2016linguistic`. - - - .. mermaid:: - - graph LR - A[Input] - C[Feature 1 Lookup] - A-->B[Word Lookup] - A-->C - A-->D[Feature N Lookup] - B-->E[MLP/Concat] - C-->E - D-->E - E-->F[Output] - - Args: - word_vec_size (int): size of the dictionary of embeddings. - word_padding_idx (int): padding index for words in the embeddings. - feat_padding_idx (List[int]): padding index for a list of features - in the embeddings. - word_vocab_size (int): size of dictionary of embeddings for words. - feat_vocab_sizes (List[int], optional): list of size of dictionary - of embeddings for each feature. - position_encoding (bool): see :class:`~mammoth.modules.PositionalEncoding` - feat_merge (string): merge action for the features embeddings: - concat, sum or mlp. - feat_vec_exponent (float): when using `-feat_merge concat`, feature - embedding size is N^feat_dim_exponent, where N is the - number of values the feature takes. - feat_vec_size (int): embedding dimension for features when using - `-feat_merge mlp` - dropout (float): dropout probability. - freeze_word_vecs (bool): freeze weights of word vectors. - """ - - def __init__( - self, - word_vec_size, - word_vocab_size, - word_padding_idx, - position_encoding=False, - feat_merge="concat", - feat_vec_exponent=0.7, - feat_vec_size=-1, - feat_padding_idx=[], - feat_vocab_sizes=[], - dropout=0, - freeze_word_vecs=False, - enable_embeddingless=False - ): - self._validate_args(feat_merge, feat_vocab_sizes, feat_vec_exponent, feat_vec_size, feat_padding_idx) - - if feat_padding_idx is None: - feat_padding_idx = [] - self.word_padding_idx = word_padding_idx - - self.word_vec_size = word_vec_size - - # Dimensions and padding for constructing the word embedding matrix - vocab_sizes = [word_vocab_size] - emb_dims = [word_vec_size] - pad_indices = [word_padding_idx] - - # Dimensions and padding for feature embedding matrices - # (these have no effect if feat_vocab_sizes is empty) - if feat_merge == 'sum': - feat_dims = [word_vec_size] * len(feat_vocab_sizes) - elif feat_vec_size > 0: - feat_dims = [feat_vec_size] * len(feat_vocab_sizes) - else: - feat_dims = [int(vocab**feat_vec_exponent) for vocab in feat_vocab_sizes] - vocab_sizes.extend(feat_vocab_sizes) - emb_dims.extend(feat_dims) - pad_indices.extend(feat_padding_idx) - - # The embedding matrix look-up tables. The first look-up table - # is for words. Subsequent ones are for features, if any exist. - emb_params = zip(vocab_sizes, emb_dims, pad_indices) - - emb_params = zip(vocab_sizes, emb_dims, pad_indices) - if enable_embeddingless is False: - embeddings = [nn.Embedding(vocab, dim, padding_idx=pad) for vocab, dim, pad in emb_params] - - else: - - def create_embeddingless(vocab, dim, padding_idx): - one_hot_matrix = F.one_hot(torch.arange(vocab)).float() - one_hot_embed = torch.cat((one_hot_matrix, torch.zeros((vocab, dim - vocab))), dim=1) - one_hot_embed[padding_idx] = torch.zeros(dim).unsqueeze(0) - emb = nn.Embedding(vocab, dim, padding_idx=padding_idx) - emb.weight = torch.nn.parameter.Parameter(one_hot_embed, requires_grad=False) - return emb - embeddings = [ - create_embeddingless(vocab, dim, padding_idx=pad) - for vocab, dim, pad in emb_params - ] - emb_luts = Elementwise(feat_merge, embeddings) - - # The final output size of word + feature vectors. This can vary - # from the word vector size if and only if features are defined. - # This is the attribute you should access if you need to know - # how big your embeddings are going to be. - self.embedding_size = sum(emb_dims) if feat_merge == 'concat' else word_vec_size - - # The sequence of operations that converts the input sequence - # into a sequence of embeddings. At minimum this consists of - # looking up the embeddings for each word and feature in the - # input. Model parameters may require the sequence to contain - # additional operations as well. - super(Embeddings, self).__init__() - self.make_embedding = nn.Sequential() - self.make_embedding.add_module('emb_luts', emb_luts) - - if feat_merge == 'mlp' and len(feat_vocab_sizes) > 0: - in_dim = sum(emb_dims) - mlp = nn.Sequential(nn.Linear(in_dim, word_vec_size), nn.ReLU()) - self.make_embedding.add_module('mlp', mlp) - - self.position_encoding = position_encoding - - if self.position_encoding: - pe = PositionalEncoding(dropout, self.embedding_size) - self.make_embedding.add_module('pe', pe) - - if freeze_word_vecs: - self.word_lut.weight.requires_grad = False - - def _validate_args(self, feat_merge, feat_vocab_sizes, feat_vec_exponent, feat_vec_size, feat_padding_idx): - if feat_merge == "sum": - # features must use word_vec_size - if feat_vec_exponent != 0.7: - warnings.warn("Merging with sum, but got non-default feat_vec_exponent. It will be unused.") - if feat_vec_size != -1: - warnings.warn("Merging with sum, but got non-default feat_vec_size. It will be unused.") - elif feat_vec_size > 0: - # features will use feat_vec_size - if feat_vec_exponent != -1: - warnings.warn( - "Not merging with sum and positive " - "feat_vec_size, but got non-default " - "feat_vec_exponent. It will be unused." - ) - else: - if feat_vec_exponent <= 0: - raise ValueError( - "Using feat_vec_exponent to determine " - "feature vec size, but got feat_vec_exponent " - "less than or equal to 0." - ) - n_feats = len(feat_vocab_sizes) - if n_feats != len(feat_padding_idx): - raise ValueError( - "Got unequal number of feat_vocab_sizes and " - "feat_padding_idx ({:d} != {:d})".format(n_feats, len(feat_padding_idx)) - ) - - @property - def word_lut(self): - """Word look-up table.""" - return self.make_embedding[0][0] - - @property - def emb_luts(self): - """Embedding look-up table.""" - return self.make_embedding[0] - - def load_pretrained_vectors(self, emb_file): - """Load in pretrained embeddings. - - Args: - emb_file (str) : path to torch serialized embeddings - """ - - if emb_file: - pretrained = torch.load(emb_file) - pretrained_vec_size = pretrained.size(1) - if self.word_vec_size > pretrained_vec_size: - self.word_lut.weight.data[:, :pretrained_vec_size] = pretrained - elif self.word_vec_size < pretrained_vec_size: - self.word_lut.weight.data.copy_(pretrained[:, : self.word_vec_size]) - else: - self.word_lut.weight.data.copy_(pretrained) - - def forward(self, source, step=None): - """Computes the embeddings for words and features. - - Args: - source (LongTensor): index tensor ``(len, batch, nfeat)`` - - Returns: - FloatTensor: Word embeddings ``(len, batch, embedding_size)`` - """ - - if self.position_encoding: - for i, module in enumerate(self.make_embedding._modules.values()): - if i == len(self.make_embedding._modules.values()) - 1: - source = module(source, step=step) - else: - source = module(source) - else: - source = self.make_embedding(source) - - return source - - def update_dropout(self, dropout): - if self.position_encoding: - self._modules['make_embedding'][1].dropout.p = dropout - - -class PluggableEmbeddings(nn.ModuleDict): - """ - Wraps multiple embeddings, - allowing any of them to be plugged in by calling activate. - This is necessary to decouple encoders/decoder from embeddings: - it is possible to e.g. share a single encoder with multiple source - languages each having their own embeddings. - """ - - def __init__(self, embedding_dict): - super().__init__() - for key, embeddings in embedding_dict.items(): - self.add_module(f'embeddings_{key}', embeddings) - self.active_key = None - - def activate(self, key): - assert f'embeddings_{key}' in self, f'Embeddings "embeddings_{key}" not in {self.keys()}' - self.active_key = key - - @property - def _active_embeddings(self): - if self.active_key is None: - raise Exception('Must activate PluggableEmbeddings before forward') - active_embeddings = self[f'embeddings_{self.active_key}'] - # print(f'plugging in embeddings_{self.active_key}') - return active_embeddings - - def forward(self, source, step=None): - return self._active_embeddings.forward(source, step=step) - - @property - def word_padding_idx(self): - return self._active_embeddings.word_padding_idx - - -# Some utilitary functions for pretrained embeddings - - -def read_embeddings(path, skip_lines=0, filter_set=None): - """ - Read an embeddings file in the glove format. - """ - embs = dict() - total_vectors_in_file = 0 - with open(path, 'rb') as f: - for i, line in enumerate(f): - if i < skip_lines: - continue - if not line: - break - if len(line) == 0: - # is this reachable? - continue - - l_split = line.decode('utf8').strip().split(' ') - if len(l_split) == 2: - continue - total_vectors_in_file += 1 - if filter_set is not None and l_split[0] not in filter_set: - continue - embs[l_split[0]] = [float(em) for em in l_split[1:]] - return embs, total_vectors_in_file - - -def calc_vocab_load_stats(vocab, loaded_embed_dict): - matching_count = len(set(vocab.stoi.keys()) & set(loaded_embed_dict.keys())) - missing_count = len(vocab) - matching_count - percent_matching = matching_count / len(vocab) * 100 - return matching_count, missing_count, percent_matching - - -def convert_to_torch_tensor(word_to_float_list_dict, vocab): - dim = len(next(iter(word_to_float_list_dict.values()))) - tensor = torch.zeros((len(vocab), dim)) - for word, values in word_to_float_list_dict.items(): - tensor[vocab.stoi[word]] = torch.Tensor(values) - return tensor - -# FIXME: seems it got nuked during the great refactoring of data -# def prepare_pretrained_embeddings(opts, fields): -# if all([opts.both_embeddings is None, opts.src_embeddings is None, opts.tgt_embeddings is None]): -# return -# -# assert ( -# opts.save_data -# ), "-save_data is required when using \ -# pretrained embeddings." -# -# vocs = [] -# for side in ['src', 'tgt']: -# try: -# vocab = fields[side].base_field.vocab -# except AttributeError: -# vocab = fields[side].vocab -# vocs.append(vocab) -# enc_vocab, dec_vocab = vocs -# -# skip_lines = 1 if opts.embeddings_type == "word2vec" else 0 -# if opts.both_embeddings is not None: -# set_of_src_and_tgt_vocab = set(enc_vocab.stoi.keys()) | set(dec_vocab.stoi.keys()) -# logger.info("Reading encoder and decoder embeddings from {}".format(opts.both_embeddings)) -# src_vectors, total_vec_count = read_embeddings(opts.both_embeddings, skip_lines, set_of_src_and_tgt_vocab) -# tgt_vectors = src_vectors -# logger.info("\tFound {} total vectors in file".format(total_vec_count)) -# else: -# if opts.src_embeddings is not None: -# logger.info("Reading encoder embeddings from {}".format(opts.src_embeddings)) -# src_vectors, total_vec_count = read_embeddings(opts.src_embeddings, skip_lines, filter_set=enc_vocab.stoi) -# logger.info("\tFound {} total vectors in file.".format(total_vec_count)) -# else: -# src_vectors = None -# if opts.tgt_embeddings is not None: -# logger.info("Reading decoder embeddings from {}".format(opts.tgt_embeddings)) -# tgt_vectors, total_vec_count = read_embeddings(opts.tgt_embeddings, skip_lines, filter_set=dec_vocab.stoi) -# logger.info("\tFound {} total vectors in file".format(total_vec_count)) -# else: -# tgt_vectors = None -# logger.info("After filtering to vectors in vocab:") -# if opts.src_embeddings is not None or opts.both_embeddings is not None: -# logger.info("\t* enc: %d match, %d missing, (%.2f%%)" % calc_vocab_load_stats(enc_vocab, src_vectors)) -# if opts.tgt_embeddings is not None or opts.both_embeddings is not None: -# logger.info("\t* dec: %d match, %d missing, (%.2f%%)" % calc_vocab_load_stats(dec_vocab, tgt_vectors)) -# -# # Write to file -# enc_output_file = opts.save_data + ".enc_embeddings.pt" -# dec_output_file = opts.save_data + ".dec_embeddings.pt" -# if opts.src_embeddings is not None or opts.both_embeddings is not None: -# logger.info("\nSaving encoder embeddings as:\n\t* enc: %s" % enc_output_file) -# torch.save(convert_to_torch_tensor(src_vectors, enc_vocab), enc_output_file) -# # set the opts in place -# opts.pre_word_vecs_enc = enc_output_file -# if opts.tgt_embeddings is not None or opts.both_embeddings is not None: -# logger.info("\nSaving decoder embeddings as:\n\t* dec: %s" % dec_output_file) -# torch.save(convert_to_torch_tensor(tgt_vectors, dec_vocab), dec_output_file) -# # set the opts in place -# opts.pre_word_vecs_dec = dec_output_file diff --git a/mammoth/modules/layer_stack.py b/mammoth/modules/layer_stack.py new file mode 100644 index 00000000..8d2b59fe --- /dev/null +++ b/mammoth/modules/layer_stack.py @@ -0,0 +1,66 @@ +from torch import nn +from typing import List, Sequence, Optional, Tuple + +from mammoth.modules.adapters import AdaptedAttentionLayers + + +class AdaptedAttentionLayersStack(nn.Module): + """ + Wrapper that allows stacking multiple AdaptedAttentionLayers. + Represents one particular stacking: does not allow switching out entire layers + (but does delegate the switching out of adapters to its components) + """ + def __init__(self, attention_layers_stack: Sequence[AdaptedAttentionLayers]): + super().__init__() + self.attention_layers_stack = nn.ModuleList(attention_layers_stack) + assert len(set(attention_layers.dim for attention_layers in attention_layers_stack)) == 1, \ + 'All AdaptedAttentionLayers must have the same dimension' + + def forward(self, x, return_hiddens=False, **kwargs): + all_intermediates = [] + for attention_layers in self.attention_layers_stack: + if return_hiddens: + x, intermediates = attention_layers.forward(x, return_hiddens=True, **kwargs) + all_intermediates.append(intermediates) + else: + x = attention_layers.forward(x, return_hiddens=False, **kwargs) + if return_hiddens: + return x, all_intermediates + else: + return x + + def freeze_base_model(self, requires_grad=False): + for attention_layers in self.attention_layers_stack: + attention_layers.freeze_base_model(requires_grad=requires_grad) + + def deactivate_adapters(self): + for attention_layers in self.attention_layers_stack: + attention_layers.deactivate_adapters() + + def activate_adapter(self, layer_stack_index: int, adapter_group: str, sub_id: str): + attention_layers = self.attention_layers_stack[layer_stack_index] + attention_layers.activate_adapter(adapter_group, sub_id) + + @property + def dim(self): + return self.attention_layers_stack[0].dim + + @property + def disable_abs_pos_emb(self): + return self.attention_layers_stack[0].disable_abs_pos_emb + + +class StackXcoder(nn.ModuleDict): + """ + Switches between different AdaptedAttentionLayersStacks depending on the task. + """ + # TransformerWrapper wraps an AttentionLayers in embeddings and some other functionality. + # We use one TransformerWrapper per task. + def activate(self, task_id: str, adapter_ids: Optional[List[Tuple[int, str, str]]]): + transformer_wrapper = self[task_id] + attention_layers_stack = transformer_wrapper.attn_layers + if adapter_ids: + attention_layers_stack.deactivate_adapters() + for layer_stack_index, adapter_group, sub_id in adapter_ids: + attention_layers_stack.activate_adapter(layer_stack_index, adapter_group, sub_id) + return transformer_wrapper diff --git a/mammoth/opts.py b/mammoth/opts.py index bbbaafcd..f57bf157 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -352,10 +352,9 @@ def model_opts(parser): ) group.add('--heads', '-heads', type=int, default=8, help='Number of heads for transformer self-attention') group.add( - '--transformer_ff', '-transformer_ff', type=int, default=2048, help='Size of hidden transformer feed-forward' + '--ff_mult', '-ff_mult', type=int, default=4, + help='Size of hidden transformer feed-forward, as a factor of model_dim' ) - # TODO is this actually in use? - group.add('--aan_useffn', '-aan_useffn', action="store_true", help='Turn on the FFN layer in the AAN decoder') # Alignement options # TODO is this actually in use? diff --git a/mammoth/train_single.py b/mammoth/train_single.py index cd6b4156..8167a58e 100644 --- a/mammoth/train_single.py +++ b/mammoth/train_single.py @@ -108,7 +108,7 @@ def main( # Build model. - model, generators_md = build_model(model_opts, opts, vocabs_dict, task_queue_manager, checkpoint) + model = build_model(model_opts, opts, vocabs_dict, task_queue_manager, checkpoint) logger.info("{} - Init model".format(device_context.id)) if device_context.is_distributed(): @@ -138,7 +138,6 @@ def main( optim, task_queue_manager=task_queue_manager, model_saver=model_saver, - generators_md=generators_md, ) logger.info("{} - Trainer built".format(device_context.id)) diff --git a/mammoth/trainer.py b/mammoth/trainer.py index cb1a0b50..0319d052 100644 --- a/mammoth/trainer.py +++ b/mammoth/trainer.py @@ -37,7 +37,6 @@ def build_trainer( optim, task_queue_manager, model_saver=None, - generators_md=None, ): """ Simplify `Trainer` creation based on user `opts`s* @@ -58,6 +57,9 @@ def build_trainer( logger.info("BUILD TRAINER") for (side, lang, component_id, tgt_vocab) in task_queue_manager.get_my_vocabs('tgt', vocabs_dict): + # FIXME: OpenNMT losses require a separate generator, which is not available in x_transformers + # Just rip it out and use F.cross_entropy? Maybe label smoothing? + # Or get the necessary components from the model to create a generator? generator = generators_md[f'generator_{lang}'] train_loss_md.add_module( f'trainloss{lang}', diff --git a/mammoth/utils/module_splitter.py b/mammoth/utils/module_splitter.py index 6e190a3a..d3c6a8d3 100644 --- a/mammoth/utils/module_splitter.py +++ b/mammoth/utils/module_splitter.py @@ -11,6 +11,7 @@ def _combine_ordered_dicts(input_dicts: Dict[str, OrderedDict]) -> OrderedDict: def explode_model(full_ab_model): + # FIXME: saving and loading are broken encoder = full_ab_model["whole_model"].encoder decoder = full_ab_model["whole_model"].decoder