From 2c622ac965b7b4aefc7dd163fbcda76895a76711 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Wed, 20 Sep 2023 00:28:01 +0300 Subject: [PATCH] {enc,dec}_rnn_size merged --- mammoth/attention_bridge.py | 12 ++++++------ mammoth/decoders/layer_stack_decoder.py | 8 ++++---- mammoth/decoders/transformer_decoder.py | 4 ++-- mammoth/encoders/transformer_encoder.py | 2 +- mammoth/model_builder.py | 3 +-- mammoth/opts.py | 17 ++++++++--------- mammoth/tests/test_models.py | 11 +++-------- mammoth/utils/parse.py | 7 +------ 8 files changed, 26 insertions(+), 38 deletions(-) diff --git a/mammoth/attention_bridge.py b/mammoth/attention_bridge.py index 29317d3b..c0257cb3 100644 --- a/mammoth/attention_bridge.py +++ b/mammoth/attention_bridge.py @@ -133,7 +133,7 @@ def __init__( attention_heads, hidden_ab_size, model_type, - dec_rnn_size, + rnn_size, ab_layer_norm=None, ): """Attention Heads Layer:""" @@ -144,7 +144,7 @@ def __init__( self.dd = u self.model_type = model_type if self.model_type != "text": - d = dec_rnn_size + d = rnn_size self.ws1 = nn.Linear(d, u, bias=True) self.ws2 = nn.Linear(u, r, bias=True) self.relu = nn.ReLU() @@ -161,7 +161,7 @@ def from_opt(cls, opt): opt.ab_fixed_length, opt.hidden_ab_size, opt.model_type, - opt.dec_rnn_size, + opt.rnn_size, opt.ab_layer_norm, ) @@ -246,7 +246,7 @@ def forward(self, intermediate_output, encoder_output, mask=None): @classmethod def from_opt(cls, opt): return cls( - opt.enc_rnn_size, + opt.rnn_size, opt.hidden_ab_size, opt.ab_fixed_length, opt.ab_layer_norm, @@ -278,7 +278,7 @@ def forward(self, intermediate_output, encoder_output, mask=None): @classmethod def from_opt(cls, opt): return cls( - opt.enc_rnn_size, + opt.rnn_size, opt.heads, opt.hidden_ab_size, # d_ff # TODO: that list indexing things seems suspicious to me... @@ -315,7 +315,7 @@ def forward(self, intermediate_output, encoder_output, mask=None): @classmethod def from_opt(cls, opt): return cls( - opt.enc_rnn_size, + opt.rnn_size, opt.hidden_ab_size, opt.ab_layer_norm, ) diff --git a/mammoth/decoders/layer_stack_decoder.py b/mammoth/decoders/layer_stack_decoder.py index e3a87240..93a7d925 100644 --- a/mammoth/decoders/layer_stack_decoder.py +++ b/mammoth/decoders/layer_stack_decoder.py @@ -29,7 +29,7 @@ def from_opt(cls, opt, embeddings, task_queue_manager, is_on_top=False): continue stacks[module_id] = AdaptedTransformerDecoder( n_layers, - opt.dec_rnn_size, + opt.rnn_size, opt.heads, opt.transformer_ff, opt.copy_attn, @@ -48,7 +48,7 @@ def from_opt(cls, opt, embeddings, task_queue_manager, is_on_top=False): alignment_heads=opt.alignment_heads, pos_ffn_activation_fn=opt.pos_ffn_activation_fn, layer_norm_module=( - nn.LayerNorm(opt.dec_rnn_size, eps=1e-6) if is_on_top + nn.LayerNorm(opt.rnn_size, eps=1e-6) if is_on_top else nn.Identity() ), ) @@ -66,7 +66,7 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack): module_id = module_opts['id'] stacks[module_id] = AdaptedTransformerDecoder( n_layers, - model_opt.dec_rnn_size, + model_opt.rnn_size, model_opt.heads, model_opt.transformer_ff, model_opt.copy_attn, @@ -85,7 +85,7 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack): alignment_heads=model_opt.alignment_heads, pos_ffn_activation_fn=model_opt.pos_ffn_activation_fn, layer_norm_module=( - nn.LayerNorm(model_opt.dec_rnn_size, eps=1e-6) if is_on_top + nn.LayerNorm(model_opt.rnn_size, eps=1e-6) if is_on_top else nn.Identity() ), ) diff --git a/mammoth/decoders/transformer_decoder.py b/mammoth/decoders/transformer_decoder.py index b41639d2..4307d155 100644 --- a/mammoth/decoders/transformer_decoder.py +++ b/mammoth/decoders/transformer_decoder.py @@ -287,7 +287,7 @@ def from_opt(cls, opt, embeddings, is_on_top=False): """Alternate constructor.""" return cls( opt.dec_layers, - opt.dec_rnn_size, + opt.rnn_size, opt.heads, opt.transformer_ff, opt.copy_attn, @@ -302,7 +302,7 @@ def from_opt(cls, opt, embeddings, is_on_top=False): alignment_heads=opt.alignment_heads, pos_ffn_activation_fn=opt.pos_ffn_activation_fn, layer_norm_module=( - nn.LayerNorm(opt.dec_rnn_size, eps=1e-6) if is_on_top + nn.LayerNorm(opt.rnn_size, eps=1e-6) if is_on_top else nn.Identity() ), ) diff --git a/mammoth/encoders/transformer_encoder.py b/mammoth/encoders/transformer_encoder.py index 44aa40c8..cd5d23d1 100644 --- a/mammoth/encoders/transformer_encoder.py +++ b/mammoth/encoders/transformer_encoder.py @@ -137,7 +137,7 @@ def from_opt(cls, opt, embeddings, is_on_top=False): """Alternate constructor.""" return cls( opt.enc_layers, - opt.enc_rnn_size, + opt.rnn_size, opt.heads, opt.transformer_ff, opt.dropout[0] if type(opt.dropout) is list else opt.dropout, diff --git a/mammoth/model_builder.py b/mammoth/model_builder.py index ff4ec6ce..4c901d5c 100644 --- a/mammoth/model_builder.py +++ b/mammoth/model_builder.py @@ -7,7 +7,6 @@ from torch.nn.init import xavier_uniform_ from collections import defaultdict -# from torchtext.legacy.data import Field import mammoth.modules @@ -404,7 +403,7 @@ def build_generator(model_opt, n_tgts, tgt_emb): else: gen_func = nn.LogSoftmax(dim=-1) generator = nn.Sequential( - nn.Linear(model_opt.dec_rnn_size, n_tgts), Cast(torch.float32), gen_func + nn.Linear(model_opt.rnn_size, n_tgts), Cast(torch.float32), gen_func ) if model_opt.share_decoder_embeddings: diff --git a/mammoth/opts.py b/mammoth/opts.py index bb67b7c9..a8e7e17d 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -385,21 +385,21 @@ def model_opts(parser): '--encoder_type', '-encoder_type', type=str, - default='rnn', - choices=['rnn', 'brnn', 'ggnn', 'mean', 'transformer', 'cnn', 'transformer_lm'], + default='transformer', + choices=['mean', 'transformer'], help="Type of encoder layer to use. Non-RNN layers " "are experimental. Options are " - "[rnn|brnn|ggnn|mean|transformer|cnn|transformer_lm].", + "[mean|transformer].", ) group.add( '--decoder_type', '-decoder_type', type=str, - default='rnn', - choices=['rnn', 'transformer', 'cnn', 'transformer_lm'], + default='transformer', + choices=['transformer'], help="Type of decoder layer to use. Non-RNN layers " "are experimental. Options are " - "[rnn|transformer|cnn|transformer].", + "[transformer].", ) group.add('--layers', '-layers', type=int, default=-1, help='Deprecated') @@ -410,10 +410,9 @@ def model_opts(parser): '-rnn_size', type=int, default=-1, - help="Size of rnn hidden states. Overwrites enc_rnn_size and dec_rnn_size", + help="Size of rnn hidden states.", ) - group.add('--enc_rnn_size', '-enc_rnn_size', type=int, default=500, help="Size of encoder rnn hidden states.") - group.add('--dec_rnn_size', '-dec_rnn_size', type=int, default=500, help="Size of decoder rnn hidden states.") + group.add( '--cnn_kernel_width', '-cnn_kernel_width', diff --git a/mammoth/tests/test_models.py b/mammoth/tests/test_models.py index e795b1b0..6c1803d7 100644 --- a/mammoth/tests/test_models.py +++ b/mammoth/tests/test_models.py @@ -63,8 +63,6 @@ def encoder_forward(self, opt, source_l=3, bsize=1): source_l: Length of generated input sentence bsize: Batchsize of generated input ''' - if opt.rnn_size > 0: - opt.enc_rnn_size = opt.rnn_size word_field = self.get_field() embeddings = build_embeddings(opt, word_field) enc = build_encoder(opt, embeddings) @@ -74,8 +72,8 @@ def encoder_forward(self, opt, source_l=3, bsize=1): hidden_t, outputs, test_length = enc(test_src, test_length) # Initialize vectors to compare size with - test_hid = torch.zeros(self.opt.enc_layers, bsize, opt.enc_rnn_size) - test_out = torch.zeros(source_l, bsize, opt.dec_rnn_size) + test_hid = torch.zeros(self.opt.enc_layers, bsize, opt.rnn_size) + test_out = torch.zeros(source_l, bsize, opt.rnn_size) # Ensure correct sizes and types self.assertEqual(test_hid.size(), hidden_t[0].size(), hidden_t[1].size()) @@ -92,9 +90,6 @@ def nmtmodel_forward(self, opt, source_l=3, bsize=1): source_l: length of input sequence bsize: batchsize """ - if opt.rnn_size > 0: - opt.enc_rnn_size = opt.rnn_size - opt.dec_rnn_size = opt.rnn_size word_field = self.get_field() embeddings = build_embeddings(opt, word_field) @@ -107,7 +102,7 @@ def nmtmodel_forward(self, opt, source_l=3, bsize=1): test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize) outputs, attn = model(test_src, test_tgt, test_length) - outputsize = torch.zeros(source_l - 1, bsize, opt.dec_rnn_size) + outputsize = torch.zeros(source_l - 1, bsize, opt.rnn_size) # Make sure that output has the correct size and type self.assertEqual(outputs.size(), outputsize.size()) self.assertEqual(type(outputs), torch.Tensor) diff --git a/mammoth/utils/parse.py b/mammoth/utils/parse.py index 0d0f34e2..759ba589 100644 --- a/mammoth/utils/parse.py +++ b/mammoth/utils/parse.py @@ -285,10 +285,6 @@ def update_model_opts(cls, model_opt): if model_opt.layers > 0: raise Exception('--layers is deprecated') - if model_opt.rnn_size > 0: - model_opt.enc_rnn_size = model_opt.rnn_size - model_opt.dec_rnn_size = model_opt.rnn_size - model_opt.brnn = model_opt.encoder_type == "brnn" if model_opt.copy_attn_type is None: @@ -304,8 +300,7 @@ def validate_model_opts(cls, model_opt): assert model_opt.model_type in ["text"], "Unsupported model type %s" % model_opt.model_type # encoder and decoder should be same sizes - same_size = model_opt.enc_rnn_size == model_opt.dec_rnn_size - assert same_size, "The encoder and decoder rnns must be the same size for now" + # assert same_size, "The encoder and decoder rnns must be the same size for now" assert model_opt.rnn_type != "SRU" or model_opt.gpu_ranks, "Using SRU requires -gpu_ranks set." if model_opt.share_embeddings: