From 0b9fb9d08ac3e7bffa9c0c057adc38ad6e53832f Mon Sep 17 00:00:00 2001 From: pzheng Date: Fri, 6 Sep 2024 11:10:34 +0200 Subject: [PATCH] remove default, update docstring and test case --- i6_models/parts/conformer/convolution.py | 3 +-- i6_models/parts/conformer/feedforward.py | 9 +++++++-- i6_models/parts/conformer/mhsa.py | 2 +- i6_models/parts/conformer/mhsa_rel_pos.py | 2 +- requirements_dev.txt | 2 +- tests/test_conformer_rel_pos.py | 5 +++-- 6 files changed, 14 insertions(+), 9 deletions(-) diff --git a/i6_models/parts/conformer/convolution.py b/i6_models/parts/conformer/convolution.py index 87c56bc3..a7566855 100644 --- a/i6_models/parts/conformer/convolution.py +++ b/i6_models/parts/conformer/convolution.py @@ -99,10 +99,9 @@ class ConformerConvolutionV2Config(ConformerConvolutionV1Config): New attribute: dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis setting to None to disable broadcasting - Allows even kernel size """ - dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None + dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] def check_valid(self): assert self.kernel_size % 2 == 1, "ConformerConvolutionV1 only supports odd kernel sizes" diff --git a/i6_models/parts/conformer/feedforward.py b/i6_models/parts/conformer/feedforward.py index c58840dd..9d7dda4b 100644 --- a/i6_models/parts/conformer/feedforward.py +++ b/i6_models/parts/conformer/feedforward.py @@ -62,14 +62,19 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: @dataclass -class ConformerPositionwiseFeedForwardV2Config(ConformerPositionwiseFeedForwardV1Config): +class ConformerPositionwiseFeedForwardV2Config(ModelConfiguration): """ New attribute: dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis setting to None to disable broadcasting + Default value for `activation` removed """ - dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None + input_dim: int + hidden_dim: int + dropout: float + activation: Callable[[torch.Tensor], torch.Tensor] + dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] def check_valid(self): assert self.dropout_broadcast_axes in [ diff --git a/i6_models/parts/conformer/mhsa.py b/i6_models/parts/conformer/mhsa.py index 1584df23..2a67defe 100644 --- a/i6_models/parts/conformer/mhsa.py +++ b/i6_models/parts/conformer/mhsa.py @@ -73,7 +73,7 @@ class ConformerMHSAV2Config(ConformerMHSAV1Config): setting to None to disable broadcasting """ - dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None + dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] def check_valid(self): assert self.dropout_broadcast_axes in [ diff --git a/i6_models/parts/conformer/mhsa_rel_pos.py b/i6_models/parts/conformer/mhsa_rel_pos.py index 0a1c8d2c..3118cacf 100644 --- a/i6_models/parts/conformer/mhsa_rel_pos.py +++ b/i6_models/parts/conformer/mhsa_rel_pos.py @@ -45,7 +45,7 @@ class ConformerMHSARelPosV1Config(ModelConfiguration): separate_pos_emb_per_head: bool pos_emb_dropout: float dropout: float - dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None + dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] def __post_init__(self) -> None: super().__post_init__() diff --git a/requirements_dev.txt b/requirements_dev.txt index fa744081..26a6a0cf 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,4 +1,4 @@ onnx onnxruntime espnet -typeguard +typeguard >= 4.3.0 diff --git a/tests/test_conformer_rel_pos.py b/tests/test_conformer_rel_pos.py index 08d85d64..3aafb32c 100644 --- a/tests/test_conformer_rel_pos.py +++ b/tests/test_conformer_rel_pos.py @@ -140,6 +140,8 @@ def test_ConformerMHSARelPosV1_against_Espnet(): dropout_rate = 0.1 batch_dim_size = 4 time_dim_size = 50 + seq_len = torch.Tensor([50, 10, 20, 40]) + sequence_mask = torch.less(torch.arange(time_dim_size)[None, :], seq_len[:, None]) espnet_mhsa_module = RelPositionMultiHeadedAttention( num_heads=num_heads, embed_size=embed_size, dropout_rate=dropout_rate @@ -186,7 +188,6 @@ def test_ConformerMHSARelPosV1_against_Espnet(): ) input_tensor = torch.rand((batch_dim_size, time_dim_size, embed_size)) - sequence_mask = torch.ones((batch_dim_size, time_dim_size)) inv_sequence_mask = torch.logical_not(sequence_mask) input_tensor_layernorm = own_mhsa_module.layernorm(input_tensor) @@ -202,4 +203,4 @@ def test_ConformerMHSARelPosV1_against_Espnet(): own_output_tensor = own_mhsa_module(input_tensor, sequence_mask=sequence_mask) - assert torch.allclose(espnet_output_tensor, own_output_tensor, rtol=1e-03) + assert torch.allclose(espnet_output_tensor, own_output_tensor, rtol=1e-03, atol=1e-6)