-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add E-branchformer to i6_models (#27)
* add cgMLP part * ConvolutionalGatingMLPV1Config valid check * add merge module part * add merge module part test * e_branchformer block * add weight dropout * update * update * remove post init * update * update * add additional checks * Update i6_models/parts/e_branchformer/merge.py Co-authored-by: Christoph M. Lüscher <[email protected]> --------- Co-authored-by: Nick Rossenbach <[email protected]> Co-authored-by: Christoph M. Lüscher <[email protected]>
- Loading branch information
1 parent
8ee4ee0
commit 871fdce
Showing
6 changed files
with
306 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .e_branchformer_v1 import * |
124 changes: 124 additions & 0 deletions
124
i6_models/assemblies/e_branchformer/e_branchformer_v1.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from __future__ import annotations | ||
|
||
__all__ = [ | ||
"EbranchformerBlockV1Config", | ||
"EbranchformerBlockV1", | ||
"EbranchformerEncoderV1Config", | ||
"EbranchformerEncoderV1", | ||
] | ||
|
||
import torch | ||
from torch import nn | ||
from dataclasses import dataclass | ||
from typing import Tuple | ||
|
||
from i6_models.config import ModelConfiguration, ModuleFactoryV1 | ||
from i6_models.parts.conformer import ( | ||
ConformerMHSAV1 as MHSAV1, | ||
ConformerMHSAV1Config as MHSAV1Config, | ||
ConformerPositionwiseFeedForwardV1 as PositionwiseFeedForwardV1, | ||
ConformerPositionwiseFeedForwardV1Config as PositionwiseFeedForwardV1Config, | ||
) | ||
from i6_models.parts.e_branchformer import ( | ||
ConvolutionalGatingMLPV1Config, | ||
ConvolutionalGatingMLPV1, | ||
MergerV1Config, | ||
MergerV1, | ||
) | ||
|
||
|
||
@dataclass | ||
class EbranchformerBlockV1Config(ModelConfiguration): | ||
""" | ||
Attributes: | ||
ff_cfg: Configuration for PositionwiseFeedForwardV1 module | ||
mhsa_cfg: Configuration for MHSAV1 module | ||
cgmlp_cfg: Configuration for ConvolutionalGatingMLPV1 module | ||
merger_cfg: Configuration for MergerV1 module | ||
""" | ||
|
||
ff_cfg: PositionwiseFeedForwardV1Config | ||
mhsa_cfg: MHSAV1Config | ||
cgmlp_cfg: ConvolutionalGatingMLPV1Config | ||
merger_cfg: MergerV1Config | ||
|
||
|
||
class EbranchformerBlockV1(nn.Module): | ||
""" | ||
Ebranchformer block module | ||
""" | ||
|
||
def __init__(self, cfg: EbranchformerBlockV1Config): | ||
""" | ||
:param cfg: e-branchformer block configuration with subunits for the different e-branchformer parts | ||
""" | ||
super().__init__() | ||
self.ff_1 = PositionwiseFeedForwardV1(cfg=cfg.ff_cfg) | ||
self.mhsa = MHSAV1(cfg=cfg.mhsa_cfg) | ||
self.cgmlp = ConvolutionalGatingMLPV1(model_cfg=cfg.cgmlp_cfg) | ||
self.merger = MergerV1(model_cfg=cfg.merger_cfg) | ||
self.ff_2 = PositionwiseFeedForwardV1(cfg=cfg.ff_cfg) | ||
self.final_layer_norm = torch.nn.LayerNorm(cfg.ff_cfg.input_dim) | ||
|
||
def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> torch.Tensor: | ||
""" | ||
:param tensor: input tensor of shape [B, T, F] | ||
:param sequence_mask: mask tensor where 0 defines positions within the sequence and 1 outside, shape: [B, T] | ||
:return: torch.Tensor of shape [B, T, F] | ||
""" | ||
x = 0.5 * self.ff1(x) + x # [B, T, F] | ||
x_1 = self.mhsa(x, sequence_mask) # [B, T, F] | ||
x_2 = self.cgmlp(x) # [B, T, F] | ||
x = self.merger(x_1, x_2) + x # [B, T, F] | ||
x = 0.5 * self.ff2(x) + x # [B, T, F] | ||
x = self.final_layer_norm(x) # [B, T, F] | ||
return x | ||
|
||
|
||
class EbranchformerEncoderV1Config(ModelConfiguration): | ||
""" | ||
Attributes: | ||
num_layers: Number of e-branchformer layers in the e-branchformer encoder | ||
frontend: A pair of ConformerFrontend and corresponding config | ||
block_cfg: Configuration for EbranchformerBlockV1 | ||
""" | ||
|
||
num_layers: int | ||
|
||
# nested configurations | ||
frontend: ModuleFactoryV1 | ||
block_cfg: EbranchformerBlockV1Config | ||
|
||
|
||
class EbranchformerEncoderV1(nn.Module): | ||
""" | ||
Implementation of the Branchformer with Enhanced merging (short e-branchformer), as in the original publication. | ||
The model consists of a frontend and a stack of N e-branchformer blocks. | ||
C.f. https://arxiv.org/pdf/2210.00077.pdf | ||
""" | ||
|
||
def __init__(self, cfg: EbranchformerEncoderV1Config): | ||
""" | ||
:param cfg: e-branchformer encoder configuration with subunits for frontend and e-branchformer blocks | ||
""" | ||
super().__init__() | ||
|
||
self.frontend = cfg.frontend() | ||
self.module_list = torch.nn.ModuleList([EbranchformerBlockV1(cfg.block_cfg) for _ in range(cfg.num_layers)]) | ||
|
||
def forward(self, data_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
:param data_tensor: input tensor of shape [B, T', F'] | ||
:param sequence_mask: mask tensor where 1 defines positions within the sequence and 0 outside, shape: [B, T'] | ||
:return: (output, out_seq_mask) | ||
where output is torch.Tensor of shape [B, T, F], | ||
out_seq_mask is a torch.Tensor of shape [B, T] | ||
F': input feature dim, F: internal and output feature dim | ||
T': data time dim, T: down-sampled time dim (internal time dim) | ||
""" | ||
x, sequence_mask = self.frontend(data_tensor, sequence_mask) # [B, T, F] | ||
for module in self.module_list: | ||
x = module(x, sequence_mask) # [B, T, F] | ||
|
||
return x, sequence_mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .cgmlp import * | ||
from .merge import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from __future__ import annotations | ||
|
||
__all__ = ["ConvolutionalGatingMLPV1Config", "ConvolutionalGatingMLPV1"] | ||
|
||
from dataclasses import dataclass | ||
from typing import Callable | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from i6_models.config import ModelConfiguration | ||
|
||
|
||
@dataclass | ||
class ConvolutionalGatingMLPV1Config(ModelConfiguration): | ||
""" | ||
Attributes: | ||
input_dim: input dimension | ||
hidden_dim: hidden dimension (normally set to 6*input_dim as suggested by the paper) | ||
kernel_size: kernel size of the depthwise convolution layer | ||
dropout: dropout probability | ||
activation: activation function | ||
""" | ||
|
||
input_dim: int | ||
hidden_dim: int | ||
kernel_size: int | ||
dropout: float | ||
activation: Callable[[torch.Tensor], torch.Tensor] = nn.functional.gelu | ||
|
||
def check_valid(self): | ||
assert self.kernel_size % 2 == 1, "ConvolutionalGatingMLPV1 only supports odd kernel sizes" | ||
assert self.hidden_dim % 2 == 0, "ConvolutionalGatingMLPV1 only supports even hidden_dim" | ||
|
||
def __post__init__(self): | ||
super().__post_init__() | ||
self.check_valid() | ||
|
||
|
||
class ConvolutionalGatingMLPV1(nn.Module): | ||
"""Convolutional Gating MLP (cgMLP).""" | ||
|
||
def __init__(self, model_cfg: ConvolutionalGatingMLPV1Config): | ||
super().__init__() | ||
|
||
self.layer_norm_input = nn.LayerNorm(model_cfg.input_dim) | ||
self.linear_ff = nn.Linear(in_features=model_cfg.input_dim, out_features=model_cfg.hidden_dim, bias=True) | ||
self.activation = model_cfg.activation | ||
self.layer_norm_csgu = nn.LayerNorm(model_cfg.hidden_dim // 2) | ||
self.depthwise_conv = nn.Conv1d( | ||
in_channels=model_cfg.hidden_dim // 2, | ||
out_channels=model_cfg.hidden_dim // 2, | ||
kernel_size=model_cfg.kernel_size, | ||
padding=(model_cfg.kernel_size - 1) // 2, | ||
groups=model_cfg.hidden_dim // 2, | ||
) | ||
self.linear_out = nn.Linear(in_features=model_cfg.hidden_dim // 2, out_features=model_cfg.input_dim, bias=True) | ||
self.dropout = model_cfg.dropout | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
""" | ||
:param x: shape [B, T, F], F=input_dim | ||
:return: shape [B, T, F], F=input_dim | ||
""" | ||
x = self.layer_norm_input(x) # [B, T, F] | ||
x = self.linear_ff(x) # [B, T, F'] | ||
x = self.activation(x) | ||
|
||
# convolutional spatial gating unit (csgu) | ||
x_1, x_2 = x.chunk(2, dim=-1) # [B, T, F'//2], [B, T, F'//2] | ||
x_2 = self.layer_norm_csgu(x_2) | ||
# conv layers expect shape [B, F, T] so we have to transpose here | ||
x_2 = x_2.transpose(1, 2) # [B, F'//2, T] | ||
x_2 = self.depthwise_conv(x_2) # [B, F'//2, T] | ||
x_2 = x_2.transpose(1, 2) # [B, T, F'//2] | ||
x = x_1 * x_2 # [B, T, F'//2] | ||
x = nn.functional.dropout(x, p=self.dropout, training=self.training) | ||
|
||
x = self.linear_out(x) # [B, T, F] | ||
x = nn.functional.dropout(x, p=self.dropout, training=self.training) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from __future__ import annotations | ||
|
||
__all__ = ["MergerV1Config", "MergerV1"] | ||
|
||
from dataclasses import dataclass | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from i6_models.config import ModelConfiguration | ||
|
||
|
||
@dataclass | ||
class MergerV1Config(ModelConfiguration): | ||
""" | ||
Attributes: | ||
input_dim: input dimension | ||
kernel_size: kernel size of the depthwise convolution layer | ||
dropout: dropout probability | ||
""" | ||
|
||
input_dim: int | ||
kernel_size: int | ||
dropout: float | ||
|
||
def check_valid(self): | ||
assert self.kernel_size % 2 == 1, "MergerV1 only supports odd kernel sizes" | ||
|
||
def __post__init__(self): | ||
super().__post_init__() | ||
self.check_valid() | ||
|
||
|
||
class MergerV1(nn.Module): | ||
def __init__(self, model_cfg: MergerV1Config): | ||
""" | ||
The merge module to merge the outputs of local extractor and global extractor | ||
Here we take the best variant from the E-branchformer paper (Fig. 3c), refer to | ||
https://arxiv.org/abs/2210.00077 for more merge module variants | ||
""" | ||
super().__init__() | ||
|
||
self.depthwise_conv = nn.Conv1d( | ||
in_channels=model_cfg.input_dim * 2, | ||
out_channels=model_cfg.input_dim * 2, | ||
kernel_size=model_cfg.kernel_size, | ||
padding=(model_cfg.kernel_size - 1) // 2, | ||
groups=model_cfg.input_dim * 2, | ||
) | ||
self.linear_ff = nn.Linear(in_features=2 * model_cfg.input_dim, out_features=model_cfg.input_dim, bias=True) | ||
self.dropout = model_cfg.dropout | ||
|
||
def forward(self, x_1: torch.Tensor, x_2: torch.Tensor) -> torch.Tensor: | ||
x_concat = torch.cat([x_1, x_2], dim=-1) # [B, T, 2F] | ||
# conv layers expect shape [B, F, T] so we have to transpose here | ||
x = x_concat.transpose(1, 2) # [B, 2F, T] | ||
x = self.depthwise_conv(x) | ||
x = x.transpose(1, 2) # [B, T, 2F] | ||
x = x + x_concat | ||
x = self.linear_ff(x) | ||
x = nn.functional.dropout(x, p=self.dropout, training=self.training) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from itertools import product | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from i6_models.parts.e_branchformer.cgmlp import ConvolutionalGatingMLPV1Config, ConvolutionalGatingMLPV1 | ||
from i6_models.parts.e_branchformer.merge import MergerV1Config, MergerV1 | ||
|
||
|
||
def test_ConvolutionalGatingMLPV1(): | ||
def get_output_shape(input_shape, hidden_dim, kernel_size, dropout, activation): | ||
input_dim = input_shape[-1] | ||
cfg = ConvolutionalGatingMLPV1Config(input_dim, hidden_dim, kernel_size, dropout, activation) | ||
e_branchformer_cgmlp_part = ConvolutionalGatingMLPV1(cfg) | ||
x = torch.randn(input_shape) | ||
y = e_branchformer_cgmlp_part(x) | ||
return y.shape | ||
|
||
for input_shape, hidden_dim, kernel_size, dropout, activation in product( | ||
[(100, 5, 20), (200, 30, 10)], [120, 60], [9, 15], [0.1, 0.3], [nn.functional.gelu, nn.functional.relu] | ||
): | ||
assert get_output_shape(input_shape, hidden_dim, kernel_size, dropout, activation) == input_shape | ||
|
||
|
||
def test_MergerV1(): | ||
def get_output_shape(input_shape, kernel_size, dropout): | ||
input_dim = input_shape[-1] | ||
cfg = MergerV1Config(input_dim, kernel_size, dropout) | ||
e_branchformer_merge_part = MergerV1(cfg) | ||
tensor_local = torch.randn(input_shape) | ||
tensor_global = torch.randn(input_shape) | ||
y = e_branchformer_merge_part(tensor_local, tensor_global) | ||
return y.shape | ||
|
||
for input_shape, kernel_size, dropout in product([(100, 5, 20), (200, 30, 10)], [15, 31], [0.1, 0.3]): | ||
assert get_output_shape(input_shape, kernel_size, dropout) == input_shape |