diff --git a/i6_models/parts/conformer/norm.py b/i6_models/parts/conformer/norm.py index d0401e5..c46155f 100644 --- a/i6_models/parts/conformer/norm.py +++ b/i6_models/parts/conformer/norm.py @@ -9,11 +9,11 @@ class LayerNormNC(nn.LayerNorm): see here: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html """ - def __init__(self, channels: int): + def __init__(self, channels: int, **kwargs): """ :param channels: number of channels for normalization """ - super().__init__(channels) + super().__init__(channels, **kwargs) def forward(self, tensor: torch.Tensor) -> torch.Tensor: """