From ac59442c00bdbfdea894853fb387ea6bad996890 Mon Sep 17 00:00:00 2001 From: Joseph Attieh Date: Thu, 25 Jan 2024 13:48:31 +0200 Subject: [PATCH] Updated optimizer to not be created if embeddings do not need training --- mammoth/model_builder.py | 11 +++++------ mammoth/modules/embeddings.py | 7 ++++--- mammoth/utils/optimizers.py | 5 +++-- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/mammoth/model_builder.py b/mammoth/model_builder.py index a42ed68d..df014e1c 100644 --- a/mammoth/model_builder.py +++ b/mammoth/model_builder.py @@ -40,7 +40,6 @@ def build_embeddings(opts, vocab, for_encoder=True): opts.word_padding_idx = word_padding_idx freeze_word_vecs = opts.freeze_word_vecs_enc if for_encoder else opts.freeze_word_vecs_dec - emb = Embeddings( word_vec_size=opts.model_dim, position_encoding=opts.position_encoding, @@ -48,7 +47,7 @@ def build_embeddings(opts, vocab, for_encoder=True): word_padding_idx=word_padding_idx, word_vocab_size=len(vocab), freeze_word_vecs=freeze_word_vecs, - embeddingless=opts.embeddingless + enable_embeddingless=opts.enable_embeddingless ) return emb @@ -369,13 +368,13 @@ def build_only_enc(model_opts, src_emb, task_queue_manager): if model_opts.param_init != 0.0: for name, module in encoder.named_modules(): for p in module.parameters(): - if not("embedding" in name and model_opts.embeddingless is True): + if not("embedding" in name and model_opts.enable_embeddingless is True): p.data.uniform_(-model_opts.param_init, model_opts.param_init) if model_opts.param_init_glorot: for name, module in encoder.named_modules(): for p in module.parameters(): - if not("embedding" in name and model_opts.embeddingless is True): + if not("embedding" in name and model_opts.enable_embeddingless is True): if p.dim() > 1: xavier_uniform_(p, gain=nn.init.calculate_gain('relu')) if model_opts.model_dtype == 'fp16' and model_opts.optim == 'fusedadam': @@ -390,12 +389,12 @@ def build_only_dec(model_opts, tgt_emb, task_queue_manager): if model_opts.param_init != 0.0: for name, module in decoder.named_modules(): for p in module.parameters(): - if not("embedding" in name and model_opts.embeddingless is True): + if not("embedding" in name and model_opts.embeddenable_embeddinglessingless is True): p.data.uniform_(-model_opts.param_init, model_opts.param_init) if model_opts.param_init_glorot: for name, module in decoder.named_modules(): for p in module.parameters(): - if not("embedding" in name and model_opts.embeddingless is True): + if not("embedding" in name and model_opts.enable_embeddingless is True): if p.dim() > 1: xavier_uniform_(p, gain=nn.init.calculate_gain('relu')) diff --git a/mammoth/modules/embeddings.py b/mammoth/modules/embeddings.py index c70c23e6..63ad8fba 100644 --- a/mammoth/modules/embeddings.py +++ b/mammoth/modules/embeddings.py @@ -154,15 +154,16 @@ def __init__( embeddings = [nn.Embedding(vocab, dim, padding_idx=pad) for vocab, dim, pad in emb_params] else: - def create_embeddingless(vocab, dim, padding_idx, sparse): + print("CREATING EMBEDDINGLESS") + def create_embeddingless(vocab, dim, padding_idx): one_hot_matrix = F.one_hot(torch.arange(vocab)).float() one_hot_embed = torch.cat((one_hot_matrix, torch.zeros((vocab, dim - vocab))),dim=1) one_hot_embed[padding_idx] = torch.zeros(dim).unsqueeze(0) - emb = nn.Embedding(vocab, dim, padding_idx=padding_idx, sparse=sparse) + emb = nn.Embedding(vocab, dim, padding_idx=padding_idx) emb.weight = torch.nn.parameter.Parameter(one_hot_embed, requires_grad=False) return emb embeddings = [ - create_embeddingless(vocab, dim, padding_idx=pad, sparse=sparse) + create_embeddingless(vocab, dim, padding_idx=pad) for vocab, dim, pad in emb_params ] emb_luts = Elementwise(feat_merge, embeddings) diff --git a/mammoth/utils/optimizers.py b/mammoth/utils/optimizers.py index 57e20ebf..43b7c6eb 100644 --- a/mammoth/utils/optimizers.py +++ b/mammoth/utils/optimizers.py @@ -32,8 +32,9 @@ def attention_bridge_optimizer(model, task_queue_manager, base_optimizer): params.append(param) if name in suboptimizers: raise Exception(f'Trying to create second optimizer for "{name}"') - optimizer = base_optimizer(params) - suboptimizers[name] = optimizer + if len(params)!=0: + optimizer = base_optimizer(params) + suboptimizers[name] = optimizer for generator_id in task_queue_manager.get_generators(): generator = model.generator[f'generator_{generator_id}']