Skip to content

Commit

Permalink
Updated optimizer to not be created if embeddings do not need training
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Attieh authored and Joseph Attieh committed Jan 25, 2024
1 parent 68f7e60 commit ac59442
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
11 changes: 5 additions & 6 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,14 @@ def build_embeddings(opts, vocab, for_encoder=True):
opts.word_padding_idx = word_padding_idx

freeze_word_vecs = opts.freeze_word_vecs_enc if for_encoder else opts.freeze_word_vecs_dec

emb = Embeddings(
word_vec_size=opts.model_dim,
position_encoding=opts.position_encoding,
dropout=opts.dropout[0] if type(opts.dropout) is list else opts.dropout,
word_padding_idx=word_padding_idx,
word_vocab_size=len(vocab),
freeze_word_vecs=freeze_word_vecs,
embeddingless=opts.embeddingless
enable_embeddingless=opts.enable_embeddingless
)
return emb

Expand Down Expand Up @@ -369,13 +368,13 @@ def build_only_enc(model_opts, src_emb, task_queue_manager):
if model_opts.param_init != 0.0:
for name, module in encoder.named_modules():
for p in module.parameters():
if not("embedding" in name and model_opts.embeddingless is True):
if not("embedding" in name and model_opts.enable_embeddingless is True):
p.data.uniform_(-model_opts.param_init, model_opts.param_init)

if model_opts.param_init_glorot:
for name, module in encoder.named_modules():
for p in module.parameters():
if not("embedding" in name and model_opts.embeddingless is True):
if not("embedding" in name and model_opts.enable_embeddingless is True):
if p.dim() > 1:
xavier_uniform_(p, gain=nn.init.calculate_gain('relu'))
if model_opts.model_dtype == 'fp16' and model_opts.optim == 'fusedadam':
Expand All @@ -390,12 +389,12 @@ def build_only_dec(model_opts, tgt_emb, task_queue_manager):
if model_opts.param_init != 0.0:
for name, module in decoder.named_modules():
for p in module.parameters():
if not("embedding" in name and model_opts.embeddingless is True):
if not("embedding" in name and model_opts.embeddenable_embeddinglessingless is True):
p.data.uniform_(-model_opts.param_init, model_opts.param_init)
if model_opts.param_init_glorot:
for name, module in decoder.named_modules():
for p in module.parameters():
if not("embedding" in name and model_opts.embeddingless is True):
if not("embedding" in name and model_opts.enable_embeddingless is True):
if p.dim() > 1:
xavier_uniform_(p, gain=nn.init.calculate_gain('relu'))

Expand Down
7 changes: 4 additions & 3 deletions mammoth/modules/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,16 @@ def __init__(
embeddings = [nn.Embedding(vocab, dim, padding_idx=pad) for vocab, dim, pad in emb_params]

else:
def create_embeddingless(vocab, dim, padding_idx, sparse):
print("CREATING EMBEDDINGLESS")
def create_embeddingless(vocab, dim, padding_idx):
one_hot_matrix = F.one_hot(torch.arange(vocab)).float()
one_hot_embed = torch.cat((one_hot_matrix, torch.zeros((vocab, dim - vocab))),dim=1)
one_hot_embed[padding_idx] = torch.zeros(dim).unsqueeze(0)
emb = nn.Embedding(vocab, dim, padding_idx=padding_idx, sparse=sparse)
emb = nn.Embedding(vocab, dim, padding_idx=padding_idx)
emb.weight = torch.nn.parameter.Parameter(one_hot_embed, requires_grad=False)
return emb
embeddings = [
create_embeddingless(vocab, dim, padding_idx=pad, sparse=sparse)
create_embeddingless(vocab, dim, padding_idx=pad)
for vocab, dim, pad in emb_params
]
emb_luts = Elementwise(feat_merge, embeddings)
Expand Down
5 changes: 3 additions & 2 deletions mammoth/utils/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ def attention_bridge_optimizer(model, task_queue_manager, base_optimizer):
params.append(param)
if name in suboptimizers:
raise Exception(f'Trying to create second optimizer for "{name}"')
optimizer = base_optimizer(params)
suboptimizers[name] = optimizer
if len(params)!=0:
optimizer = base_optimizer(params)
suboptimizers[name] = optimizer

for generator_id in task_queue_manager.get_generators():
generator = model.generator[f'generator_{generator_id}']
Expand Down

0 comments on commit ac59442

Please sign in to comment.