Skip to content

Commit

Permalink
{enc,dec}_rnn_size merged
Browse files Browse the repository at this point in the history
  • Loading branch information
Mickus Timothee committed Sep 19, 2023
1 parent 821d292 commit 2c622ac
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 38 deletions.
12 changes: 6 additions & 6 deletions mammoth/attention_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(
attention_heads,
hidden_ab_size,
model_type,
dec_rnn_size,
rnn_size,
ab_layer_norm=None,
):
"""Attention Heads Layer:"""
Expand All @@ -144,7 +144,7 @@ def __init__(
self.dd = u
self.model_type = model_type
if self.model_type != "text":
d = dec_rnn_size
d = rnn_size
self.ws1 = nn.Linear(d, u, bias=True)
self.ws2 = nn.Linear(u, r, bias=True)
self.relu = nn.ReLU()
Expand All @@ -161,7 +161,7 @@ def from_opt(cls, opt):
opt.ab_fixed_length,
opt.hidden_ab_size,
opt.model_type,
opt.dec_rnn_size,
opt.rnn_size,
opt.ab_layer_norm,
)

Expand Down Expand Up @@ -246,7 +246,7 @@ def forward(self, intermediate_output, encoder_output, mask=None):
@classmethod
def from_opt(cls, opt):
return cls(
opt.enc_rnn_size,
opt.rnn_size,
opt.hidden_ab_size,
opt.ab_fixed_length,
opt.ab_layer_norm,
Expand Down Expand Up @@ -278,7 +278,7 @@ def forward(self, intermediate_output, encoder_output, mask=None):
@classmethod
def from_opt(cls, opt):
return cls(
opt.enc_rnn_size,
opt.rnn_size,
opt.heads,
opt.hidden_ab_size, # d_ff
# TODO: that list indexing things seems suspicious to me...
Expand Down Expand Up @@ -315,7 +315,7 @@ def forward(self, intermediate_output, encoder_output, mask=None):
@classmethod
def from_opt(cls, opt):
return cls(
opt.enc_rnn_size,
opt.rnn_size,
opt.hidden_ab_size,
opt.ab_layer_norm,
)
Expand Down
8 changes: 4 additions & 4 deletions mammoth/decoders/layer_stack_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def from_opt(cls, opt, embeddings, task_queue_manager, is_on_top=False):
continue
stacks[module_id] = AdaptedTransformerDecoder(
n_layers,
opt.dec_rnn_size,
opt.rnn_size,
opt.heads,
opt.transformer_ff,
opt.copy_attn,
Expand All @@ -48,7 +48,7 @@ def from_opt(cls, opt, embeddings, task_queue_manager, is_on_top=False):
alignment_heads=opt.alignment_heads,
pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
layer_norm_module=(
nn.LayerNorm(opt.dec_rnn_size, eps=1e-6) if is_on_top
nn.LayerNorm(opt.rnn_size, eps=1e-6) if is_on_top
else nn.Identity()
),
)
Expand All @@ -66,7 +66,7 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack):
module_id = module_opts['id']
stacks[module_id] = AdaptedTransformerDecoder(
n_layers,
model_opt.dec_rnn_size,
model_opt.rnn_size,
model_opt.heads,
model_opt.transformer_ff,
model_opt.copy_attn,
Expand All @@ -85,7 +85,7 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack):
alignment_heads=model_opt.alignment_heads,
pos_ffn_activation_fn=model_opt.pos_ffn_activation_fn,
layer_norm_module=(
nn.LayerNorm(model_opt.dec_rnn_size, eps=1e-6) if is_on_top
nn.LayerNorm(model_opt.rnn_size, eps=1e-6) if is_on_top
else nn.Identity()
),
)
Expand Down
4 changes: 2 additions & 2 deletions mammoth/decoders/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def from_opt(cls, opt, embeddings, is_on_top=False):
"""Alternate constructor."""
return cls(
opt.dec_layers,
opt.dec_rnn_size,
opt.rnn_size,
opt.heads,
opt.transformer_ff,
opt.copy_attn,
Expand All @@ -302,7 +302,7 @@ def from_opt(cls, opt, embeddings, is_on_top=False):
alignment_heads=opt.alignment_heads,
pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
layer_norm_module=(
nn.LayerNorm(opt.dec_rnn_size, eps=1e-6) if is_on_top
nn.LayerNorm(opt.rnn_size, eps=1e-6) if is_on_top
else nn.Identity()
),
)
Expand Down
2 changes: 1 addition & 1 deletion mammoth/encoders/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def from_opt(cls, opt, embeddings, is_on_top=False):
"""Alternate constructor."""
return cls(
opt.enc_layers,
opt.enc_rnn_size,
opt.rnn_size,
opt.heads,
opt.transformer_ff,
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
Expand Down
3 changes: 1 addition & 2 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from torch.nn.init import xavier_uniform_

from collections import defaultdict
# from torchtext.legacy.data import Field

import mammoth.modules

Expand Down Expand Up @@ -404,7 +403,7 @@ def build_generator(model_opt, n_tgts, tgt_emb):
else:
gen_func = nn.LogSoftmax(dim=-1)
generator = nn.Sequential(
nn.Linear(model_opt.dec_rnn_size, n_tgts), Cast(torch.float32), gen_func
nn.Linear(model_opt.rnn_size, n_tgts), Cast(torch.float32), gen_func
)

if model_opt.share_decoder_embeddings:
Expand Down
17 changes: 8 additions & 9 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,21 +385,21 @@ def model_opts(parser):
'--encoder_type',
'-encoder_type',
type=str,
default='rnn',
choices=['rnn', 'brnn', 'ggnn', 'mean', 'transformer', 'cnn', 'transformer_lm'],
default='transformer',
choices=['mean', 'transformer'],
help="Type of encoder layer to use. Non-RNN layers "
"are experimental. Options are "
"[rnn|brnn|ggnn|mean|transformer|cnn|transformer_lm].",
"[mean|transformer].",
)
group.add(
'--decoder_type',
'-decoder_type',
type=str,
default='rnn',
choices=['rnn', 'transformer', 'cnn', 'transformer_lm'],
default='transformer',
choices=['transformer'],
help="Type of decoder layer to use. Non-RNN layers "
"are experimental. Options are "
"[rnn|transformer|cnn|transformer].",
"[transformer].",
)

