From 1ef0670335853d365c7905c1963b8c4c5131b173 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 4 Sep 2023 13:58:44 +0300 Subject: [PATCH] Bugfix: remove double layer norm Each LayerStack{Enc|Dec}oder includes layernorms between each layer and also at the beginning and end of the stack. This means that consecutive LayerStacks cause two layernorms immediately after each other. By introducing conditional processing for the extra layer norm, duplicates are avoided. This change is courtesy of Timothee Mickus. --- onmt/decoders/layer_stack_decoder.py | 12 +++++++++++- onmt/decoders/transformer.py | 13 +++++++++---- onmt/encoders/layer_stack_encoder.py | 10 ++++++++++ onmt/encoders/transformer.py | 9 +++++++-- 4 files changed, 37 insertions(+), 7 deletions(-) diff --git a/onmt/decoders/layer_stack_decoder.py b/onmt/decoders/layer_stack_decoder.py index 9eb1fff6..5fc7f594 100644 --- a/onmt/decoders/layer_stack_decoder.py +++ b/onmt/decoders/layer_stack_decoder.py @@ -17,10 +17,11 @@ def __init__(self, embeddings, decoders): self._active: List[str] = [] @classmethod - def from_opt(cls, opt, embeddings, task_queue_manager): + def from_opt(cls, opt, embeddings, task_queue_manager, is_on_top=False): """Alternate constructor for use during training.""" decoders = nn.ModuleList() for layer_stack_index, n_layers in enumerate(opt.dec_layers): + is_on_top = layer_stack_index == len(opt.dec_layers) - 1 stacks = nn.ModuleDict() for module_id in task_queue_manager.get_decoders(layer_stack_index): if module_id in stacks: @@ -46,6 +47,10 @@ def from_opt(cls, opt, embeddings, task_queue_manager): opt.alignment_layer, alignment_heads=opt.alignment_heads, pos_ffn_activation_fn=opt.pos_ffn_activation_fn, + layer_norm_module=( + nn.LayerNorm(opt.dec_rnn_size, eps=1e-6) if is_on_top + else nn.Identity() + ), ) decoders.append(stacks) return cls(embeddings, decoders) @@ -56,6 +61,7 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack): decoders = nn.ModuleList() for layer_stack_index, n_layers in enumerate(model_opt.dec_layers): stacks = nn.ModuleDict() + is_on_top = layer_stack_index == len(model_opt.dec_layers) - 1 module_opts = opt_stack['decoder'][layer_stack_index] module_id = module_opts['id'] stacks[module_id] = AdaptedTransformerDecoder( @@ -78,6 +84,10 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack): model_opt.alignment_layer, alignment_heads=model_opt.alignment_heads, pos_ffn_activation_fn=model_opt.pos_ffn_activation_fn, + layer_norm_module=( + nn.LayerNorm(model_opt.dec_rnn_size, eps=1e-6) if is_on_top + else nn.Identity() + ), ) decoders.append(stacks) return cls(embeddings, decoders) diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index 3e3a4ec1..637db455 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -266,7 +266,7 @@ def _forward( class TransformerDecoderBase(DecoderBase): - def __init__(self, d_model, copy_attn, embeddings, alignment_layer): + def __init__(self, d_model, copy_attn, embeddings, alignment_layer, layer_norm_module): super(TransformerDecoderBase, self).__init__() self.embeddings = embeddings @@ -278,12 +278,12 @@ def __init__(self, d_model, copy_attn, embeddings, alignment_layer): # attention. But it was never actually used -- the "copy" attention # just reuses the context attention. self._copy = copy_attn - self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.layer_norm = layer_norm_module self.alignment_layer = alignment_layer @classmethod - def from_opt(cls, opt, embeddings): + def from_opt(cls, opt, embeddings, is_on_top=False): """Alternate constructor.""" return cls( opt.dec_layers, @@ -301,6 +301,10 @@ def from_opt(cls, opt, embeddings): opt.alignment_layer, alignment_heads=opt.alignment_heads, pos_ffn_activation_fn=opt.pos_ffn_activation_fn, + layer_norm_module=( + nn.LayerNorm(opt.dec_rnn_size, eps=1e-6) if is_on_top + else nn.Identity() + ), ) def init_state(self, src, memory_bank, enc_hidden): @@ -391,8 +395,9 @@ def __init__( alignment_layer, alignment_heads, pos_ffn_activation_fn=ActivationFunction.relu, + layer_norm_module=None, ): - super(TransformerDecoder, self).__init__(d_model, copy_attn, embeddings, alignment_layer) + super(TransformerDecoder, self).__init__(d_model, copy_attn, embeddings, alignment_layer, layer_norm_module) self.transformer_layers = nn.ModuleList( [ diff --git a/onmt/encoders/layer_stack_encoder.py b/onmt/encoders/layer_stack_encoder.py index 412b6b39..77073fd9 100644 --- a/onmt/encoders/layer_stack_encoder.py +++ b/onmt/encoders/layer_stack_encoder.py @@ -22,6 +22,7 @@ def from_opt(cls, opt, embeddings, task_queue_manager): encoders = nn.ModuleList() for layer_stack_index, n_layers in enumerate(opt.enc_layers): stacks = nn.ModuleDict() + is_on_top = layer_stack_index == len(opt.enc_layers) - 1 for module_id in task_queue_manager.get_encoders(layer_stack_index): if module_id in stacks: # several tasks using the same layer stack @@ -40,6 +41,10 @@ def from_opt(cls, opt, embeddings, task_queue_manager): None, # embeddings, opt.max_relative_positions, pos_ffn_activation_fn=opt.pos_ffn_activation_fn, + layer_norm_module=( + nn.LayerNorm(opt.enc_rnn_size, eps=1e-6) if is_on_top + else nn.Identity() + ) ) encoders.append(stacks) return cls(embeddings, encoders) @@ -52,6 +57,7 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack): stacks = nn.ModuleDict() module_opts = opt_stack['encoder'][layer_stack_index] module_id = module_opts['id'] + is_on_top = layer_stack_index == len(model_opt.enc_layers) - 1 stacks[module_id] = AdaptedTransformerEncoder( n_layers, model_opt.enc_rnn_size, @@ -66,6 +72,10 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack): None, # embeddings, model_opt.max_relative_positions, pos_ffn_activation_fn=model_opt.pos_ffn_activation_fn, + layer_norm_module=( + nn.LayerNorm(model_opt.enc_rnn_size, eps=1e-6) if is_on_top + else nn.Identity() + ) ) encoders.append(stacks) return cls(embeddings, encoders) diff --git a/onmt/encoders/transformer.py b/onmt/encoders/transformer.py index 33e48c87..c020aa7d 100644 --- a/onmt/encoders/transformer.py +++ b/onmt/encoders/transformer.py @@ -111,6 +111,7 @@ def __init__( embeddings, max_relative_positions, pos_ffn_activation_fn=ActivationFunction.relu, + layer_norm_module=None, ): super(TransformerEncoder, self).__init__() @@ -129,10 +130,10 @@ def __init__( for i in range(num_layers) ] ) - self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.layer_norm = layer_norm_module @classmethod - def from_opt(cls, opt, embeddings): + def from_opt(cls, opt, embeddings, is_on_top=False): """Alternate constructor.""" return cls( opt.enc_layers, @@ -144,6 +145,10 @@ def from_opt(cls, opt, embeddings): embeddings, opt.max_relative_positions, pos_ffn_activation_fn=opt.pos_ffn_activation_fn, + layer_norm_module=( + nn.LayerNorm(opt.enc_rnn_size, eps=1e-6) if is_on_top + else nn.Identity() + ) ) def forward(self, src, lengths=None, skip_embedding=False, mask=None):