diff --git a/onmt/decoders/layer_stack_decoder.py b/onmt/decoders/layer_stack_decoder.py index 9eb1fff6..5fc7f594 100644 --- a/onmt/decoders/layer_stack_decoder.py +++ b/onmt/decoders/layer_stack_decoder.py @@ -17,10 +17,11 @@ def __init__(self, embeddings, decoders): self._active: List[str] = [] @classmethod - def from_opt(cls, opt, embeddings, task_queue_manager): + def from_opt(cls, opt, embeddings, task_queue_manager, is_on_top=False): """Alternate constructor for use during training.""" decoders = nn.ModuleList() for layer_stack_index, n_layers in enumerate(opt.dec_layers): + is_on_top = layer_stack_index == len(opt.dec_layers) - 1 stacks = nn.ModuleDict() for module_id in task_queue_manager.get_decoders(layer_stack_index): if module_id in stacks: @@ -46,6 +47,10 @@ def from_opt(cls, opt, embeddings, task_queue_manager): opt.alignment_layer, alignment_heads=opt.alignment_heads, pos_ffn_activation_fn=opt.pos_ffn_activation_fn, + layer_norm_module=( + nn.LayerNorm(opt.dec_rnn_size, eps=1e-6) if is_on_top + else nn.Identity() + ), ) decoders.append(stacks) return cls(embeddings, decoders) @@ -56,6 +61,7 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack): decoders = nn.ModuleList() for layer_stack_index, n_layers in enumerate(model_opt.dec_layers): stacks = nn.ModuleDict() + is_on_top = layer_stack_index == len(model_opt.dec_layers) - 1 module_opts = opt_stack['decoder'][layer_stack_index] module_id = module_opts['id'] stacks[module_id] = AdaptedTransformerDecoder( @@ -78,6 +84,10 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack): model_opt.alignment_layer, alignment_heads=model_opt.alignment_heads, pos_ffn_activation_fn=model_opt.pos_ffn_activation_fn, + layer_norm_module=( + nn.LayerNorm(model_opt.dec_rnn_size, eps=1e-6) if is_on_top + else nn.Identity() + ), ) decoders.append(stacks) return cls(embeddings, decoders) diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index 3e3a4ec1..637db455 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -266,7 +266,7 @@ def _forward( class TransformerDecoderBase(DecoderBase): - def __init__(self, d_model, copy_attn, embeddings, alignment_layer): + def __init__(self, d_model, copy_attn, embeddings, alignment_layer, layer_norm_module): super(TransformerDecoderBase, self).__init__() self.embeddings = embeddings @@ -278,12 +278,12 @@ def __init__(self, d_model, copy_attn, embeddings, alignment_layer): # attention. But it was never actually used -- the "copy" attention # just reuses the context attention. self._copy = copy_attn - self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.layer_norm = layer_norm_module self.alignment_layer = alignment_layer @classmethod - def from_opt(cls, opt, embeddings): + def from_opt(cls, opt, embeddings, is_on_top=False): """Alternate constructor.""" return cls( opt.dec_layers, @@ -301,6 +301,10 @@ def from_opt(cls, opt, embeddings): opt.alignment_layer, alignment_heads=opt.alignment_heads, pos_ffn_activation_fn=opt.pos_ffn_activation_fn, + layer_norm_module=( + nn.LayerNorm(opt.dec_rnn_size, eps=1e-6) if is_on_top + else nn.Identity() + ), ) def init_state(self, src, memory_bank, enc_hidden): @@ -391,8 +395,9 @@ def __init__( alignment_layer, alignment_heads, pos_ffn_activation_fn=ActivationFunction.relu, + layer_norm_module=None, ): - super(TransformerDecoder, self).__init__(d_model, copy_attn, embeddings, alignment_layer) + super(TransformerDecoder, self).__init__(d_model, copy_attn, embeddings, alignment_layer, layer_norm_module) self.transformer_layers = nn.ModuleList( [ diff --git a/onmt/encoders/layer_stack_encoder.py b/onmt/encoders/layer_stack_encoder.py index 412b6b39..77073fd9 100644 --- a/onmt/encoders/layer_stack_encoder.py +++ b/onmt/encoders/layer_stack_encoder.py @@ -22,6 +22,7 @@ def from_opt(cls, opt, embeddings, task_queue_manager): encoders = nn.ModuleList() for layer_stack_index, n_layers in enumerate(opt.enc_layers): stacks = nn.ModuleDict() + is_on_top = layer_stack_index == len(opt.enc_layers) - 1 for module_id in task_queue_manager.get_encoders(layer_stack_index): if module_id in stacks: # several tasks using the same layer stack @@ -40,6 +41,10 @@ def from_opt(cls, opt, embeddings, task_queue_manager): None, # embeddings, opt.max_relative_positions, pos_ffn_activation_fn=opt.pos_ffn_activation_fn, + layer_norm_module=( + nn.LayerNorm(opt.enc_rnn_size, eps=1e-6) if is_on_top + else nn.Identity() + ) ) encoders.append(stacks) return cls(embeddings, encoders) @@ -52,6 +57,7 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack): stacks = nn.ModuleDict() module_opts = opt_stack['encoder'][layer_stack_index] module_id = module_opts['id'] + is_on_top = layer_stack_index == len(model_opt.enc_layers) - 1 stacks[module_id] = AdaptedTransformerEncoder( n_layers, model_opt.enc_rnn_size, @@ -66,6 +72,10 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack): None, # embeddings, model_opt.max_relative_positions, pos_ffn_activation_fn=model_opt.pos_ffn_activation_fn, + layer_norm_module=( + nn.LayerNorm(model_opt.enc_rnn_size, eps=1e-6) if is_on_top + else nn.Identity() + ) ) encoders.append(stacks) return cls(embeddings, encoders) diff --git a/onmt/encoders/transformer.py b/onmt/encoders/transformer.py index 33e48c87..c020aa7d 100644 --- a/onmt/encoders/transformer.py +++ b/onmt/encoders/transformer.py @@ -111,6 +111,7 @@ def __init__( embeddings, max_relative_positions, pos_ffn_activation_fn=ActivationFunction.relu, + layer_norm_module=None, ): super(TransformerEncoder, self).__init__() @@ -129,10 +130,10 @@ def __init__( for i in range(num_layers) ] ) - self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.layer_norm = layer_norm_module @classmethod - def from_opt(cls, opt, embeddings): + def from_opt(cls, opt, embeddings, is_on_top=False): """Alternate constructor.""" return cls( opt.enc_layers, @@ -144,6 +145,10 @@ def from_opt(cls, opt, embeddings): embeddings, opt.max_relative_positions, pos_ffn_activation_fn=opt.pos_ffn_activation_fn, + layer_norm_module=( + nn.LayerNorm(opt.enc_rnn_size, eps=1e-6) if is_on_top + else nn.Identity() + ) ) def forward(self, src, lengths=None, skip_embedding=False, mask=None):