Skip to content

Commit

Permalink
remove default, update docstring and test case
Browse files Browse the repository at this point in the history
  • Loading branch information
kuacakuaca committed Sep 6, 2024
1 parent 5bf24d2 commit 0b9fb9d
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 9 deletions.
3 changes: 1 addition & 2 deletions i6_models/parts/conformer/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,9 @@ class ConformerConvolutionV2Config(ConformerConvolutionV1Config):
New attribute:
dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis
setting to None to disable broadcasting
Allows even kernel size
"""

dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None
dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]]

def check_valid(self):
assert self.kernel_size % 2 == 1, "ConformerConvolutionV1 only supports odd kernel sizes"
Expand Down
9 changes: 7 additions & 2 deletions i6_models/parts/conformer/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,19 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:


@dataclass
class ConformerPositionwiseFeedForwardV2Config(ConformerPositionwiseFeedForwardV1Config):
class ConformerPositionwiseFeedForwardV2Config(ModelConfiguration):
"""
New attribute:
dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis
setting to None to disable broadcasting
Default value for `activation` removed
"""

dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None
input_dim: int
hidden_dim: int
dropout: float
activation: Callable[[torch.Tensor], torch.Tensor]
dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]]

def check_valid(self):
assert self.dropout_broadcast_axes in [
Expand Down
2 changes: 1 addition & 1 deletion i6_models/parts/conformer/mhsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class ConformerMHSAV2Config(ConformerMHSAV1Config):
setting to None to disable broadcasting
"""

dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None
dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]]

def check_valid(self):
assert self.dropout_broadcast_axes in [
Expand Down
2 changes: 1 addition & 1 deletion i6_models/parts/conformer/mhsa_rel_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ConformerMHSARelPosV1Config(ModelConfiguration):
separate_pos_emb_per_head: bool
pos_emb_dropout: float
dropout: float
dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None
dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]]

def __post_init__(self) -> None:
super().__post_init__()
Expand Down
2 changes: 1 addition & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
onnx
onnxruntime
espnet
typeguard
typeguard >= 4.3.0
5 changes: 3 additions & 2 deletions tests/test_conformer_rel_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def test_ConformerMHSARelPosV1_against_Espnet():
dropout_rate = 0.1
batch_dim_size = 4
time_dim_size = 50
seq_len = torch.Tensor([50, 10, 20, 40])
sequence_mask = torch.less(torch.arange(time_dim_size)[None, :], seq_len[:, None])

espnet_mhsa_module = RelPositionMultiHeadedAttention(
num_heads=num_heads, embed_size=embed_size, dropout_rate=dropout_rate
Expand Down Expand Up @@ -186,7 +188,6 @@ def test_ConformerMHSARelPosV1_against_Espnet():
)

input_tensor = torch.rand((batch_dim_size, time_dim_size, embed_size))
sequence_mask = torch.ones((batch_dim_size, time_dim_size))
inv_sequence_mask = torch.logical_not(sequence_mask)

input_tensor_layernorm = own_mhsa_module.layernorm(input_tensor)
Expand All @@ -202,4 +203,4 @@ def test_ConformerMHSARelPosV1_against_Espnet():

own_output_tensor = own_mhsa_module(input_tensor, sequence_mask=sequence_mask)

assert torch.allclose(espnet_output_tensor, own_output_tensor, rtol=1e-03)
assert torch.allclose(espnet_output_tensor, own_output_tensor, rtol=1e-03, atol=1e-6)

0 comments on commit 0b9fb9d

Please sign in to comment.