Skip to content

Commit

Permalink
make dropout model modules
Browse files Browse the repository at this point in the history
  • Loading branch information
kuacakuaca committed Aug 21, 2024
1 parent 2bf2c89 commit 600752e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 40 deletions.
15 changes: 5 additions & 10 deletions i6_models/parts/conformer/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,27 +90,22 @@ class ConformerPositionwiseFeedForwardV2(ConformerPositionwiseFeedForwardV1):
def __init__(self, cfg: ConformerPositionwiseFeedForwardV2Config):
super().__init__(cfg)

self.dropout = nn.Dropout1d(cfg.dropout) if cfg.dropout_broadcast_axes else nn.Dropout(cfg.dropout)
self.dropout_broadcast_axes = cfg.dropout_broadcast_axes

def _broadcast_dropout(self, tensor: torch.Tensor) -> torch.Tensor:
if self.dropout_broadcast_axes is None:
tensor = torch.nn.functional.dropout(tensor, p=self.dropout, training=self.training)
tensor = self.dropout(tensor)
elif self.dropout_broadcast_axes == "T":
tensor = torch.nn.functional.dropout1d(
tensor.transpose(1, 2), p=self.dropout, training=self.training
).transpose(1, 2)
tensor = self.dropout(tensor.transpose(1, 2)).transpose(1, 2)
elif self.dropout_broadcast_axes == "B":
tensor = torch.nn.functional.dropout1d(
tensor.permute(1, 2, 0), p=self.dropout, training=self.training
).permute(2, 0, 1)
tensor = self.dropout(tensor.permute(1, 2, 0)).permute(2, 0, 1)
elif self.dropout_broadcast_axes == "BT":
batch_dim_size = tensor.shape[0]
feature_dim_size = tensor.shape[-1]

tensor = (
torch.nn.functional.dropout1d(
tensor.reshape(-1, feature_dim_size).transpose(0, 1), p=self.dropout, training=self.training
)
self.dropout(tensor.reshape(-1, feature_dim_size).transpose(0, 1))
.transpose(0, 1)
.reshape(batch_dim_size, -1, feature_dim_size)
)
Expand Down
15 changes: 5 additions & 10 deletions i6_models/parts/conformer/mhsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self, cfg: ConformerMHSAV2Config):

super().__init__(cfg)

self.dropout = torch.nn.Dropout1d(cfg.dropout) if cfg.dropout_broadcast_axes else torch.nn.Dropout(cfg.dropout)
self.dropout_broadcast_axes = cfg.dropout_broadcast_axes

def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> torch.Tensor:
Expand All @@ -113,23 +114,17 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
) # [B,T,F]

if self.dropout_broadcast_axes is None:
output_tensor = torch.nn.functional.dropout(output_tensor, p=self.dropout, training=self.training)
output_tensor = self.dropout(output_tensor)
elif self.dropout_broadcast_axes == "T":
output_tensor = torch.nn.functional.dropout1d(
output_tensor.transpose(1, 2), p=self.dropout, training=self.training
).transpose(1, 2)
output_tensor = self.dropout(output_tensor.transpose(1, 2)).transpose(1, 2)
elif self.dropout_broadcast_axes == "B":
output_tensor = torch.nn.functional.dropout1d(
output_tensor.permute(1, 2, 0), p=self.dropout, training=self.training
).permute(2, 0, 1)
output_tensor = self.dropout(output_tensor.permute(1, 2, 0)).permute(2, 0, 1)
elif self.dropout_broadcast_axes == "BT":
batch_dim_size = output_tensor.shape[0]
feature_dim_size = output_tensor.shape[-1]

output_tensor = (
torch.nn.functional.dropout1d(
output_tensor.reshape(-1, feature_dim_size).transpose(0, 1), p=self.dropout, training=self.training
)
self.dropout(output_tensor.reshape(-1, feature_dim_size).transpose(0, 1))
.transpose(0, 1)
.reshape(batch_dim_size, -1, feature_dim_size)
)
Expand Down
33 changes: 13 additions & 20 deletions i6_models/parts/conformer/mhsa_rel_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ def __post_init__(self) -> None:

class ConformerMHSARelPosV1(nn.Module):
"""
Conformer multi-headed self-attention module with relative positional encoding proposed by Shaw et al. (cf. https://arxiv.org/abs/1803.02155)
Conformer multi-headed self-attention module supporting
- relative positional encoding proposed by Shaw et al. (cf. https://arxiv.org/abs/1803.02155) by setting `learnable_pos_emb` to True and `with_pos_bias` to False
- and Transformer-XL style relative PE by Dai et al. (cf. https://arxiv.org/abs/1901.02860) by setting `learnable_pos_emb` to False and `with_pos_bias` to True
"""

