Skip to content

Commit

Permalink
Pass kwargs also to TransformerWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Oct 7, 2024
1 parent 203d4d5 commit 83c8c26
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
39 changes: 29 additions & 10 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@
from mammoth.utils.logging import logger
from mammoth.utils.misc import use_gpu

TRANSFORMER_WRAPPER_OPTS = {
'post_emb_norm',
'tie_embedding',
'use_abs_pos_emb',
'scaled_sinu_pos_emb',
'emb_frac_gradient',
}


def _combine_ordered_dicts(input_dicts: Dict[str, OrderedDict]) -> OrderedDict:
result = []
Expand Down Expand Up @@ -59,6 +67,7 @@ def get_attention_layers_kwargs(
is_last = layer_stack_index == len(depths) - 1
pre_norm_has_final_norm = is_last
kwargs = model_opts.x_transformers_opts if model_opts.x_transformers_opts else dict()
kwargs = {key: val for key, val in kwargs.items() if key not in TRANSFORMER_WRAPPER_OPTS}
kwargs.update({
'dim': model_opts.model_dim,
'depth': depth,
Expand All @@ -69,6 +78,21 @@ def get_attention_layers_kwargs(
return kwargs


def get_transformer_wrapper_kwargs(
side: Side,
model_opts,
):
"""Return arguments for x_transformers.TransformerWrapper"""
assert side in {Side.encoder, Side.decoder}, f'Invalid side "{side}"'
kwargs = model_opts.x_transformers_opts if model_opts.x_transformers_opts else dict()
kwargs = {key: val for key, val in kwargs.items() if key in TRANSFORMER_WRAPPER_OPTS}
max_seq_len = 0 if model_opts.max_length is None else model_opts.max_length
kwargs.update({
'max_seq_len': max_seq_len,
})
return kwargs


def build_xcoder(
side: Side,
model_opts,
Expand Down Expand Up @@ -196,6 +220,10 @@ def build_xcoder(
if single_task:
tasks = [task for task in tasks if task.corpus_id == single_task]
transformer_wrappers = dict()
transformer_wrapper_kwargs = get_transformer_wrapper_kwargs(
side=side,
model_opts=model_opts,
)
for task in tasks:
if side == Side.encoder:
xcoder_ids = task.encoder_id
Expand All @@ -211,22 +239,13 @@ def build_xcoder(

lang = task.src_lang if side == Side.encoder else task.tgt_lang
vocab = vocabs_dict[(side_alt_str, lang)]
max_seq_len = 0 if model_opts.max_length is None else model_opts.max_length
post_emb_norm = True
tie_embedding = True
use_abs_pos_emb = True
emb_frac_gradient = 1.
# Using custom extended TransformerWrapper to allow passing in an embedding
transformer_wrapper = TransformerWrapper(
num_tokens=len(vocab),
max_seq_len=max_seq_len,
attn_layers=adapted_attention_layers_stack,
emb_dim=model_opts.model_dim,
post_emb_norm=post_emb_norm,
tie_embedding=tie_embedding,
use_abs_pos_emb=use_abs_pos_emb,
emb_frac_gradient=emb_frac_gradient,
token_emb=token_embs[lang],
**transformer_wrapper_kwargs,
)
transformer_wrappers[task.corpus_id] = transformer_wrapper

Expand Down
6 changes: 3 additions & 3 deletions mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,9 @@ def _gradient_accumulation(

with torch.cuda.amp.autocast(enabled=self.optim.amp):
logits, decoder_output = self.model(
rearrange(src, 't b 1 -> b t'),
rearrange(decoder_input, 't b 1 -> b t'),
rearrange(src_mask, 't b -> b t'),
src=rearrange(src, 't b 1 -> b t'),
decoder_input=rearrange(decoder_input, 't b 1 -> b t'),
src_mask=rearrange(src_mask, 't b -> b t'),
metadata=metadata,
)
logits = rearrange(logits, 'b t i -> t b i')
Expand Down

0 comments on commit 83c8c26

Please sign in to comment.