Skip to content

Commit

Permalink
adress feedback
Browse files Browse the repository at this point in the history
- make own Dropout module
- add additional parameters `with_bias`, `with_linear_pos` and `separate_emb_pos_per_head`
- remove default values
- remove tensor expand exps
  • Loading branch information
kuacakuaca committed Sep 4, 2024
1 parent b8db085 commit 83b23c9
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 129 deletions.
29 changes: 7 additions & 22 deletions i6_models/parts/conformer/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@

from dataclasses import dataclass
from copy import deepcopy
from typing import Callable, Union, Optional, Literal

import torch
from torch import nn
from i6_models.config import ModelConfiguration
from typing import Callable, Union, Optional
from i6_models.parts.dropout import BroadcastDropout


@dataclass
Expand Down Expand Up @@ -101,12 +102,13 @@ class ConformerConvolutionV2Config(ConformerConvolutionV1Config):
Allows even kernel size
"""

dropout_broadcast_axes: Optional[str] = None
dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None

def check_valid(self):
assert self.kernel_size % 2 == 1, "ConformerConvolutionV1 only supports odd kernel sizes"

assert self.dropout_broadcast_axes is None or self.dropout_broadcast_axes in [
assert self.dropout_broadcast_axes in [
None,
"B",
"T",
"BT",
Expand All @@ -124,10 +126,7 @@ def __init__(self, model_cfg: ConformerConvolutionV2Config):
"""
super().__init__(model_cfg)

self.dropout_broadcast_axes = model_cfg.dropout_broadcast_axes
self.dropout = (
nn.Dropout1d(model_cfg.dropout) if model_cfg.dropout_broadcast_axes else nn.Dropout(model_cfg.dropout)
)
self.dropout = BroadcastDropout(model_cfg.dropout, dropout_broadcast_axes=model_cfg.dropout_broadcast_axes)

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -148,20 +147,6 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
tensor = self.activation(tensor)
tensor = self.pointwise_conv2(tensor)

if self.dropout_broadcast_axes is None:
tensor = self.dropout(tensor)
elif self.dropout_broadcast_axes == "T":
tensor = self.dropout(tensor.transpose(1, 2)).transpose(1, 2)
elif self.dropout_broadcast_axes == "B":
tensor = self.dropout(tensor.permute(1, 2, 0)).permute(2, 0, 1)
elif self.dropout_broadcast_axes == "BT":
batch_dim_size = tensor.shape[0]
feature_dim_size = tensor.shape[-1]

tensor = (
self.dropout(tensor.reshape(-1, feature_dim_size).transpose(0, 1))
.transpose(0, 1)
.reshape(batch_dim_size, -1, feature_dim_size)
)
tensor = self.dropout(tensor)

return tensor
33 changes: 8 additions & 25 deletions i6_models/parts/conformer/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
]

from dataclasses import dataclass
from typing import Callable, Optional
from typing import Callable, Optional, Literal

import torch
from torch import nn

from i6_models.config import ModelConfiguration
from i6_models.parts.dropout import BroadcastDropout