group.add('--layers', '-layers', type=int, default=-1, help='Deprecated')
Expand All @@ -410,10 +410,9 @@ def model_opts(parser):
'-rnn_size',
type=int,
default=-1,
help="Size of rnn hidden states. Overwrites enc_rnn_size and dec_rnn_size",
help="Size of rnn hidden states.",
)
group.add('--enc_rnn_size', '-enc_rnn_size', type=int, default=500, help="Size of encoder rnn hidden states.")
group.add('--dec_rnn_size', '-dec_rnn_size', type=int, default=500, help="Size of decoder rnn hidden states.")

group.add(
'--cnn_kernel_width',
'-cnn_kernel_width',
Expand Down
11 changes: 3 additions & 8 deletions mammoth/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ def encoder_forward(self, opt, source_l=3, bsize=1):
source_l: Length of generated input sentence
bsize: Batchsize of generated input
'''
if opt.rnn_size > 0:
opt.enc_rnn_size = opt.rnn_size
word_field = self.get_field()
embeddings = build_embeddings(opt, word_field)
enc = build_encoder(opt, embeddings)
Expand All @@ -74,8 +72,8 @@ def encoder_forward(self, opt, source_l=3, bsize=1):
hidden_t, outputs, test_length = enc(test_src, test_length)

# Initialize vectors to compare size with
test_hid = torch.zeros(self.opt.enc_layers, bsize, opt.enc_rnn_size)
test_out = torch.zeros(source_l, bsize, opt.dec_rnn_size)
test_hid = torch.zeros(self.opt.enc_layers, bsize, opt.rnn_size)
test_out = torch.zeros(source_l, bsize, opt.rnn_size)

# Ensure correct sizes and types
self.assertEqual(test_hid.size(), hidden_t[0].size(), hidden_t[1].size())
Expand All @@ -92,9 +90,6 @@ def nmtmodel_forward(self, opt, source_l=3, bsize=1):
source_l: length of input sequence
bsize: batchsize
"""
if opt.rnn_size > 0:
opt.enc_rnn_size = opt.rnn_size
opt.dec_rnn_size = opt.rnn_size
word_field = self.get_field()

embeddings = build_embeddings(opt, word_field)
Expand All @@ -107,7 +102,7 @@ def nmtmodel_forward(self, opt, source_l=3, bsize=1):

test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize)
outputs, attn = model(test_src, test_tgt, test_length)
outputsize = torch.zeros(source_l - 1, bsize, opt.dec_rnn_size)
outputsize = torch.zeros(source_l - 1, bsize, opt.rnn_size)
# Make sure that output has the correct size and type
self.assertEqual(outputs.size(), outputsize.size())
self.assertEqual(type(outputs), torch.Tensor)
Expand Down
7 changes: 1 addition & 6 deletions mammoth/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,6 @@ def update_model_opts(cls, model_opt):
if model_opt.layers > 0:
raise Exception('--layers is deprecated')

if model_opt.rnn_size > 0:
model_opt.enc_rnn_size = model_opt.rnn_size
model_opt.dec_rnn_size = model_opt.rnn_size

model_opt.brnn = model_opt.encoder_type == "brnn"

if model_opt.copy_attn_type is None:
Expand All @@ -304,8 +300,7 @@ def validate_model_opts(cls, model_opt):
assert model_opt.model_type in ["text"], "Unsupported model type %s" % model_opt.model_type

# encoder and decoder should be same sizes
same_size = model_opt.enc_rnn_size == model_opt.dec_rnn_size
assert same_size, "The encoder and decoder rnns must be the same size for now"
# assert same_size, "The encoder and decoder rnns must be the same size for now"

assert model_opt.rnn_type != "SRU" or model_opt.gpu_ranks, "Using SRU requires -gpu_ranks set."
if model_opt.share_embeddings:
Expand Down

0 comments on commit 2c622ac

Please sign in to comment.