Skip to content

Commit 5d32606

Browse files
committed
makes BatchNorm configurable in preencoder; turns off elementwise_affine for layernorm_embedding in transformer encoder to stabilize training; adds the bias term to joiner's final fc_out and uses its default initializer
1 parent 321982b commit 5d32606

5 files changed

+43
-14
lines changed

espresso/models/speech_lstm.py

+1
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
235235
kernel_sizes,
236236
strides,
237237
in_channels=task.feat_in_channels,
238+
apply_batchnorm=True,
238239
)
239240
if out_channels is not None
240241
else None

espresso/models/transformer/speech_transformer_config.py

+6
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ class SpeechEncoderConfig(SpeechEncDecBaseConfig):
5252
default="[(1, 1), (2, 2), (1, 1), (2, 2)]",
5353
metadata={"help": "list of encoder convolution's out strides"},
5454
)
55+
conv_apply_batchnorm: bool = field(
56+
default=True,
57+
metadata={
58+
"help": "whether to apply BatchNorm after each convolution layer in pre-encoder"
59+
},
60+
)
5561
transformer_context: Optional[str] = field(
5662
default=None,
5763
metadata={

espresso/models/transformer/speech_transformer_encoder.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ def __init__(
105105
)
106106

107107
if cfg.layernorm_embedding:
108-
self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
108+
self.layernorm_embedding = LayerNorm(
109+
embed_dim, elementwise_affine=False, export=cfg.export
110+
) # sets elementwise_affine to False to stabilize training
109111
else:
110112
self.layernorm_embedding = None
111113

espresso/models/transformer/speech_transformer_transducer_base.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import torch.nn as nn
1111
import torch.nn.functional as F
12+
from omegaconf import DictConfig
1213
from torch import Tensor
1314

1415
import espresso.tools.utils as speech_utils
@@ -79,14 +80,12 @@ def __init__(self, cfg, encoder, decoder):
7980
self.fc_out = nn.Linear(
8081
self.decoder.embed_tokens.embedding_dim,
8182
self.decoder.embed_tokens.num_embeddings,
82-
bias=False,
8383
)
8484
self.fc_out.weight = self.decoder.embed_tokens.weight
8585
else:
8686
self.fc_out = nn.Linear(
87-
cfg.joint_dim, self.decoder.embed_tokens.num_embeddings, bias=False
87+
cfg.joint_dim, self.decoder.embed_tokens.num_embeddings
8888
)
89-
nn.init.normal_(self.fc_out.weight, mean=0, std=cfg.joint_dim**-0.5)
9089
self.fc_out = nn.utils.weight_norm(self.fc_out, name="weight")
9190

9291
self.cfg = cfg
@@ -144,6 +143,7 @@ def build_model(cls, cfg, task):
144143
kernel_sizes,
145144
strides,
146145
in_channels=task.feat_in_channels,
146+
apply_batchnorm=cfg.encoder.conv_apply_batchnorm,
147147
)
148148
if out_channels is not None
149149
else None
@@ -310,3 +310,8 @@ def get_normalized_probs(
310310
):
311311
"""Get normalized probabilities (or log probs) from a net's output."""
312312
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
313+
314+
def prepare_for_inference_(self, cfg: DictConfig):
315+
"""Prepare model for inference."""
316+
self.fc_out = nn.utils.remove_weight_norm(self.fc_out)
317+
super().prepare_for_inference_(cfg)

espresso/modules/speech_convolutions.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,19 @@
2020

2121

2222
class ConvBNReLU(nn.Module):
23-
"""Sequence of convolution-BatchNorm-ReLU layers."""
24-
25-
def __init__(self, out_channels, kernel_sizes, strides, in_channels=1):
23+
"""Sequence of convolution-[BatchNorm]-ReLU layers.
24+
25+
Args:
26+
out_channels (int): the number of output channels of conv layer
27+
kernel_sizes (int or tuple): kernel sizes
28+
strides (int or tuple): strides
29+
in_channels (int, optional): the number of input channels (default: 1)
30+
apply_batchnorm (bool, optional): if True apply BatchNorm after each convolution layer (default: True)
31+
"""
32+
33+
def __init__(
34+
self, out_channels, kernel_sizes, strides, in_channels=1, apply_batchnorm=True
35+
):
2636
super().__init__()
2737
if not has_packaging:
2838
raise ImportError("Please install packaging with: pip install packaging")
@@ -35,7 +45,7 @@ def __init__(self, out_channels, kernel_sizes, strides, in_channels=1):
3545
assert num_layers == len(kernel_sizes) and num_layers == len(strides)
3646

3747
self.convolutions = nn.ModuleList()
38-
self.batchnorms = nn.ModuleList()
48+
self.batchnorms = nn.ModuleList() if apply_batchnorm else None
3949
for i in range(num_layers):
4050
self.convolutions.append(
4151
Convolution2d(
@@ -45,7 +55,8 @@ def __init__(self, out_channels, kernel_sizes, strides, in_channels=1):
4555
self.strides[i],
4656
)
4757
)
48-
self.batchnorms.append(nn.BatchNorm2d(out_channels[i]))
58+
if apply_batchnorm:
59+
self.batchnorms.append(nn.BatchNorm2d(out_channels[i]))
4960

5061
def output_lengths(self, in_lengths: Union[torch.Tensor, int]):
5162
out_lengths = in_lengths
@@ -65,18 +76,22 @@ def output_lengths(self, in_lengths: Union[torch.Tensor, int]):
6576
return out_lengths
6677

6778
def forward(self, src, src_lengths):
68-
# B X T X C -> B X (input channel num) x T X (C / input channel num)
79+
# B x T x C -> B x (input channel num) x T x (C / input channel num)
6980
x = src.view(
7081
src.size(0),
7182
src.size(1),
7283
self.in_channels,
7384
src.size(2) // self.in_channels,
7485
).transpose(1, 2)
75-
for conv, bn in zip(self.convolutions, self.batchnorms):
76-
x = F.relu(bn(conv(x)))
77-
# B X (output channel num) x T X C' -> B X T X (output channel num) X C'
86+
if self.batchnorms is not None:
87+
for conv, bn in zip(self.convolutions, self.batchnorms):
88+
x = F.relu(bn(conv(x)))
89+
else:
90+
for conv in self.convolutions:
91+
x = F.relu(conv(x))
92+
# B x (output channel num) x T x C' -> B x T x (output channel num) x C'
7893
x = x.transpose(1, 2)
79-
# B X T X (output channel num) X C' -> B X T X C
94+
# B x T x (output channel num) x C' -> B x T x C
8095
x = x.contiguous().view(x.size(0), x.size(1), x.size(2) * x.size(3))
8196

8297
x_lengths = self.output_lengths(src_lengths)

0 commit comments

Comments
 (0)