diff --git a/i6_models/parts/conformer/convolution.py b/i6_models/parts/conformer/convolution.py index b0b3d023..87c56bc3 100644 --- a/i6_models/parts/conformer/convolution.py +++ b/i6_models/parts/conformer/convolution.py @@ -9,11 +9,12 @@ from dataclasses import dataclass from copy import deepcopy +from typing import Callable, Union, Optional, Literal import torch from torch import nn from i6_models.config import ModelConfiguration -from typing import Callable, Union, Optional +from i6_models.parts.dropout import BroadcastDropout @dataclass @@ -101,12 +102,13 @@ class ConformerConvolutionV2Config(ConformerConvolutionV1Config): Allows even kernel size """ - dropout_broadcast_axes: Optional[str] = None + dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None def check_valid(self): assert self.kernel_size % 2 == 1, "ConformerConvolutionV1 only supports odd kernel sizes" - assert self.dropout_broadcast_axes is None or self.dropout_broadcast_axes in [ + assert self.dropout_broadcast_axes in [ + None, "B", "T", "BT", @@ -124,10 +126,7 @@ def __init__(self, model_cfg: ConformerConvolutionV2Config): """ super().__init__(model_cfg) - self.dropout_broadcast_axes = model_cfg.dropout_broadcast_axes - self.dropout = ( - nn.Dropout1d(model_cfg.dropout) if model_cfg.dropout_broadcast_axes else nn.Dropout(model_cfg.dropout) - ) + self.dropout = BroadcastDropout(model_cfg.dropout, dropout_broadcast_axes=model_cfg.dropout_broadcast_axes) def forward(self, tensor: torch.Tensor) -> torch.Tensor: """ @@ -148,20 +147,6 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: tensor = self.activation(tensor) tensor = self.pointwise_conv2(tensor) - if self.dropout_broadcast_axes is None: - tensor = self.dropout(tensor) - elif self.dropout_broadcast_axes == "T": - tensor = self.dropout(tensor.transpose(1, 2)).transpose(1, 2) - elif self.dropout_broadcast_axes == "B": - tensor = self.dropout(tensor.permute(1, 2, 0)).permute(2, 0, 1) - elif self.dropout_broadcast_axes == "BT": - batch_dim_size = tensor.shape[0] - feature_dim_size = tensor.shape[-1] - - tensor = ( - self.dropout(tensor.reshape(-1, feature_dim_size).transpose(0, 1)) - .transpose(0, 1) - .reshape(batch_dim_size, -1, feature_dim_size) - ) + tensor = self.dropout(tensor) return tensor diff --git a/i6_models/parts/conformer/feedforward.py b/i6_models/parts/conformer/feedforward.py index 9628ce60..c58840dd 100644 --- a/i6_models/parts/conformer/feedforward.py +++ b/i6_models/parts/conformer/feedforward.py @@ -8,12 +8,13 @@ ] from dataclasses import dataclass -from typing import Callable, Optional +from typing import Callable, Optional, Literal import torch from torch import nn from i6_models.config import ModelConfiguration +from i6_models.parts.dropout import BroadcastDropout @dataclass @@ -68,10 +69,11 @@ class ConformerPositionwiseFeedForwardV2Config(ConformerPositionwiseFeedForwardV setting to None to disable broadcasting """ - dropout_broadcast_axes: Optional[str] = None + dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None def check_valid(self): - assert self.dropout_broadcast_axes is None or self.dropout_broadcast_axes in [ + assert self.dropout_broadcast_axes in [ + None, "B", "T", "BT", @@ -90,26 +92,7 @@ class ConformerPositionwiseFeedForwardV2(ConformerPositionwiseFeedForwardV1): def __init__(self, cfg: ConformerPositionwiseFeedForwardV2Config): super().__init__(cfg) - self.dropout = nn.Dropout1d(cfg.dropout) if cfg.dropout_broadcast_axes else nn.Dropout(cfg.dropout) - self.dropout_broadcast_axes = cfg.dropout_broadcast_axes - - def _broadcast_dropout(self, tensor: torch.Tensor) -> torch.Tensor: - if self.dropout_broadcast_axes is None: - tensor = self.dropout(tensor) - elif self.dropout_broadcast_axes == "T": - tensor = self.dropout(tensor.transpose(1, 2)).transpose(1, 2) - elif self.dropout_broadcast_axes == "B": - tensor = self.dropout(tensor.permute(1, 2, 0)).permute(2, 0, 1) - elif self.dropout_broadcast_axes == "BT": - batch_dim_size = tensor.shape[0] - feature_dim_size = tensor.shape[-1] - - tensor = ( - self.dropout(tensor.reshape(-1, feature_dim_size).transpose(0, 1)) - .transpose(0, 1) - .reshape(batch_dim_size, -1, feature_dim_size) - ) - return tensor + self.dropout = BroadcastDropout(cfg.dropout, dropout_broadcast_axes=cfg.dropout_broadcast_axes) def forward(self, tensor: torch.Tensor) -> torch.Tensor: """ @@ -120,8 +103,8 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: tensor = self.linear_ff(tensor) # [B,T,F] tensor = self.activation(tensor) # [B,T,F] - tensor = self._broadcast_dropout(tensor) # [B,T,F] + tensor = self.dropout(tensor) # [B,T,F] tensor = self.linear_out(tensor) # [B,T,F] - tensor = self._broadcast_dropout(tensor) # [B,T,F] + tensor = self.dropout(tensor) # [B,T,F] return tensor diff --git a/i6_models/parts/conformer/mhsa.py b/i6_models/parts/conformer/mhsa.py index 8eb39327..1584df23 100644 --- a/i6_models/parts/conformer/mhsa.py +++ b/i6_models/parts/conformer/mhsa.py @@ -1,13 +1,14 @@ from __future__ import annotations __all__ = ["ConformerMHSAV1", "ConformerMHSAV1Config", "ConformerMHSAV2", "ConformerMHSAV2Config"] + from dataclasses import dataclass +from typing import Optional, Literal import torch from i6_models.config import ModelConfiguration from i6_models.util import compat - -from typing import Optional +from i6_models.parts.dropout import BroadcastDropout @dataclass @@ -72,10 +73,11 @@ class ConformerMHSAV2Config(ConformerMHSAV1Config): setting to None to disable broadcasting """ - dropout_broadcast_axes: Optional[str] = None + dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None def check_valid(self): - assert self.dropout_broadcast_axes is None or self.dropout_broadcast_axes in [ + assert self.dropout_broadcast_axes in [ + None, "B", "T", "BT", @@ -95,8 +97,7 @@ def __init__(self, cfg: ConformerMHSAV2Config): super().__init__(cfg) - self.dropout = torch.nn.Dropout1d(cfg.dropout) if cfg.dropout_broadcast_axes else torch.nn.Dropout(cfg.dropout) - self.dropout_broadcast_axes = cfg.dropout_broadcast_axes + self.dropout = BroadcastDropout(cfg.dropout, dropout_broadcast_axes=cfg.dropout_broadcast_axes) def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> torch.Tensor: """ @@ -113,20 +114,6 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to output_tensor, output_tensor, output_tensor, key_padding_mask=inv_sequence_mask, need_weights=False ) # [B,T,F] - if self.dropout_broadcast_axes is None: - output_tensor = self.dropout(output_tensor) - elif self.dropout_broadcast_axes == "T": - output_tensor = self.dropout(output_tensor.transpose(1, 2)).transpose(1, 2) - elif self.dropout_broadcast_axes == "B": - output_tensor = self.dropout(output_tensor.permute(1, 2, 0)).permute(2, 0, 1) - elif self.dropout_broadcast_axes == "BT": - batch_dim_size = output_tensor.shape[0] - feature_dim_size = output_tensor.shape[-1] - - output_tensor = ( - self.dropout(output_tensor.reshape(-1, feature_dim_size).transpose(0, 1)) - .transpose(0, 1) - .reshape(batch_dim_size, -1, feature_dim_size) - ) + output_tensor = self.dropout(output_tensor) return output_tensor # [B,T,F] diff --git a/i6_models/parts/conformer/mhsa_rel_pos.py b/i6_models/parts/conformer/mhsa_rel_pos.py index 918f9f36..f2516c91 100644 --- a/i6_models/parts/conformer/mhsa_rel_pos.py +++ b/i6_models/parts/conformer/mhsa_rel_pos.py @@ -1,11 +1,10 @@ from __future__ import annotations - __all__ = ["ConformerMHSARelPosV1", "ConformerMHSARelPosV1Config"] from dataclasses import dataclass import math -from typing import Optional +from typing import Optional, Literal import torch from torch import nn @@ -13,6 +12,7 @@ from i6_models.config import ModelConfiguration from i6_models.util import compat +from i6_models.parts.dropout import BroadcastDropout @dataclass @@ -21,9 +21,12 @@ class ConformerMHSARelPosV1Config(ModelConfiguration): Attributes: input_dim: input dim and total dimension for query/key and value projections, should be divisible by `num_att_heads` num_att_heads: number of attention heads + with_bias: whether to add bias to qkv and output lienar projections att_weights_dropout: attention weights dropout learnable_pos_emb: whether to use learnable relative positional embeddings instead of fixed sinusoidal ones rel_pos_clip: maximal relative postion for embedding + with_linear_pos: whether to linearly transform the positional embeddings + separate_pos_emb_per_head: whether to apply separate linear transformation on positional embeddings for each head with_pos_bias: whether to add additional position bias terms to the attention scores pos_emb_dropout: dropout for the positional embeddings dropout: multi-headed self attention output dropout @@ -33,18 +36,22 @@ class ConformerMHSARelPosV1Config(ModelConfiguration): input_dim: int num_att_heads: int + with_bias: bool att_weights_dropout: float + learnable_pos_emb: bool + rel_pos_clip: Optional[int] + with_linear_pos: bool + with_pos_bias: bool + separate_pos_emb_per_head: bool + pos_emb_dropout: float dropout: float - learnable_pos_emb: bool = True - rel_pos_clip: Optional[int] = None - with_pos_bias: bool = False - pos_emb_dropout: float = 0.0 - dropout_broadcast_axes: Optional[str] = None + dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None def __post_init__(self) -> None: super().__post_init__() assert self.input_dim % self.num_att_heads == 0, "input_dim must be divisible by num_att_heads" - assert self.dropout_broadcast_axes is None or self.dropout_broadcast_axes in [ + assert self.dropout_broadcast_axes in [ + None, "B", "T", "BT", @@ -71,6 +78,7 @@ def __init__(self, cfg: ConformerMHSARelPosV1Config): self.learnable_pos_emb = cfg.learnable_pos_emb self.rel_pos_clip = cfg.rel_pos_clip + self.separate_pos_emb_per_head = cfg.separate_pos_emb_per_head self.with_pos_bias = cfg.with_pos_bias self.pos_emb_dropout = nn.Dropout(cfg.pos_emb_dropout) @@ -81,43 +89,42 @@ def __init__(self, cfg: ConformerMHSARelPosV1Config): assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" # projection matrices - self.q_proj_weight = nn.parameter.Parameter(torch.empty((self.embed_dim, self.embed_dim))) - self.k_proj_weight = nn.parameter.Parameter(torch.empty((self.embed_dim, self.embed_dim))) - self.v_proj_weight = nn.parameter.Parameter(torch.empty((self.embed_dim, self.embed_dim))) - - self.in_proj_bias = nn.parameter.Parameter(torch.empty(3 * self.embed_dim)) - - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.qkv_proj = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=cfg.with_bias) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=cfg.with_bias) self.register_parameter("rel_pos_embeddings", None) self.register_parameter("pos_bias_u", None) self.register_parameter("pos_bias_v", None) + self.pos_emb_dim = ( + self.embed_dim if cfg.with_linear_pos or cfg.separate_pos_emb_per_head else self.embed_dim_per_head + ) if self.learnable_pos_emb: - self.rel_pos_embeddings = nn.parameter.Parameter( - torch.empty(self.rel_pos_clip * 2 + 1, self.embed_dim // self.num_heads) + self.rel_pos_embeddings = nn.parameter.Parameter(torch.empty(self.rel_pos_clip * 2 + 1, self.pos_emb_dim)) + if cfg.with_linear_pos: + self.linear_pos = nn.Linear( + self.pos_emb_dim, + self.embed_dim if cfg.separate_pos_emb_per_head else self.embed_dim_per_head, + bias=False, ) + else: + self.linear_pos = nn.Identity() + if self.with_pos_bias: self.pos_bias_u = nn.parameter.Parameter(torch.empty(self.num_heads, self.embed_dim_per_head)) self.pos_bias_v = nn.parameter.Parameter(torch.empty(self.num_heads, self.embed_dim_per_head)) - self.dropout = nn.Dropout1d(cfg.dropout) if cfg.dropout_broadcast_axes else nn.Dropout(cfg.dropout) - self.dropout_broadcast_axes = cfg.dropout_broadcast_axes + self.dropout = BroadcastDropout(cfg.dropout, dropout_broadcast_axes=cfg.dropout_broadcast_axes) - self._reset_parameters() # initialize parameters + self._reset_parameters() def _reset_parameters(self): - nn.init.xavier_uniform_(self.q_proj_weight) - nn.init.xavier_uniform_(self.k_proj_weight) - nn.init.xavier_uniform_(self.v_proj_weight) - - # TODO: choose kind of initialization if self.learnable_pos_emb: - nn.init.normal_(self.rel_pos_embeddings) + nn.init.xavier_normal_(self.rel_pos_embeddings) if self.with_pos_bias: - nn.init.constant_(self.pos_bias_u, 0.0) - nn.init.constant_(self.pos_bias_v, 0.0) - nn.init.constant_(self.in_proj_bias, 0.0) + # init taken from espnet default + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> torch.Tensor: """ @@ -138,16 +145,10 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to torch.zeros_like(inv_sequence_mask, dtype=input_tensor.dtype) .masked_fill(inv_sequence_mask, float("-inf")) .view(batch_dim_size, 1, 1, time_dim_size) - .expand(-1, self.num_heads, -1, -1) - ) # [B, #heads, 1, T'] + ) # [B, 1, 1, T'] # query, key and value sequences - bias_k, bias_q, bias_v = self.in_proj_bias.chunk(3) - - query_seq = F.linear(output_tensor, self.q_proj_weight, bias_q) # [B, T, #heads * F'] - key_seq = F.linear(output_tensor, self.k_proj_weight, bias_k) - value_seq = F.linear(output_tensor, self.v_proj_weight, bias_v) - + query_seq, key_seq, value_seq = self.qkv_proj(output_tensor).chunk(3, dim=-1) # [B, T, #heads * F'] q = query_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head) # [B, T, #heads, F'] k = key_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head) # [B, T', #heads, F'] @@ -160,17 +161,28 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to final_mat = distance_mat_clipped + self.rel_pos_clip - rel_pos_embeddings = self.rel_pos_embeddings[final_mat] # [T, T', F'] + rel_pos_embeddings = self.rel_pos_embeddings[final_mat] # [T, T', pos_emb_dim] else: rel_pos_embeddings = self._sinusoidal_pe( torch.arange(time_dim_size - 1, -time_dim_size, -1, device=input_tensor.device, dtype=torch.float32), - self.embed_dim_per_head, - ).expand( - time_dim_size, 2 * time_dim_size - 1, self.embed_dim_per_head - ) # [T, T+T'-1, F'] + self.pos_emb_dim, + ).view( + 1, 2 * time_dim_size - 1, self.pos_emb_dim + ) # [1, T+T'-1, pos_emb_dim] # dropout relative positional embeddings - rel_pos_embeddings = self.pos_emb_dropout(rel_pos_embeddings) + rel_pos_embeddings = self.pos_emb_dropout( + rel_pos_embeddings + ) # [T, T', pos_emb_dim] or [1, T+T'-1, pos_emb_dim] + rel_pos_embeddings = rel_pos_embeddings.unsqueeze(2) # [T, T', 1, pos_emb_dim] or [1, T+T'-1, 1, pos_emb_dim] + + # linear transformation or identity + rel_pos_embeddings = self.linear_pos(rel_pos_embeddings) # [T, T', 1, F'|F] or [1, T+T'-1, 1, F'|F] + + if self.separate_pos_emb_per_head: + rel_pos_embeddings = rel_pos_embeddings.squeeze(2).reshape( + *rel_pos_embeddings.shape[:2], -1, self.embed_dim_per_head + ) # [T, T', #heads, F'] or [1, T+T'-1, #heads, F'] q_with_bias_u = q + self.pos_bias_u if self.with_pos_bias else q # [B, T, #heads, F'] q_with_bias_v = q + self.pos_bias_v if self.with_pos_bias else q @@ -179,10 +191,12 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to attn_ac = torch.einsum("bihf, bjhf -> bhij", q_with_bias_u, k) # [B, #heads, T, T'] # attention matrix b and d - attn_bd = torch.einsum("bihf, ijf -> bhij", q_with_bias_v, rel_pos_embeddings) + attn_bd = torch.einsum( + "bihf, ijhf -> bhij", q_with_bias_v, rel_pos_embeddings + ) # [B, #heads, T, T'] or [B, #heads, T, T+T'+1] if not self.learnable_pos_emb: - attn_bd = self._rel_shift_bhij(attn_bd, k_len=time_dim_size) + attn_bd = self._rel_shift_bhij(attn_bd, k_len=time_dim_size) # [B, #heads, T, T'] attn = attn_ac + attn_bd + mask # [B, #heads, T, T'] attn_scaled = attn * (math.sqrt(1.0 / float(self.embed_dim_per_head))) # [B, #heads, T, T'] @@ -201,21 +215,7 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to output_tensor = self.out_proj(attn_output) - if self.dropout_broadcast_axes is None: - output_tensor = self.dropout(output_tensor) - elif self.dropout_broadcast_axes == "T": - output_tensor = self.dropout(output_tensor.transpose(1, 2)).transpose(1, 2) - elif self.dropout_broadcast_axes == "B": - output_tensor = self.dropout(output_tensor.permute(1, 2, 0)).permute(2, 0, 1) - elif self.dropout_broadcast_axes == "BT": - batch_dim_size = output_tensor.shape[0] - feature_dim_size = output_tensor.shape[-1] - - output_tensor = ( - self.output(output_tensor.reshape(-1, feature_dim_size).transpose(0, 1)) - .transpose(0, 1) - .reshape(batch_dim_size, -1, feature_dim_size) - ) + output_tensor = self.dropout(output_tensor) return output_tensor # [B,T,F] diff --git a/i6_models/parts/dropout.py b/i6_models/parts/dropout.py new file mode 100644 index 00000000..d240212f --- /dev/null +++ b/i6_models/parts/dropout.py @@ -0,0 +1,55 @@ +from typing import Optional, Literal + +import torch +from torch import nn + + +class BroadcastDropout(nn.Module): + """ + customized dropout module supporting dropout broadcasting + supported variants are: + - no broadcasting (default): dropout_broadcast_axes=None + - broadcast over the batch axis: dropout_broadcast_axes='B' + - broadcast over the time axis: dropout_broadcast_axes='T' + - broadcast over the batch and time axes: dropout_broadcast_axes='BT' + """ + + def __init__(self, p: float, dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None): + super().__init__() + + self.p = p + assert dropout_broadcast_axes is None or dropout_broadcast_axes in [ + "B", + "T", + "BT", + ], "invalid value, supported are None, 'B', 'T' and 'BT'" + self.dropout_broadcast_axes = dropout_broadcast_axes + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + """ + assumes input tensor of shape [B, T, F] + return tensor of shape [B, T, F] + """ + if self.dropout_broadcast_axes is None: + tensor = torch.nn.functional.dropout(tensor, p=self.p, training=self.training) + elif self.dropout_broadcast_axes == "T": # [B, T, F] -> [B, F, T] -> [B, T, F] + tensor = torch.nn.functional.dropout1d(tensor.transpose(1, 2), p=self.p, training=self.training).transpose( + 1, 2 + ) + elif self.dropout_broadcast_axes == "B": # [B, T, F] -> [T, F, B] -> [B, T, F] + tensor = torch.nn.functional.dropout1d(tensor.permute(1, 2, 0), p=self.p, training=self.training).permute( + 2, 0, 1 + ) + elif self.dropout_broadcast_axes == "BT": # [B, T, F] -> [B*T, F] -> [F, B*T] -> [B*T, F] -> [B, T, F] + batch_dim_size = tensor.shape[0] + feature_dim_size = tensor.shape[-1] + + tensor = ( + torch.nn.functional.dropout1d( + tensor.reshape(-1, feature_dim_size).transpose(0, 1), p=self.p, training=self.training + ) + .transpose(0, 1) + .reshape(batch_dim_size, -1, feature_dim_size) + ) + + return tensor diff --git a/tests/test_conformer_rel_pos.py b/tests/test_conformer_rel_pos.py index 78e4439b..88b2dba3 100644 --- a/tests/test_conformer_rel_pos.py +++ b/tests/test_conformer_rel_pos.py @@ -76,10 +76,13 @@ def get_output_shape( input_shape, seq_len, input_dim, + with_bias=True, num_att_heads=8, att_weights_dropout=0.1, dropout=0.1, learnable_pos_emb=True, + with_linear_pos=False, + separate_pos_emb_per_head=False, rel_pos_clip=16, with_pos_bias=False, pos_emb_dropout=0.0, @@ -90,9 +93,12 @@ def get_output_shape( cfg = ConformerMHSARelPosV1Config( input_dim=input_dim, num_att_heads=num_att_heads, + with_bias=with_bias, att_weights_dropout=att_weights_dropout, dropout=dropout, learnable_pos_emb=learnable_pos_emb, + with_linear_pos=with_linear_pos, + separate_pos_emb_per_head=separate_pos_emb_per_head, rel_pos_clip=rel_pos_clip, with_pos_bias=with_pos_bias, pos_emb_dropout=pos_emb_dropout, @@ -110,7 +116,9 @@ def get_output_shape( input_shape = [4, 15, 32] # B,T,F seq_len = [15, 12, 10, 15] - for learnable_pos_emb, with_pos_bias, pos_emb_dropout in product([True, False], [True, False], [0.0, 0.1]): + for learnable_pos_emb, with_pos_bias, pos_emb_dropout, with_linear_pos, separate_pos_emb_per_head in product( + [True, False], [True, False], [0.0, 0.1], [True, False], [True, False] + ): assert get_output_shape( input_shape, seq_len, @@ -118,4 +126,6 @@ def get_output_shape( learnable_pos_emb=learnable_pos_emb, with_pos_bias=with_pos_bias, pos_emb_dropout=pos_emb_dropout, + with_linear_pos=with_linear_pos, + separate_pos_emb_per_head=separate_pos_emb_per_head, ) == [4, 15, 32]