@dataclass
Expand Down Expand Up @@ -68,10 +69,11 @@ class ConformerPositionwiseFeedForwardV2Config(ConformerPositionwiseFeedForwardV
setting to None to disable broadcasting
"""

dropout_broadcast_axes: Optional[str] = None
dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None

def check_valid(self):
assert self.dropout_broadcast_axes is None or self.dropout_broadcast_axes in [
assert self.dropout_broadcast_axes in [
None,
"B",
"T",
"BT",
Expand All @@ -90,26 +92,7 @@ class ConformerPositionwiseFeedForwardV2(ConformerPositionwiseFeedForwardV1):
def __init__(self, cfg: ConformerPositionwiseFeedForwardV2Config):
super().__init__(cfg)

self.dropout = nn.Dropout1d(cfg.dropout) if cfg.dropout_broadcast_axes else nn.Dropout(cfg.dropout)
self.dropout_broadcast_axes = cfg.dropout_broadcast_axes

def _broadcast_dropout(self, tensor: torch.Tensor) -> torch.Tensor:
if self.dropout_broadcast_axes is None:
tensor = self.dropout(tensor)
elif self.dropout_broadcast_axes == "T":
tensor = self.dropout(tensor.transpose(1, 2)).transpose(1, 2)
elif self.dropout_broadcast_axes == "B":
tensor = self.dropout(tensor.permute(1, 2, 0)).permute(2, 0, 1)
elif self.dropout_broadcast_axes == "BT":
batch_dim_size = tensor.shape[0]
feature_dim_size = tensor.shape[-1]

tensor = (
self.dropout(tensor.reshape(-1, feature_dim_size).transpose(0, 1))
.transpose(0, 1)
.reshape(batch_dim_size, -1, feature_dim_size)
)
return tensor
self.dropout = BroadcastDropout(cfg.dropout, dropout_broadcast_axes=cfg.dropout_broadcast_axes)

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -120,8 +103,8 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
tensor = self.linear_ff(tensor) # [B,T,F]
tensor = self.activation(tensor) # [B,T,F]

tensor = self._broadcast_dropout(tensor) # [B,T,F]
tensor = self.dropout(tensor) # [B,T,F]
tensor = self.linear_out(tensor) # [B,T,F]
tensor = self._broadcast_dropout(tensor) # [B,T,F]
tensor = self.dropout(tensor) # [B,T,F]

return tensor
29 changes: 8 additions & 21 deletions i6_models/parts/conformer/mhsa.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

__all__ = ["ConformerMHSAV1", "ConformerMHSAV1Config", "ConformerMHSAV2", "ConformerMHSAV2Config"]

from dataclasses import dataclass
from typing import Optional, Literal
import torch

from i6_models.config import ModelConfiguration
from i6_models.util import compat

from typing import Optional
from i6_models.parts.dropout import BroadcastDropout


@dataclass
Expand Down Expand Up @@ -72,10 +73,11 @@ class ConformerMHSAV2Config(ConformerMHSAV1Config):
setting to None to disable broadcasting
"""

dropout_broadcast_axes: Optional[str] = None
dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None

def check_valid(self):
assert self.dropout_broadcast_axes is None or self.dropout_broadcast_axes in [
assert self.dropout_broadcast_axes in [
None,
"B",
"T",
"BT",
Expand All @@ -95,8 +97,7 @@ def __init__(self, cfg: ConformerMHSAV2Config):

super().__init__(cfg)

self.dropout = torch.nn.Dropout1d(cfg.dropout) if cfg.dropout_broadcast_axes else torch.nn.Dropout(cfg.dropout)
self.dropout_broadcast_axes = cfg.dropout_broadcast_axes
self.dropout = BroadcastDropout(cfg.dropout, dropout_broadcast_axes=cfg.dropout_broadcast_axes)

def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -113,20 +114,6 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
output_tensor, output_tensor, output_tensor, key_padding_mask=inv_sequence_mask, need_weights=False
) # [B,T,F]

if self.dropout_broadcast_axes is None:
output_tensor = self.dropout(output_tensor)
elif self.dropout_broadcast_axes == "T":
output_tensor = self.dropout(output_tensor.transpose(1, 2)).transpose(1, 2)
elif self.dropout_broadcast_axes == "B":
output_tensor = self.dropout(output_tensor.permute(1, 2, 0)).permute(2, 0, 1)
elif self.dropout_broadcast_axes == "BT":
batch_dim_size = output_tensor.shape[0]
feature_dim_size = output_tensor.shape[-1]

output_tensor = (
self.dropout(output_tensor.reshape(-1, feature_dim_size).transpose(0, 1))
.transpose(0, 1)
.reshape(batch_dim_size, -1, feature_dim_size)
)
output_tensor = self.dropout(output_tensor)

return output_tensor # [B,T,F]
Loading

0 comments on commit 83b23c9

Please sign in to comment.