Skip to content

Commit

Permalink
Merge pull request #3 from Helsinki-NLP/fix/double-layernorm
Browse files Browse the repository at this point in the history
Bugfix: remove double layer norm
  • Loading branch information
Waino authored Sep 18, 2023
2 parents 4dda101 + 1ef0670 commit e95e197
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 7 deletions.
12 changes: 11 additions & 1 deletion onmt/decoders/layer_stack_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ def __init__(self, embeddings, decoders):
self._active: List[str] = []

@classmethod
def from_opt(cls, opt, embeddings, task_queue_manager):
def from_opt(cls, opt, embeddings, task_queue_manager, is_on_top=False):
"""Alternate constructor for use during training."""
decoders = nn.ModuleList()
for layer_stack_index, n_layers in enumerate(opt.dec_layers):
is_on_top = layer_stack_index == len(opt.dec_layers) - 1
stacks = nn.ModuleDict()
for module_id in task_queue_manager.get_decoders(layer_stack_index):
if module_id in stacks:
Expand All @@ -46,6 +47,10 @@ def from_opt(cls, opt, embeddings, task_queue_manager):
opt.alignment_layer,
alignment_heads=opt.alignment_heads,
pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
layer_norm_module=(
nn.LayerNorm(opt.dec_rnn_size, eps=1e-6) if is_on_top
else nn.Identity()
),
)
decoders.append(stacks)
return cls(embeddings, decoders)
Expand All @@ -56,6 +61,7 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack):
decoders = nn.ModuleList()
for layer_stack_index, n_layers in enumerate(model_opt.dec_layers):
stacks = nn.ModuleDict()
is_on_top = layer_stack_index == len(model_opt.dec_layers) - 1
module_opts = opt_stack['decoder'][layer_stack_index]
module_id = module_opts['id']
stacks[module_id] = AdaptedTransformerDecoder(
Expand All @@ -78,6 +84,10 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack):
model_opt.alignment_layer,
alignment_heads=model_opt.alignment_heads,
pos_ffn_activation_fn=model_opt.pos_ffn_activation_fn,
layer_norm_module=(
nn.LayerNorm(model_opt.dec_rnn_size, eps=1e-6) if is_on_top
else nn.Identity()
),
)
decoders.append(stacks)
return cls(embeddings, decoders)
Expand Down
13 changes: 9 additions & 4 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _forward(


class TransformerDecoderBase(DecoderBase):
def __init__(self, d_model, copy_attn, embeddings, alignment_layer):
def __init__(self, d_model, copy_attn, embeddings, alignment_layer, layer_norm_module):
super(TransformerDecoderBase, self).__init__()

self.embeddings = embeddings
Expand All @@ -278,12 +278,12 @@ def __init__(self, d_model, copy_attn, embeddings, alignment_layer):
# attention. But it was never actually used -- the "copy" attention
# just reuses the context attention.
self._copy = copy_attn
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.layer_norm = layer_norm_module

self.alignment_layer = alignment_layer

@classmethod
def from_opt(cls, opt, embeddings):
def from_opt(cls, opt, embeddings, is_on_top=False):
"""Alternate constructor."""
return cls(
opt.dec_layers,
Expand All @@ -301,6 +301,10 @@ def from_opt(cls, opt, embeddings):
opt.alignment_layer,
alignment_heads=opt.alignment_heads,
pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
layer_norm_module=(
nn.LayerNorm(opt.dec_rnn_size, eps=1e-6) if is_on_top
else nn.Identity()
),
)

def init_state(self, src, memory_bank, enc_hidden):
Expand Down Expand Up @@ -391,8 +395,9 @@ def __init__(
alignment_layer,
alignment_heads,
pos_ffn_activation_fn=ActivationFunction.relu,
layer_norm_module=None,
):
super(TransformerDecoder, self).__init__(d_model, copy_attn, embeddings, alignment_layer)
super(TransformerDecoder, self).__init__(d_model, copy_attn, embeddings, alignment_layer, layer_norm_module)

self.transformer_layers = nn.ModuleList(
[
Expand Down
10 changes: 10 additions & 0 deletions onmt/encoders/layer_stack_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def from_opt(cls, opt, embeddings, task_queue_manager):
encoders = nn.ModuleList()
for layer_stack_index, n_layers in enumerate(opt.enc_layers):
stacks = nn.ModuleDict()
is_on_top = layer_stack_index == len(opt.enc_layers) - 1
for module_id in task_queue_manager.get_encoders(layer_stack_index):
if module_id in stacks:
# several tasks using the same layer stack
Expand All @@ -40,6 +41,10 @@ def from_opt(cls, opt, embeddings, task_queue_manager):
None, # embeddings,
opt.max_relative_positions,
pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
layer_norm_module=(
nn.LayerNorm(opt.enc_rnn_size, eps=1e-6) if is_on_top
else nn.Identity()
)
)
encoders.append(stacks)
return cls(embeddings, encoders)
Expand All @@ -52,6 +57,7 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack):
stacks = nn.ModuleDict()
module_opts = opt_stack['encoder'][layer_stack_index]
module_id = module_opts['id']
is_on_top = layer_stack_index == len(model_opt.enc_layers) - 1
stacks[module_id] = AdaptedTransformerEncoder(
n_layers,
model_opt.enc_rnn_size,
Expand All @@ -66,6 +72,10 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack):
None, # embeddings,
model_opt.max_relative_positions,
pos_ffn_activation_fn=model_opt.pos_ffn_activation_fn,
layer_norm_module=(
nn.LayerNorm(model_opt.enc_rnn_size, eps=1e-6) if is_on_top
else nn.Identity()
)
)
encoders.append(stacks)
return cls(embeddings, encoders)
Expand Down
9 changes: 7 additions & 2 deletions onmt/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
embeddings,
max_relative_positions,
pos_ffn_activation_fn=ActivationFunction.relu,
layer_norm_module=None,
):
super(TransformerEncoder, self).__init__()

Expand All @@ -129,10 +130,10 @@ def __init__(
for i in range(num_layers)
]
)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.layer_norm = layer_norm_module

@classmethod
def from_opt(cls, opt, embeddings):
def from_opt(cls, opt, embeddings, is_on_top=False):
"""Alternate constructor."""
return cls(
opt.enc_layers,
Expand All @@ -144,6 +145,10 @@ def from_opt(cls, opt, embeddings):
embeddings,
opt.max_relative_positions,
pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
layer_norm_module=(
nn.LayerNorm(opt.enc_rnn_size, eps=1e-6) if is_on_top
else nn.Identity()
)
)

def forward(self, src, lengths=None, skip_embedding=False, mask=None):
Expand Down

0 comments on commit e95e197

Please sign in to comment.