From 83b23c9b0e5ebc6e711d880485d6928d15ad24e8 Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Wed, 4 Sep 2024 12:40:52 -0400
Subject: [PATCH] adress feedback - make own Dropout module - add additional
parameters `with_bias`, `with_linear_pos` and `separate_emb_pos_per_head` -
remove default values - remove tensor expand exps
---
i6_models/parts/conformer/convolution.py | 29 ++----
i6_models/parts/conformer/feedforward.py | 33 ++----
i6_models/parts/conformer/mhsa.py | 29 ++----
i6_models/parts/conformer/mhsa_rel_pos.py | 120 +++++++++++-----------
i6_models/parts/dropout.py | 55 ++++++++++
tests/test_conformer_rel_pos.py | 12 ++-
6 files changed, 149 insertions(+), 129 deletions(-)
create mode 100644 i6_models/parts/dropout.py
diff --git a/i6_models/parts/conformer/convolution.py b/i6_models/parts/conformer/convolution.py
index b0b3d023..87c56bc3 100644
--- a/i6_models/parts/conformer/convolution.py
+++ b/i6_models/parts/conformer/convolution.py
@@ -9,11 +9,12 @@
from dataclasses import dataclass
from copy import deepcopy
+from typing import Callable, Union, Optional, Literal
import torch
from torch import nn
from i6_models.config import ModelConfiguration
-from typing import Callable, Union, Optional
+from i6_models.parts.dropout import BroadcastDropout
@dataclass
@@ -101,12 +102,13 @@ class ConformerConvolutionV2Config(ConformerConvolutionV1Config):
Allows even kernel size
"""
- dropout_broadcast_axes: Optional[str] = None
+ dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None
def check_valid(self):
assert self.kernel_size % 2 == 1, "ConformerConvolutionV1 only supports odd kernel sizes"
- assert self.dropout_broadcast_axes is None or self.dropout_broadcast_axes in [
+ assert self.dropout_broadcast_axes in [
+ None,
"B",
"T",
"BT",
@@ -124,10 +126,7 @@ def __init__(self, model_cfg: ConformerConvolutionV2Config):
"""
super().__init__(model_cfg)
- self.dropout_broadcast_axes = model_cfg.dropout_broadcast_axes
- self.dropout = (
- nn.Dropout1d(model_cfg.dropout) if model_cfg.dropout_broadcast_axes else nn.Dropout(model_cfg.dropout)
- )
+ self.dropout = BroadcastDropout(model_cfg.dropout, dropout_broadcast_axes=model_cfg.dropout_broadcast_axes)
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""
@@ -148,20 +147,6 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
tensor = self.activation(tensor)
tensor = self.pointwise_conv2(tensor)
- if self.dropout_broadcast_axes is None:
- tensor = self.dropout(tensor)
- elif self.dropout_broadcast_axes == "T":
- tensor = self.dropout(tensor.transpose(1, 2)).transpose(1, 2)
- elif self.dropout_broadcast_axes == "B":
- tensor = self.dropout(tensor.permute(1, 2, 0)).permute(2, 0, 1)
- elif self.dropout_broadcast_axes == "BT":
- batch_dim_size = tensor.shape[0]
- feature_dim_size = tensor.shape[-1]
-
- tensor = (
- self.dropout(tensor.reshape(-1, feature_dim_size).transpose(0, 1))
- .transpose(0, 1)
- .reshape(batch_dim_size, -1, feature_dim_size)
- )
+ tensor = self.dropout(tensor)
return tensor
diff --git a/i6_models/parts/conformer/feedforward.py b/i6_models/parts/conformer/feedforward.py
index 9628ce60..c58840dd 100644
--- a/i6_models/parts/conformer/feedforward.py
+++ b/i6_models/parts/conformer/feedforward.py
@@ -8,12 +8,13 @@
]
from dataclasses import dataclass
-from typing import Callable, Optional
+from typing import Callable, Optional, Literal
import torch
from torch import nn
from i6_models.config import ModelConfiguration
+from i6_models.parts.dropout import BroadcastDropout
@dataclass
@@ -68,10 +69,11 @@ class ConformerPositionwiseFeedForwardV2Config(ConformerPositionwiseFeedForwardV
setting to None to disable broadcasting
"""
- dropout_broadcast_axes: Optional[str] = None
+ dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None
def check_valid(self):
- assert self.dropout_broadcast_axes is None or self.dropout_broadcast_axes in [
+ assert self.dropout_broadcast_axes in [
+ None,
"B",
"T",
"BT",
@@ -90,26 +92,7 @@ class ConformerPositionwiseFeedForwardV2(ConformerPositionwiseFeedForwardV1):
def __init__(self, cfg: ConformerPositionwiseFeedForwardV2Config):
super().__init__(cfg)
- self.dropout = nn.Dropout1d(cfg.dropout) if cfg.dropout_broadcast_axes else nn.Dropout(cfg.dropout)
- self.dropout_broadcast_axes = cfg.dropout_broadcast_axes
-
- def _broadcast_dropout(self, tensor: torch.Tensor) -> torch.Tensor:
- if self.dropout_broadcast_axes is None:
- tensor = self.dropout(tensor)
- elif self.dropout_broadcast_axes == "T":
- tensor = self.dropout(tensor.transpose(1, 2)).transpose(1, 2)
- elif self.dropout_broadcast_axes == "B":
- tensor = self.dropout(tensor.permute(1, 2, 0)).permute(2, 0, 1)
- elif self.dropout_broadcast_axes == "BT":
- batch_dim_size = tensor.shape[0]
- feature_dim_size = tensor.shape[-1]
-
- tensor = (
- self.dropout(tensor.reshape(-1, feature_dim_size).transpose(0, 1))
- .transpose(0, 1)
- .reshape(batch_dim_size, -1, feature_dim_size)
- )
- return tensor
+ self.dropout = BroadcastDropout(cfg.dropout, dropout_broadcast_axes=cfg.dropout_broadcast_axes)
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""
@@ -120,8 +103,8 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
tensor = self.linear_ff(tensor) # [B,T,F]
tensor = self.activation(tensor) # [B,T,F]
- tensor = self._broadcast_dropout(tensor) # [B,T,F]
+ tensor = self.dropout(tensor) # [B,T,F]
tensor = self.linear_out(tensor) # [B,T,F]
- tensor = self._broadcast_dropout(tensor) # [B,T,F]
+ tensor = self.dropout(tensor) # [B,T,F]
return tensor
diff --git a/i6_models/parts/conformer/mhsa.py b/i6_models/parts/conformer/mhsa.py
index 8eb39327..1584df23 100644
--- a/i6_models/parts/conformer/mhsa.py
+++ b/i6_models/parts/conformer/mhsa.py
@@ -1,13 +1,14 @@
from __future__ import annotations
__all__ = ["ConformerMHSAV1", "ConformerMHSAV1Config", "ConformerMHSAV2", "ConformerMHSAV2Config"]
+
from dataclasses import dataclass
+from typing import Optional, Literal
import torch
from i6_models.config import ModelConfiguration
from i6_models.util import compat
-
-from typing import Optional
+from i6_models.parts.dropout import BroadcastDropout
@dataclass
@@ -72,10 +73,11 @@ class ConformerMHSAV2Config(ConformerMHSAV1Config):
setting to None to disable broadcasting
"""
- dropout_broadcast_axes: Optional[str] = None
+ dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None
def check_valid(self):
- assert self.dropout_broadcast_axes is None or self.dropout_broadcast_axes in [
+ assert self.dropout_broadcast_axes in [
+ None,
"B",
"T",
"BT",
@@ -95,8 +97,7 @@ def __init__(self, cfg: ConformerMHSAV2Config):
super().__init__(cfg)
- self.dropout = torch.nn.Dropout1d(cfg.dropout) if cfg.dropout_broadcast_axes else torch.nn.Dropout(cfg.dropout)
- self.dropout_broadcast_axes = cfg.dropout_broadcast_axes
+ self.dropout = BroadcastDropout(cfg.dropout, dropout_broadcast_axes=cfg.dropout_broadcast_axes)
def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> torch.Tensor:
"""
@@ -113,20 +114,6 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
output_tensor, output_tensor, output_tensor, key_padding_mask=inv_sequence_mask, need_weights=False
) # [B,T,F]
- if self.dropout_broadcast_axes is None:
- output_tensor = self.dropout(output_tensor)
- elif self.dropout_broadcast_axes == "T":
- output_tensor = self.dropout(output_tensor.transpose(1, 2)).transpose(1, 2)
- elif self.dropout_broadcast_axes == "B":
- output_tensor = self.dropout(output_tensor.permute(1, 2, 0)).permute(2, 0, 1)
- elif self.dropout_broadcast_axes == "BT":
- batch_dim_size = output_tensor.shape[0]
- feature_dim_size = output_tensor.shape[-1]
-
- output_tensor = (
- self.dropout(output_tensor.reshape(-1, feature_dim_size).transpose(0, 1))
- .transpose(0, 1)
- .reshape(batch_dim_size, -1, feature_dim_size)
- )
+ output_tensor = self.dropout(output_tensor)
return output_tensor # [B,T,F]
diff --git a/i6_models/parts/conformer/mhsa_rel_pos.py b/i6_models/parts/conformer/mhsa_rel_pos.py
index 918f9f36..f2516c91 100644
--- a/i6_models/parts/conformer/mhsa_rel_pos.py
+++ b/i6_models/parts/conformer/mhsa_rel_pos.py
@@ -1,11 +1,10 @@
from __future__ import annotations
-
__all__ = ["ConformerMHSARelPosV1", "ConformerMHSARelPosV1Config"]
from dataclasses import dataclass
import math
-from typing import Optional
+from typing import Optional, Literal
import torch
from torch import nn
@@ -13,6 +12,7 @@
from i6_models.config import ModelConfiguration
from i6_models.util import compat
+from i6_models.parts.dropout import BroadcastDropout
@dataclass
@@ -21,9 +21,12 @@ class ConformerMHSARelPosV1Config(ModelConfiguration):
Attributes:
input_dim: input dim and total dimension for query/key and value projections, should be divisible by `num_att_heads`
num_att_heads: number of attention heads
+ with_bias: whether to add bias to qkv and output lienar projections
att_weights_dropout: attention weights dropout
learnable_pos_emb: whether to use learnable relative positional embeddings instead of fixed sinusoidal ones
rel_pos_clip: maximal relative postion for embedding
+ with_linear_pos: whether to linearly transform the positional embeddings
+ separate_pos_emb_per_head: whether to apply separate linear transformation on positional embeddings for each head
with_pos_bias: whether to add additional position bias terms to the attention scores
pos_emb_dropout: dropout for the positional embeddings
dropout: multi-headed self attention output dropout
@@ -33,18 +36,22 @@ class ConformerMHSARelPosV1Config(ModelConfiguration):
input_dim: int
num_att_heads: int
+ with_bias: bool
att_weights_dropout: float
+ learnable_pos_emb: bool
+ rel_pos_clip: Optional[int]
+ with_linear_pos: bool
+ with_pos_bias: bool
+ separate_pos_emb_per_head: bool
+ pos_emb_dropout: float
dropout: float
- learnable_pos_emb: bool = True
- rel_pos_clip: Optional[int] = None
- with_pos_bias: bool = False
- pos_emb_dropout: float = 0.0
- dropout_broadcast_axes: Optional[str] = None
+ dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None
def __post_init__(self) -> None:
super().__post_init__()
assert self.input_dim % self.num_att_heads == 0, "input_dim must be divisible by num_att_heads"
- assert self.dropout_broadcast_axes is None or self.dropout_broadcast_axes in [
+ assert self.dropout_broadcast_axes in [
+ None,
"B",
"T",
"BT",
@@ -71,6 +78,7 @@ def __init__(self, cfg: ConformerMHSARelPosV1Config):
self.learnable_pos_emb = cfg.learnable_pos_emb
self.rel_pos_clip = cfg.rel_pos_clip
+ self.separate_pos_emb_per_head = cfg.separate_pos_emb_per_head
self.with_pos_bias = cfg.with_pos_bias
self.pos_emb_dropout = nn.Dropout(cfg.pos_emb_dropout)
@@ -81,43 +89,42 @@ def __init__(self, cfg: ConformerMHSARelPosV1Config):
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
# projection matrices
- self.q_proj_weight = nn.parameter.Parameter(torch.empty((self.embed_dim, self.embed_dim)))
- self.k_proj_weight = nn.parameter.Parameter(torch.empty((self.embed_dim, self.embed_dim)))
- self.v_proj_weight = nn.parameter.Parameter(torch.empty((self.embed_dim, self.embed_dim)))
-
- self.in_proj_bias = nn.parameter.Parameter(torch.empty(3 * self.embed_dim))
-
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
+ self.qkv_proj = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=cfg.with_bias)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=cfg.with_bias)
self.register_parameter("rel_pos_embeddings", None)
self.register_parameter("pos_bias_u", None)
self.register_parameter("pos_bias_v", None)
+ self.pos_emb_dim = (
+ self.embed_dim if cfg.with_linear_pos or cfg.separate_pos_emb_per_head else self.embed_dim_per_head
+ )
if self.learnable_pos_emb:
- self.rel_pos_embeddings = nn.parameter.Parameter(
- torch.empty(self.rel_pos_clip * 2 + 1, self.embed_dim // self.num_heads)
+ self.rel_pos_embeddings = nn.parameter.Parameter(torch.empty(self.rel_pos_clip * 2 + 1, self.pos_emb_dim))
+ if cfg.with_linear_pos:
+ self.linear_pos = nn.Linear(
+ self.pos_emb_dim,
+ self.embed_dim if cfg.separate_pos_emb_per_head else self.embed_dim_per_head,
+ bias=False,
)
+ else:
+ self.linear_pos = nn.Identity()
+
if self.with_pos_bias:
self.pos_bias_u = nn.parameter.Parameter(torch.empty(self.num_heads, self.embed_dim_per_head))
self.pos_bias_v = nn.parameter.Parameter(torch.empty(self.num_heads, self.embed_dim_per_head))
- self.dropout = nn.Dropout1d(cfg.dropout) if cfg.dropout_broadcast_axes else nn.Dropout(cfg.dropout)
- self.dropout_broadcast_axes = cfg.dropout_broadcast_axes
+ self.dropout = BroadcastDropout(cfg.dropout, dropout_broadcast_axes=cfg.dropout_broadcast_axes)
- self._reset_parameters() # initialize parameters
+ self._reset_parameters()
def _reset_parameters(self):
- nn.init.xavier_uniform_(self.q_proj_weight)
- nn.init.xavier_uniform_(self.k_proj_weight)
- nn.init.xavier_uniform_(self.v_proj_weight)
-
- # TODO: choose kind of initialization
if self.learnable_pos_emb:
- nn.init.normal_(self.rel_pos_embeddings)
+ nn.init.xavier_normal_(self.rel_pos_embeddings)
if self.with_pos_bias:
- nn.init.constant_(self.pos_bias_u, 0.0)
- nn.init.constant_(self.pos_bias_v, 0.0)
- nn.init.constant_(self.in_proj_bias, 0.0)
+ # init taken from espnet default
+ nn.init.xavier_uniform_(self.pos_bias_u)
+ nn.init.xavier_uniform_(self.pos_bias_v)
def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> torch.Tensor:
"""
@@ -138,16 +145,10 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
torch.zeros_like(inv_sequence_mask, dtype=input_tensor.dtype)
.masked_fill(inv_sequence_mask, float("-inf"))
.view(batch_dim_size, 1, 1, time_dim_size)
- .expand(-1, self.num_heads, -1, -1)
- ) # [B, #heads, 1, T']
+ ) # [B, 1, 1, T']
# query, key and value sequences
- bias_k, bias_q, bias_v = self.in_proj_bias.chunk(3)
-
- query_seq = F.linear(output_tensor, self.q_proj_weight, bias_q) # [B, T, #heads * F']
- key_seq = F.linear(output_tensor, self.k_proj_weight, bias_k)
- value_seq = F.linear(output_tensor, self.v_proj_weight, bias_v)
-
+ query_seq, key_seq, value_seq = self.qkv_proj(output_tensor).chunk(3, dim=-1) # [B, T, #heads * F']
q = query_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head) # [B, T, #heads, F']
k = key_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head) # [B, T', #heads, F']
@@ -160,17 +161,28 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
final_mat = distance_mat_clipped + self.rel_pos_clip
- rel_pos_embeddings = self.rel_pos_embeddings[final_mat] # [T, T', F']
+ rel_pos_embeddings = self.rel_pos_embeddings[final_mat] # [T, T', pos_emb_dim]
else:
rel_pos_embeddings = self._sinusoidal_pe(
torch.arange(time_dim_size - 1, -time_dim_size, -1, device=input_tensor.device, dtype=torch.float32),
- self.embed_dim_per_head,
- ).expand(
- time_dim_size, 2 * time_dim_size - 1, self.embed_dim_per_head
- ) # [T, T+T'-1, F']
+ self.pos_emb_dim,
+ ).view(
+ 1, 2 * time_dim_size - 1, self.pos_emb_dim
+ ) # [1, T+T'-1, pos_emb_dim]
# dropout relative positional embeddings
- rel_pos_embeddings = self.pos_emb_dropout(rel_pos_embeddings)
+ rel_pos_embeddings = self.pos_emb_dropout(
+ rel_pos_embeddings
+ ) # [T, T', pos_emb_dim] or [1, T+T'-1, pos_emb_dim]
+ rel_pos_embeddings = rel_pos_embeddings.unsqueeze(2) # [T, T', 1, pos_emb_dim] or [1, T+T'-1, 1, pos_emb_dim]
+
+ # linear transformation or identity
+ rel_pos_embeddings = self.linear_pos(rel_pos_embeddings) # [T, T', 1, F'|F] or [1, T+T'-1, 1, F'|F]
+
+ if self.separate_pos_emb_per_head:
+ rel_pos_embeddings = rel_pos_embeddings.squeeze(2).reshape(
+ *rel_pos_embeddings.shape[:2], -1, self.embed_dim_per_head
+ ) # [T, T', #heads, F'] or [1, T+T'-1, #heads, F']
q_with_bias_u = q + self.pos_bias_u if self.with_pos_bias else q # [B, T, #heads, F']
q_with_bias_v = q + self.pos_bias_v if self.with_pos_bias else q
@@ -179,10 +191,12 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
attn_ac = torch.einsum("bihf, bjhf -> bhij", q_with_bias_u, k) # [B, #heads, T, T']
# attention matrix b and d
- attn_bd = torch.einsum("bihf, ijf -> bhij", q_with_bias_v, rel_pos_embeddings)
+ attn_bd = torch.einsum(
+ "bihf, ijhf -> bhij", q_with_bias_v, rel_pos_embeddings
+ ) # [B, #heads, T, T'] or [B, #heads, T, T+T'+1]
if not self.learnable_pos_emb:
- attn_bd = self._rel_shift_bhij(attn_bd, k_len=time_dim_size)
+ attn_bd = self._rel_shift_bhij(attn_bd, k_len=time_dim_size) # [B, #heads, T, T']
attn = attn_ac + attn_bd + mask # [B, #heads, T, T']
attn_scaled = attn * (math.sqrt(1.0 / float(self.embed_dim_per_head))) # [B, #heads, T, T']
@@ -201,21 +215,7 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
output_tensor = self.out_proj(attn_output)
- if self.dropout_broadcast_axes is None:
- output_tensor = self.dropout(output_tensor)
- elif self.dropout_broadcast_axes == "T":
- output_tensor = self.dropout(output_tensor.transpose(1, 2)).transpose(1, 2)
- elif self.dropout_broadcast_axes == "B":
- output_tensor = self.dropout(output_tensor.permute(1, 2, 0)).permute(2, 0, 1)
- elif self.dropout_broadcast_axes == "BT":
- batch_dim_size = output_tensor.shape[0]
- feature_dim_size = output_tensor.shape[-1]
-
- output_tensor = (
- self.output(output_tensor.reshape(-1, feature_dim_size).transpose(0, 1))
- .transpose(0, 1)
- .reshape(batch_dim_size, -1, feature_dim_size)
- )
+ output_tensor = self.dropout(output_tensor)
return output_tensor # [B,T,F]
diff --git a/i6_models/parts/dropout.py b/i6_models/parts/dropout.py
new file mode 100644
index 00000000..d240212f
--- /dev/null
+++ b/i6_models/parts/dropout.py
@@ -0,0 +1,55 @@
+from typing import Optional, Literal
+
+import torch
+from torch import nn
+
+
+class BroadcastDropout(nn.Module):
+ """
+ customized dropout module supporting dropout broadcasting
+ supported variants are:
+ - no broadcasting (default): dropout_broadcast_axes=None
+ - broadcast over the batch axis: dropout_broadcast_axes='B'
+ - broadcast over the time axis: dropout_broadcast_axes='T'
+ - broadcast over the batch and time axes: dropout_broadcast_axes='BT'
+ """
+
+ def __init__(self, p: float, dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None):
+ super().__init__()
+
+ self.p = p
+ assert dropout_broadcast_axes is None or dropout_broadcast_axes in [
+ "B",
+ "T",
+ "BT",
+ ], "invalid value, supported are None, 'B', 'T' and 'BT'"
+ self.dropout_broadcast_axes = dropout_broadcast_axes
+
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
+ """
+ assumes input tensor of shape [B, T, F]
+ return tensor of shape [B, T, F]
+ """
+ if self.dropout_broadcast_axes is None:
+ tensor = torch.nn.functional.dropout(tensor, p=self.p, training=self.training)
+ elif self.dropout_broadcast_axes == "T": # [B, T, F] -> [B, F, T] -> [B, T, F]
+ tensor = torch.nn.functional.dropout1d(tensor.transpose(1, 2), p=self.p, training=self.training).transpose(
+ 1, 2
+ )
+ elif self.dropout_broadcast_axes == "B": # [B, T, F] -> [T, F, B] -> [B, T, F]
+ tensor = torch.nn.functional.dropout1d(tensor.permute(1, 2, 0), p=self.p, training=self.training).permute(
+ 2, 0, 1
+ )
+ elif self.dropout_broadcast_axes == "BT": # [B, T, F] -> [B*T, F] -> [F, B*T] -> [B*T, F] -> [B, T, F]
+ batch_dim_size = tensor.shape[0]
+ feature_dim_size = tensor.shape[-1]
+
+ tensor = (
+ torch.nn.functional.dropout1d(
+ tensor.reshape(-1, feature_dim_size).transpose(0, 1), p=self.p, training=self.training
+ )
+ .transpose(0, 1)
+ .reshape(batch_dim_size, -1, feature_dim_size)
+ )
+
+ return tensor
diff --git a/tests/test_conformer_rel_pos.py b/tests/test_conformer_rel_pos.py
index 78e4439b..88b2dba3 100644
--- a/tests/test_conformer_rel_pos.py
+++ b/tests/test_conformer_rel_pos.py
@@ -76,10 +76,13 @@ def get_output_shape(
input_shape,
seq_len,
input_dim,
+ with_bias=True,
num_att_heads=8,
att_weights_dropout=0.1,
dropout=0.1,
learnable_pos_emb=True,
+ with_linear_pos=False,
+ separate_pos_emb_per_head=False,
rel_pos_clip=16,
with_pos_bias=False,
pos_emb_dropout=0.0,
@@ -90,9 +93,12 @@ def get_output_shape(
cfg = ConformerMHSARelPosV1Config(
input_dim=input_dim,
num_att_heads=num_att_heads,
+ with_bias=with_bias,
att_weights_dropout=att_weights_dropout,
dropout=dropout,
learnable_pos_emb=learnable_pos_emb,
+ with_linear_pos=with_linear_pos,
+ separate_pos_emb_per_head=separate_pos_emb_per_head,
rel_pos_clip=rel_pos_clip,
with_pos_bias=with_pos_bias,
pos_emb_dropout=pos_emb_dropout,
@@ -110,7 +116,9 @@ def get_output_shape(
input_shape = [4, 15, 32] # B,T,F
seq_len = [15, 12, 10, 15]
- for learnable_pos_emb, with_pos_bias, pos_emb_dropout in product([True, False], [True, False], [0.0, 0.1]):
+ for learnable_pos_emb, with_pos_bias, pos_emb_dropout, with_linear_pos, separate_pos_emb_per_head in product(
+ [True, False], [True, False], [0.0, 0.1], [True, False], [True, False]
+ ):
assert get_output_shape(
input_shape,
seq_len,
@@ -118,4 +126,6 @@ def get_output_shape(
learnable_pos_emb=learnable_pos_emb,
with_pos_bias=with_pos_bias,
pos_emb_dropout=pos_emb_dropout,
+ with_linear_pos=with_linear_pos,
+ separate_pos_emb_per_head=separate_pos_emb_per_head,
) == [4, 15, 32]