20
20
21
21
22
22
class ConvBNReLU (nn .Module ):
23
- """Sequence of convolution-BatchNorm-ReLU layers."""
24
-
25
- def __init__ (self , out_channels , kernel_sizes , strides , in_channels = 1 ):
23
+ """Sequence of convolution-[BatchNorm]-ReLU layers.
24
+
25
+ Args:
26
+ out_channels (int): the number of output channels of conv layer
27
+ kernel_sizes (int or tuple): kernel sizes
28
+ strides (int or tuple): strides
29
+ in_channels (int, optional): the number of input channels (default: 1)
30
+ apply_batchnorm (bool, optional): if True apply BatchNorm after each convolution layer (default: True)
31
+ """
32
+
33
+ def __init__ (
34
+ self , out_channels , kernel_sizes , strides , in_channels = 1 , apply_batchnorm = True
35
+ ):
26
36
super ().__init__ ()
27
37
if not has_packaging :
28
38
raise ImportError ("Please install packaging with: pip install packaging" )
@@ -35,7 +45,7 @@ def __init__(self, out_channels, kernel_sizes, strides, in_channels=1):
35
45
assert num_layers == len (kernel_sizes ) and num_layers == len (strides )
36
46
37
47
self .convolutions = nn .ModuleList ()
38
- self .batchnorms = nn .ModuleList ()
48
+ self .batchnorms = nn .ModuleList () if apply_batchnorm else None
39
49
for i in range (num_layers ):
40
50
self .convolutions .append (
41
51
Convolution2d (
@@ -45,7 +55,8 @@ def __init__(self, out_channels, kernel_sizes, strides, in_channels=1):
45
55
self .strides [i ],
46
56
)
47
57
)
48
- self .batchnorms .append (nn .BatchNorm2d (out_channels [i ]))
58
+ if apply_batchnorm :
59
+ self .batchnorms .append (nn .BatchNorm2d (out_channels [i ]))
49
60
50
61
def output_lengths (self , in_lengths : Union [torch .Tensor , int ]):
51
62
out_lengths = in_lengths
@@ -65,18 +76,22 @@ def output_lengths(self, in_lengths: Union[torch.Tensor, int]):
65
76
return out_lengths
66
77
67
78
def forward (self , src , src_lengths ):
68
- # B X T X C -> B X (input channel num) x T X (C / input channel num)
79
+ # B x T x C -> B x (input channel num) x T x (C / input channel num)
69
80
x = src .view (
70
81
src .size (0 ),
71
82
src .size (1 ),
72
83
self .in_channels ,
73
84
src .size (2 ) // self .in_channels ,
74
85
).transpose (1 , 2 )
75
- for conv , bn in zip (self .convolutions , self .batchnorms ):
76
- x = F .relu (bn (conv (x )))
77
- # B X (output channel num) x T X C' -> B X T X (output channel num) X C'
86
+ if self .batchnorms is not None :
87
+ for conv , bn in zip (self .convolutions , self .batchnorms ):
88
+ x = F .relu (bn (conv (x )))
89
+ else :
90
+ for conv in self .convolutions :
91
+ x = F .relu (conv (x ))
92
+ # B x (output channel num) x T x C' -> B x T x (output channel num) x C'
78
93
x = x .transpose (1 , 2 )
79
- # B X T X (output channel num) X C' -> B X T X C
94
+ # B x T x (output channel num) x C' -> B x T x C
80
95
x = x .contiguous ().view (x .size (0 ), x .size (1 ), x .size (2 ) * x .size (3 ))
81
96
82
97
x_lengths = self .output_lengths (src_lengths )
0 commit comments