From 0b9fb9d08ac3e7bffa9c0c057adc38ad6e53832f Mon Sep 17 00:00:00 2001
From: pzheng
Date: Fri, 6 Sep 2024 11:10:34 +0200
Subject: [PATCH] remove default, update docstring and test case
---
i6_models/parts/conformer/convolution.py | 3 +--
i6_models/parts/conformer/feedforward.py | 9 +++++++--
i6_models/parts/conformer/mhsa.py | 2 +-
i6_models/parts/conformer/mhsa_rel_pos.py | 2 +-
requirements_dev.txt | 2 +-
tests/test_conformer_rel_pos.py | 5 +++--
6 files changed, 14 insertions(+), 9 deletions(-)
diff --git a/i6_models/parts/conformer/convolution.py b/i6_models/parts/conformer/convolution.py
index 87c56bc3..a7566855 100644
--- a/i6_models/parts/conformer/convolution.py
+++ b/i6_models/parts/conformer/convolution.py
@@ -99,10 +99,9 @@ class ConformerConvolutionV2Config(ConformerConvolutionV1Config):
New attribute:
dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis
setting to None to disable broadcasting
- Allows even kernel size
"""
- dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None
+ dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]]
def check_valid(self):
assert self.kernel_size % 2 == 1, "ConformerConvolutionV1 only supports odd kernel sizes"
diff --git a/i6_models/parts/conformer/feedforward.py b/i6_models/parts/conformer/feedforward.py
index c58840dd..9d7dda4b 100644
--- a/i6_models/parts/conformer/feedforward.py
+++ b/i6_models/parts/conformer/feedforward.py
@@ -62,14 +62,19 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
@dataclass
-class ConformerPositionwiseFeedForwardV2Config(ConformerPositionwiseFeedForwardV1Config):
+class ConformerPositionwiseFeedForwardV2Config(ModelConfiguration):
"""
New attribute:
dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis
setting to None to disable broadcasting
+ Default value for `activation` removed
"""
- dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None
+ input_dim: int
+ hidden_dim: int
+ dropout: float
+ activation: Callable[[torch.Tensor], torch.Tensor]
+ dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]]
def check_valid(self):
assert self.dropout_broadcast_axes in [
diff --git a/i6_models/parts/conformer/mhsa.py b/i6_models/parts/conformer/mhsa.py
index 1584df23..2a67defe 100644
--- a/i6_models/parts/conformer/mhsa.py
+++ b/i6_models/parts/conformer/mhsa.py
@@ -73,7 +73,7 @@ class ConformerMHSAV2Config(ConformerMHSAV1Config):
setting to None to disable broadcasting
"""
- dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None
+ dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]]
def check_valid(self):
assert self.dropout_broadcast_axes in [
diff --git a/i6_models/parts/conformer/mhsa_rel_pos.py b/i6_models/parts/conformer/mhsa_rel_pos.py
index 0a1c8d2c..3118cacf 100644
--- a/i6_models/parts/conformer/mhsa_rel_pos.py
+++ b/i6_models/parts/conformer/mhsa_rel_pos.py
@@ -45,7 +45,7 @@ class ConformerMHSARelPosV1Config(ModelConfiguration):
separate_pos_emb_per_head: bool
pos_emb_dropout: float
dropout: float
- dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None
+ dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]]
def __post_init__(self) -> None:
super().__post_init__()
diff --git a/requirements_dev.txt b/requirements_dev.txt
index fa744081..26a6a0cf 100644
--- a/requirements_dev.txt
+++ b/requirements_dev.txt
@@ -1,4 +1,4 @@
onnx
onnxruntime
espnet
-typeguard
+typeguard >= 4.3.0
diff --git a/tests/test_conformer_rel_pos.py b/tests/test_conformer_rel_pos.py
index 08d85d64..3aafb32c 100644
--- a/tests/test_conformer_rel_pos.py
+++ b/tests/test_conformer_rel_pos.py
@@ -140,6 +140,8 @@ def test_ConformerMHSARelPosV1_against_Espnet():
dropout_rate = 0.1
batch_dim_size = 4
time_dim_size = 50
+ seq_len = torch.Tensor([50, 10, 20, 40])
+ sequence_mask = torch.less(torch.arange(time_dim_size)[None, :], seq_len[:, None])
espnet_mhsa_module = RelPositionMultiHeadedAttention(
num_heads=num_heads, embed_size=embed_size, dropout_rate=dropout_rate
@@ -186,7 +188,6 @@ def test_ConformerMHSARelPosV1_against_Espnet():
)
input_tensor = torch.rand((batch_dim_size, time_dim_size, embed_size))
- sequence_mask = torch.ones((batch_dim_size, time_dim_size))
inv_sequence_mask = torch.logical_not(sequence_mask)
input_tensor_layernorm = own_mhsa_module.layernorm(input_tensor)
@@ -202,4 +203,4 @@ def test_ConformerMHSARelPosV1_against_Espnet():
own_output_tensor = own_mhsa_module(input_tensor, sequence_mask=sequence_mask)
- assert torch.allclose(espnet_output_tensor, own_output_tensor, rtol=1e-03)
+ assert torch.allclose(espnet_output_tensor, own_output_tensor, rtol=1e-03, atol=1e-6)