Skip to content

Commit

Permalink
add E-branchformer to i6_models (#27)
Browse files Browse the repository at this point in the history
* add cgMLP part

* ConvolutionalGatingMLPV1Config valid check

* add merge module part

* add merge module part test

* e_branchformer block

* add weight dropout

* update

* update

* remove post init

* update

* update

* add additional checks

* Update i6_models/parts/e_branchformer/merge.py

Co-authored-by: Christoph M. Lüscher <[email protected]>

---------

Co-authored-by: Nick Rossenbach <[email protected]>
Co-authored-by: Christoph M. Lüscher <[email protected]>
  • Loading branch information
3 people authored Oct 26, 2023
1 parent 8ee4ee0 commit 871fdce
Show file tree
Hide file tree
Showing 6 changed files with 306 additions and 0 deletions.
1 change: 1 addition & 0 deletions i6_models/assemblies/e_branchformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .e_branchformer_v1 import *
124 changes: 124 additions & 0 deletions i6_models/assemblies/e_branchformer/e_branchformer_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import annotations

__all__ = [
"EbranchformerBlockV1Config",
"EbranchformerBlockV1",
"EbranchformerEncoderV1Config",
"EbranchformerEncoderV1",
]

import torch
from torch import nn
from dataclasses import dataclass
from typing import Tuple

from i6_models.config import ModelConfiguration, ModuleFactoryV1
from i6_models.parts.conformer import (
ConformerMHSAV1 as MHSAV1,
ConformerMHSAV1Config as MHSAV1Config,
ConformerPositionwiseFeedForwardV1 as PositionwiseFeedForwardV1,
ConformerPositionwiseFeedForwardV1Config as PositionwiseFeedForwardV1Config,
)
from i6_models.parts.e_branchformer import (
ConvolutionalGatingMLPV1Config,
ConvolutionalGatingMLPV1,
MergerV1Config,
MergerV1,
)


@dataclass
class EbranchformerBlockV1Config(ModelConfiguration):
"""
Attributes:
ff_cfg: Configuration for PositionwiseFeedForwardV1 module
mhsa_cfg: Configuration for MHSAV1 module
cgmlp_cfg: Configuration for ConvolutionalGatingMLPV1 module
merger_cfg: Configuration for MergerV1 module
"""

ff_cfg: PositionwiseFeedForwardV1Config
mhsa_cfg: MHSAV1Config
cgmlp_cfg: ConvolutionalGatingMLPV1Config
merger_cfg: MergerV1Config


class EbranchformerBlockV1(nn.Module):
"""
Ebranchformer block module
"""

def __init__(self, cfg: EbranchformerBlockV1Config):
"""
:param cfg: e-branchformer block configuration with subunits for the different e-branchformer parts
"""
super().__init__()
self.ff_1 = PositionwiseFeedForwardV1(cfg=cfg.ff_cfg)
self.mhsa = MHSAV1(cfg=cfg.mhsa_cfg)
self.cgmlp = ConvolutionalGatingMLPV1(model_cfg=cfg.cgmlp_cfg)
self.merger = MergerV1(model_cfg=cfg.merger_cfg)
self.ff_2 = PositionwiseFeedForwardV1(cfg=cfg.ff_cfg)
self.final_layer_norm = torch.nn.LayerNorm(cfg.ff_cfg.input_dim)

def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> torch.Tensor:
"""
:param tensor: input tensor of shape [B, T, F]
:param sequence_mask: mask tensor where 0 defines positions within the sequence and 1 outside, shape: [B, T]
:return: torch.Tensor of shape [B, T, F]
"""
x = 0.5 * self.ff1(x) + x # [B, T, F]
x_1 = self.mhsa(x, sequence_mask) # [B, T, F]
x_2 = self.cgmlp(x) # [B, T, F]
x = self.merger(x_1, x_2) + x # [B, T, F]
x = 0.5 * self.ff2(x) + x # [B, T, F]
x = self.final_layer_norm(x) # [B, T, F]
return x


class EbranchformerEncoderV1Config(ModelConfiguration):
"""
Attributes:
num_layers: Number of e-branchformer layers in the e-branchformer encoder
frontend: A pair of ConformerFrontend and corresponding config
block_cfg: Configuration for EbranchformerBlockV1
"""

num_layers: int

# nested configurations
frontend: ModuleFactoryV1
block_cfg: EbranchformerBlockV1Config


class EbranchformerEncoderV1(nn.Module):
"""
Implementation of the Branchformer with Enhanced merging (short e-branchformer), as in the original publication.
The model consists of a frontend and a stack of N e-branchformer blocks.
C.f. https://arxiv.org/pdf/2210.00077.pdf
"""

def __init__(self, cfg: EbranchformerEncoderV1Config):
"""
:param cfg: e-branchformer encoder configuration with subunits for frontend and e-branchformer blocks
"""
super().__init__()

self.frontend = cfg.frontend()
self.module_list = torch.nn.ModuleList([EbranchformerBlockV1(cfg.block_cfg) for _ in range(cfg.num_layers)])

def forward(self, data_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param data_tensor: input tensor of shape [B, T', F']
:param sequence_mask: mask tensor where 1 defines positions within the sequence and 0 outside, shape: [B, T']
:return: (output, out_seq_mask)
where output is torch.Tensor of shape [B, T, F],
out_seq_mask is a torch.Tensor of shape [B, T]
F': input feature dim, F: internal and output feature dim
T': data time dim, T: down-sampled time dim (internal time dim)
"""
x, sequence_mask = self.frontend(data_tensor, sequence_mask) # [B, T, F]
for module in self.module_list:
x = module(x, sequence_mask) # [B, T, F]

return x, sequence_mask
2 changes: 2 additions & 0 deletions i6_models/parts/e_branchformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .cgmlp import *
from .merge import *
81 changes: 81 additions & 0 deletions i6_models/parts/e_branchformer/cgmlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

__all__ = ["ConvolutionalGatingMLPV1Config", "ConvolutionalGatingMLPV1"]

from dataclasses import dataclass
from typing import Callable

import torch
from torch import nn

from i6_models.config import ModelConfiguration


@dataclass
class ConvolutionalGatingMLPV1Config(ModelConfiguration):
"""
Attributes:
input_dim: input dimension
hidden_dim: hidden dimension (normally set to 6*input_dim as suggested by the paper)
kernel_size: kernel size of the depthwise convolution layer
dropout: dropout probability
activation: activation function
"""

input_dim: int
hidden_dim: int
kernel_size: int
dropout: float
activation: Callable[[torch.Tensor], torch.Tensor] = nn.functional.gelu

def check_valid(self):
assert self.kernel_size % 2 == 1, "ConvolutionalGatingMLPV1 only supports odd kernel sizes"
assert self.hidden_dim % 2 == 0, "ConvolutionalGatingMLPV1 only supports even hidden_dim"

def __post__init__(self):
super().__post_init__()
self.check_valid()


class ConvolutionalGatingMLPV1(nn.Module):
"""Convolutional Gating MLP (cgMLP)."""

def __init__(self, model_cfg: ConvolutionalGatingMLPV1Config):
super().__init__()

self.layer_norm_input = nn.LayerNorm(model_cfg.input_dim)
self.linear_ff = nn.Linear(in_features=model_cfg.input_dim, out_features=model_cfg.hidden_dim, bias=True)
self.activation = model_cfg.activation
self.layer_norm_csgu = nn.LayerNorm(model_cfg.hidden_dim // 2)
self.depthwise_conv = nn.Conv1d(
in_channels=model_cfg.hidden_dim // 2,
out_channels=model_cfg.hidden_dim // 2,
kernel_size=model_cfg.kernel_size,
padding=(model_cfg.kernel_size - 1) // 2,
groups=model_cfg.hidden_dim // 2,
)
self.linear_out = nn.Linear(in_features=model_cfg.hidden_dim // 2, out_features=model_cfg.input_dim, bias=True)
self.dropout = model_cfg.dropout

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
:param x: shape [B, T, F], F=input_dim
:return: shape [B, T, F], F=input_dim
"""
x = self.layer_norm_input(x) # [B, T, F]
x = self.linear_ff(x) # [B, T, F']
x = self.activation(x)

# convolutional spatial gating unit (csgu)
x_1, x_2 = x.chunk(2, dim=-1) # [B, T, F'//2], [B, T, F'//2]
x_2 = self.layer_norm_csgu(x_2)
# conv layers expect shape [B, F, T] so we have to transpose here
x_2 = x_2.transpose(1, 2) # [B, F'//2, T]
x_2 = self.depthwise_conv(x_2) # [B, F'//2, T]
x_2 = x_2.transpose(1, 2) # [B, T, F'//2]
x = x_1 * x_2 # [B, T, F'//2]
x = nn.functional.dropout(x, p=self.dropout, training=self.training)

x = self.linear_out(x) # [B, T, F]
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
return x
62 changes: 62 additions & 0 deletions i6_models/parts/e_branchformer/merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

__all__ = ["MergerV1Config", "MergerV1"]

from dataclasses import dataclass

import torch
from torch import nn

from i6_models.config import ModelConfiguration


@dataclass
class MergerV1Config(ModelConfiguration):
"""
Attributes:
input_dim: input dimension
kernel_size: kernel size of the depthwise convolution layer
dropout: dropout probability
"""

input_dim: int
kernel_size: int
dropout: float

def check_valid(self):
assert self.kernel_size % 2 == 1, "MergerV1 only supports odd kernel sizes"

def __post__init__(self):
super().__post_init__()
self.check_valid()


class MergerV1(nn.Module):
def __init__(self, model_cfg: MergerV1Config):
"""
The merge module to merge the outputs of local extractor and global extractor
Here we take the best variant from the E-branchformer paper (Fig. 3c), refer to
https://arxiv.org/abs/2210.00077 for more merge module variants
"""
super().__init__()

self.depthwise_conv = nn.Conv1d(
in_channels=model_cfg.input_dim * 2,
out_channels=model_cfg.input_dim * 2,
kernel_size=model_cfg.kernel_size,
padding=(model_cfg.kernel_size - 1) // 2,
groups=model_cfg.input_dim * 2,
)
self.linear_ff = nn.Linear(in_features=2 * model_cfg.input_dim, out_features=model_cfg.input_dim, bias=True)
self.dropout = model_cfg.dropout

def forward(self, x_1: torch.Tensor, x_2: torch.Tensor) -> torch.Tensor:
x_concat = torch.cat([x_1, x_2], dim=-1) # [B, T, 2F]
# conv layers expect shape [B, F, T] so we have to transpose here
x = x_concat.transpose(1, 2) # [B, 2F, T]
x = self.depthwise_conv(x)
x = x.transpose(1, 2) # [B, T, 2F]
x = x + x_concat
x = self.linear_ff(x)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
return x
36 changes: 36 additions & 0 deletions tests/test_e_branchformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from itertools import product

import torch
from torch import nn

from i6_models.parts.e_branchformer.cgmlp import ConvolutionalGatingMLPV1Config, ConvolutionalGatingMLPV1
from i6_models.parts.e_branchformer.merge import MergerV1Config, MergerV1


def test_ConvolutionalGatingMLPV1():
def get_output_shape(input_shape, hidden_dim, kernel_size, dropout, activation):
input_dim = input_shape[-1]
cfg = ConvolutionalGatingMLPV1Config(input_dim, hidden_dim, kernel_size, dropout, activation)
e_branchformer_cgmlp_part = ConvolutionalGatingMLPV1(cfg)
x = torch.randn(input_shape)
y = e_branchformer_cgmlp_part(x)
return y.shape

for input_shape, hidden_dim, kernel_size, dropout, activation in product(
[(100, 5, 20), (200, 30, 10)], [120, 60], [9, 15], [0.1, 0.3], [nn.functional.gelu, nn.functional.relu]
):
assert get_output_shape(input_shape, hidden_dim, kernel_size, dropout, activation) == input_shape


def test_MergerV1():
def get_output_shape(input_shape, kernel_size, dropout):
input_dim = input_shape[-1]
cfg = MergerV1Config(input_dim, kernel_size, dropout)
e_branchformer_merge_part = MergerV1(cfg)
tensor_local = torch.randn(input_shape)
tensor_global = torch.randn(input_shape)
y = e_branchformer_merge_part(tensor_local, tensor_global)
return y.shape

for input_shape, kernel_size, dropout in product([(100, 5, 20), (200, 30, 10)], [15, 31], [0.1, 0.3]):
assert get_output_shape(input_shape, kernel_size, dropout) == input_shape

0 comments on commit 871fdce

Please sign in to comment.