diff --git a/i6_models/parts/frontend/generic_frontend.py b/i6_models/parts/frontend/generic_frontend.py new file mode 100644 index 00000000..3df213d9 --- /dev/null +++ b/i6_models/parts/frontend/generic_frontend.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +__all__ = [ + "FrontendLayerType", + "GenericFrontendV1Config", + "GenericFrontendV1", +] + +from dataclasses import dataclass +from enum import Enum, auto +from typing import Callable, Optional, Tuple, Union, Sequence + +import torch +from torch import nn + +from i6_models.config import ModelConfiguration + +from i6_models.parts.frontend.common import get_same_padding, mask_pool, calculate_output_dim + + +class FrontendLayerType(Enum): + Conv2d = auto() + Pool2d = auto() + Activation = auto() + + +@dataclass +class GenericFrontendV1Config(ModelConfiguration): + """ + Attributes: + in_features: number of input features to module + layer_ordering: the ordering of the front end layer sequences, the ordering element must be selected from FrontendLayerType + e.g. the ordering of VGG4LayerActFrontendV1 would be [FrontendLayerType.Conv2d, FrontendLayerType.Activation, + FrontendLayerType.Pool2d, FrontendLayerType.Conv2d, FrontendLayerType.Conv2d, FrontendLayerType.Activation, + FrontendLayerType.Pool2d] + conv_kernel_sizes: kernel sizes for each conv layer + conv_strides: stride sizes for each conv layer + conv_paddings: paddings sizes for each conv layer + conv_out_dims: number of out channels for each conv layer + pool_kernel_sizes: kernel sizes for each pool layer + pool_strides: stride sizes for each pool layer + pool_paddings: padding sizes for each pool layer + activations: activation functions + out_features: output size of the final linear layer + """ + + in_features: int + layer_ordering: Sequence[FrontendLayerType] + conv_kernel_sizes: Optional[Sequence[Tuple[int, int]]] + conv_strides: Optional[Sequence[Tuple[int, int]]] + conv_paddings: Optional[Sequence[Tuple[int, int]]] + conv_out_dims: Optional[Sequence[int]] + pool_kernel_sizes: Optional[Sequence[Tuple[int, int]]] + pool_strides: Optional[Sequence[Tuple[int, int]]] + pool_paddings: Optional[Sequence[Tuple[int, int]]] + activations: Optional[Sequence[Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]]] + out_features: int + + def check_valid(self): + num_convs = 0 if self.conv_kernel_sizes is None else len(self.conv_kernel_sizes) + num_pools = 0 if self.pool_kernel_sizes is None else len(self.pool_kernel_sizes) + num_activations = 0 if self.activations is None else len(self.activations) + + assert num_convs == self.layer_ordering.count( + FrontendLayerType.Conv2d + ), "Number of convolution layers mismatch!" + assert num_activations == self.layer_ordering.count( + FrontendLayerType.Activation + ), "Number of activation layers mismatch!" + assert num_pools == self.layer_ordering.count(FrontendLayerType.Pool2d), "Number of pooling layers mismatch!" + + if self.conv_strides is not None: + assert len(self.conv_strides) == num_convs, "Please specify stride for each convolution layer!" + if self.conv_paddings is not None: + assert len(self.conv_paddings) == num_convs, "Please specify padding for each convolution layer!" + if num_convs != 0: + assert ( + len(self.conv_out_dims) == num_convs + ), "Please specify the number of channels for each convolution layer!" + + if self.pool_strides is not None: + assert len(self.pool_strides) == num_pools, "Please specify stride for each pooling layer!" + if self.pool_paddings is not None: + assert len(self.pool_paddings) == num_pools, "Please specify padding for each pooling layer!" + + assert len(self.layer_ordering) == num_convs + num_pools + num_activations, "Number of total layers mismatch!" + + for kernel_sizes in filter(None, [self.conv_kernel_sizes, self.pool_kernel_sizes]): + for kernel_size in kernel_sizes: + assert all(k % 2 for k in kernel_size), "ConformerVGGFrontendV1 only supports odd kernel sizes" + + def __post__init__(self): + super().__post_init__() + self.check_valid() + + +class GenericFrontendV1(nn.Module): + def __init__(self, model_cfg: GenericFrontendV1Config): + """ + Generic Front-End + can be used to generate customized frontend by combining convolutional and pooling layers, as well as activation + functions differently + + To get the ESPnet case, for example Conv2dSubsampling6, use these options + layer_ordering = [FrontendLayerType.Conv2d, FrontendLayerType.Conv2d] + conv_kernel_sizes = [3, 5] + strides = [2, 3] + + To get the i6_models VGG4LayerActFrontendV1, use the options: + layer_ordering = [FrontendLayerType.Conv2d, FrontendLayerType.Activation, FrontendLayerType.Pool2d, + FrontendLayerType.Conv2d, FrontendLayerType.Conv2d, FrontendLayerType.Activation, FrontendLayerType.Pool2d] + conv_kernel_sizes = [3, 3, 3] + conv_out_dims = [32, 34, 64] + pool_kernel_sizes = [3, 3] + pool_strides = [2, 2] + activations = [torch.nn.ReLU(), torch.nn.ReLU()] + """ + super().__init__() + + model_cfg.check_valid() + + self.cfg = model_cfg + + self.frontend_layers = nn.ModuleList([]) + + conv_layer_index = 0 + pool_layer_index = 0 + activation_layer_index = 0 + last_channel_dim = 1 + last_feat_dim = model_cfg.in_features + for layer_type in model_cfg.layer_ordering: + if layer_type == FrontendLayerType.Conv2d: + conv_out_dim = model_cfg.conv_out_dims[conv_layer_index] + conv_kernel_size = model_cfg.conv_kernel_sizes[conv_layer_index] + conv_stride = 1 if model_cfg.conv_strides is None else model_cfg.conv_strides[conv_layer_index] + conv_padding = ( + get_same_padding(conv_kernel_size) + if model_cfg.conv_paddings is None + else model_cfg.conv_paddings[conv_layer_index] + ) + + self.frontend_layers.append( + nn.Conv2d( + in_channels=last_channel_dim, + out_channels=conv_out_dim, + kernel_size=conv_kernel_size, + stride=conv_stride, + padding=conv_padding, + ) + ) + + last_channel_dim = conv_out_dim + last_feat_dim = calculate_output_dim( + in_dim=last_feat_dim, + filter_size=conv_kernel_size[1], + stride=conv_stride[1], + padding=conv_padding[1], + ) + conv_layer_index += 1 + + elif layer_type == FrontendLayerType.Pool2d: + pool_stride = None if model_cfg.pool_strides is None else model_cfg.pool_strides[pool_layer_index] + pool_kernel_size = model_cfg.pool_kernel_sizes[pool_layer_index] + pool_padding = ( + get_same_padding(pool_kernel_size) + if model_cfg.pool_paddings is None + else model_cfg.pool_paddings[pool_layer_index] + ) + + self.frontend_layers.append( + nn.MaxPool2d( + kernel_size=pool_kernel_size, + stride=pool_stride, + padding=pool_padding, + ) + ) + last_feat_dim = calculate_output_dim( + in_dim=last_feat_dim, + filter_size=pool_kernel_size[1], + stride=pool_stride[1] or pool_kernel_size[1], + padding=pool_padding[1], + ) + pool_layer_index += 1 + + elif layer_type == FrontendLayerType.Activation: + self.frontend_layers.append(model_cfg.activations[activation_layer_index]) + activation_layer_index += 1 + else: + raise NotImplementedError + + self.linear = nn.Linear( + in_features=last_feat_dim * last_channel_dim, + out_features=model_cfg.out_features, + bias=True, + ) + + def forward(self, tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert tensor.shape[-1] == self.cfg.in_features + # and add a dim + tensor = tensor[:, None, :, :] # [B,C=1,T,F] + + for layer in self.frontend_layers: + tensor = layer(tensor) + + if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.MaxPool2d): + sequence_mask = mask_pool( + sequence_mask, + kernel_size=layer.kernel_size[0], + stride=layer.stride[0], + padding=layer.padding[0], + ) + + tensor = torch.transpose(tensor, 1, 2) # transpose to [B,T",C,F"] + tensor = torch.flatten(tensor, start_dim=2, end_dim=-1) # [B,T",C*F"] + + tensor = self.linear(tensor) + + return tensor, sequence_mask diff --git a/tests/test_generic_frontend.py b/tests/test_generic_frontend.py new file mode 100644 index 00000000..916efa45 --- /dev/null +++ b/tests/test_generic_frontend.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple, Sequence, Union, Callable, List + +import torch +from torch import nn + +from i6_models.parts.frontend.generic_frontend import FrontendLayerType, GenericFrontendV1, GenericFrontendV1Config + + +@dataclass +class GenericFrontendV1TestParams: + batch: int + time: int + in_features: int + layer_ordering: Sequence[FrontendLayerType] + conv_kernel_sizes: Optional[Sequence[Tuple[int, int]]] + conv_strides: Optional[Sequence[Tuple[int, int]]] + conv_paddings: Optional[Sequence[Tuple[int, int]]] + conv_out_dims: Optional[Sequence[int]] + pool_kernel_sizes: Optional[Sequence[Tuple[int, int]]] + pool_strides: Optional[Sequence[Tuple[int, int]]] + pool_paddings: Optional[Sequence[Tuple[int, int]]] + activations: Optional[Sequence[Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]]] + out_features: int + output_shape: List[int] + in_sequence_mask: torch.Tensor + out_sequence_mask: torch.Tensor + + +def test_generic_frontend_v1(): + torch.manual_seed(42) + + def get_output_shape(test_parameters: GenericFrontendV1TestParams): + data_input = torch.randn( + test_parameters.batch, + test_parameters.time, + test_parameters.in_features, + ) + + cfg = GenericFrontendV1Config( + in_features=test_parameters.in_features, + layer_ordering=test_parameters.layer_ordering, + conv_kernel_sizes=test_parameters.conv_kernel_sizes, + conv_strides=test_parameters.conv_strides, + conv_paddings=test_parameters.conv_paddings, + conv_out_dims=test_parameters.conv_out_dims, + pool_kernel_sizes=test_parameters.pool_kernel_sizes, + pool_strides=test_parameters.pool_strides, + pool_paddings=test_parameters.pool_paddings, + activations=test_parameters.activations, + out_features=test_parameters.out_features, + ) + + output, sequence_mask = GenericFrontendV1(cfg)( + data_input, + test_parameters.in_sequence_mask, + ) + + return output.shape, sequence_mask + + for idx, test_params in enumerate( + [ + GenericFrontendV1TestParams( + batch=10, + time=50, + in_features=50, + layer_ordering=[FrontendLayerType.Conv2d, FrontendLayerType.Conv2d], + conv_kernel_sizes=[(3, 3), (5, 5)], + conv_strides=[(1, 1), (1, 1)], + conv_paddings=None, + conv_out_dims=[32, 32], + pool_kernel_sizes=None, + pool_strides=None, + pool_paddings=None, + activations=None, + out_features=384, + output_shape=[10, 50, 384], + in_sequence_mask=torch.Tensor(10 * [25 * [True] + 25 * [False]]).bool(), + out_sequence_mask=torch.Tensor(10 * [25 * [True] + 25 * [False]]).bool(), + ), + GenericFrontendV1TestParams( + batch=10, + time=50, + in_features=50, + layer_ordering=[ + FrontendLayerType.Conv2d, + FrontendLayerType.Pool2d, + FrontendLayerType.Conv2d, + FrontendLayerType.Activation, + ], + conv_kernel_sizes=[(3, 3), (5, 5)], + conv_strides=[(1, 1), (1, 1)], + conv_paddings=None, + conv_out_dims=[32, 32], + pool_kernel_sizes=[(3, 3)], + pool_strides=[(2, 2)], + pool_paddings=None, + activations=[nn.SiLU()], + out_features=384, + output_shape=[10, 25, 384], + in_sequence_mask=torch.Tensor(10 * [50 * [True] + 0 * [False]]).bool(), + out_sequence_mask=torch.Tensor(10 * [25 * [True] + 0 * [False]]).bool(), + ), + GenericFrontendV1TestParams( + batch=10, + time=50, + in_features=50, + layer_ordering=[ + FrontendLayerType.Conv2d, + FrontendLayerType.Pool2d, + FrontendLayerType.Activation, + FrontendLayerType.Conv2d, + FrontendLayerType.Activation, + ], + conv_kernel_sizes=[(3, 3), (3, 5)], + conv_strides=[(1, 1), (2, 1)], + conv_paddings=None, + conv_out_dims=[32, 32], + pool_kernel_sizes=[(3, 3)], + pool_strides=[(2, 2)], + pool_paddings=None, + activations=[nn.SiLU(), nn.SiLU()], + out_features=384, + output_shape=[10, 13, 384], + in_sequence_mask=torch.Tensor(10 * [50 * [True] + 0 * [False]]).bool(), + out_sequence_mask=torch.Tensor(10 * [13 * [True] + 0 * [False]]).bool(), + ), + GenericFrontendV1TestParams( + batch=10, + time=50, + in_features=50, + layer_ordering=[ + FrontendLayerType.Conv2d, + FrontendLayerType.Pool2d, + FrontendLayerType.Activation, + FrontendLayerType.Conv2d, + FrontendLayerType.Pool2d, + FrontendLayerType.Activation, + ], + conv_kernel_sizes=[(3, 3), (3, 5)], + conv_strides=[(1, 1), (2, 2)], + conv_paddings=None, + conv_out_dims=[32, 32], + pool_kernel_sizes=[(3, 3), (5, 5)], + pool_strides=[(1, 2), (2, 3)], + pool_paddings=None, + activations=[nn.SiLU(), nn.SiLU()], + out_features=384, + output_shape=[10, 13, 384], + in_sequence_mask=torch.Tensor(10 * [50 * [True] + 0 * [False]]).bool(), + out_sequence_mask=torch.Tensor(10 * [13 * [True] + 0 * [False]]).bool(), + ), + ] + ): + shape, seq_mask = get_output_shape(test_params) + assert list(shape) == test_params.output_shape, (shape, test_params.output_shape) + assert torch.equal(seq_mask, test_params.out_sequence_mask), ( + seq_mask.shape, + test_params.out_sequence_mask.shape, + )