def __init__(self, cfg: ConformerMHSARelPosV1Config):
Expand All @@ -69,11 +72,11 @@ def __init__(self, cfg: ConformerMHSARelPosV1Config):
self.learnable_pos_emb = cfg.learnable_pos_emb
self.rel_pos_clip = cfg.rel_pos_clip
self.with_pos_bias = cfg.with_pos_bias
self.pos_emb_dropout = cfg.pos_emb_dropout
self.pos_emb_dropout = nn.Dropout(cfg.pos_emb_dropout)

assert not self.learnable_pos_emb or self.rel_pos_clip

self.att_weights_dropout = cfg.att_weights_dropout
self.att_weights_dropout = nn.Dropout(cfg.att_weights_dropout)

assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"

Expand All @@ -98,7 +101,7 @@ def __init__(self, cfg: ConformerMHSARelPosV1Config):
self.pos_bias_u = nn.parameter.Parameter(torch.empty(self.num_heads, self.embed_dim_per_head))
self.pos_bias_v = nn.parameter.Parameter(torch.empty(self.num_heads, self.embed_dim_per_head))

self.dropout = cfg.dropout
self.dropout = nn.Dropout1d(cfg.dropout) if cfg.dropout_broadcast_axes else nn.Dropout(cfg.dropout)
self.dropout_broadcast_axes = cfg.dropout_broadcast_axes

self._reset_parameters() # initialize parameters
Expand Down Expand Up @@ -167,7 +170,7 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
) # [T, T+T'-1, F']

# dropout relative positional embeddings
rel_pos_embeddings = F.dropout(rel_pos_embeddings, p=self.pos_emb_dropout, training=self.training)
rel_pos_embeddings = self.pos_emb_dropout(rel_pos_embeddings)

q_with_bias_u = q + self.pos_bias_u if self.with_pos_bias else q # [B, T, #heads, F']
q_with_bias_v = q + self.pos_bias_v if self.with_pos_bias else q
Expand All @@ -185,9 +188,7 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
attn_scaled = attn * (math.sqrt(1.0 / float(self.embed_dim_per_head))) # [B, #heads, T, T']

# softmax and dropout
attn_output_weights = F.dropout(
F.softmax(attn_scaled, dim=-1), p=self.att_weights_dropout, training=self.training
) # [B, #heads, T, T']
attn_output_weights = self.att_weights_dropout(F.softmax(attn_scaled, dim=-1)) # [B, #heads, T, T']

# sequence of weighted sums over value sequence
v = value_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head) # [B, T, H, F']
Expand All @@ -200,26 +201,18 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to

output_tensor = self.out_proj(attn_output)

output_tensor = F.dropout(output_tensor, p=self.dropout, training=self.training) # [B,T,F]

if self.dropout_broadcast_axes is None:
output_tensor = torch.nn.functional.dropout(output_tensor, p=self.dropout, training=self.training)
output_tensor = self.dropout(output_tensor)
elif self.dropout_broadcast_axes == "T":
output_tensor = torch.nn.functional.dropout1d(
output_tensor.transpose(1, 2), p=self.dropout, training=self.training
).transpose(1, 2)
output_tensor = self.dropout(output_tensor.transpose(1, 2)).transpose(1, 2)
elif self.dropout_broadcast_axes == "B":
output_tensor = torch.nn.functional.dropout1d(
output_tensor.permute(1, 2, 0), p=self.dropout, training=self.training
).permute(2, 0, 1)
output_tensor = self.dropout(output_tensor.permute(1, 2, 0)).permute(2, 0, 1)
elif self.dropout_broadcast_axes == "BT":
batch_dim_size = output_tensor.shape[0]
feature_dim_size = output_tensor.shape[-1]

output_tensor = (
torch.nn.functional.dropout1d(
output_tensor.reshape(-1, feature_dim_size).transpose(0, 1), p=self.dropout, training=self.training
)
self.output(output_tensor.reshape(-1, feature_dim_size).transpose(0, 1))
.transpose(0, 1)
.reshape(batch_dim_size, -1, feature_dim_size)
)
Expand Down

0 comments on commit 600752e

Please sign in to comment.