Skip to content

Commit

Permalink
Apply minor fixes to GenericFrontend code (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimBe195 authored Nov 6, 2024
1 parent 6df0b86 commit e22c46a
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions i6_models/parts/frontend/generic_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def check_valid(self):

assert len(self.layer_ordering) == num_convs + num_pools + num_activations, "Number of total layers mismatch!"

for kernel_sizes in filter(None, [self.conv_kernel_sizes, self.pool_kernel_sizes]):
for kernel_size in kernel_sizes:
if self.conv_kernel_sizes is not None:
for kernel_size in self.conv_kernel_sizes:
assert all(k % 2 for k in kernel_size), "ConformerVGGFrontendV1 only supports odd kernel sizes"

def __post__init__(self):
Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(self, model_cfg: GenericFrontendV1Config):
if layer_type == FrontendLayerType.Conv2d:
conv_out_dim = model_cfg.conv_out_dims[conv_layer_index]
conv_kernel_size = model_cfg.conv_kernel_sizes[conv_layer_index]
conv_stride = 1 if model_cfg.conv_strides is None else model_cfg.conv_strides[conv_layer_index]
conv_stride = (1, 1) if model_cfg.conv_strides is None else model_cfg.conv_strides[conv_layer_index]
conv_padding = (
get_same_padding(conv_kernel_size)
if model_cfg.conv_paddings is None
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(self, model_cfg: GenericFrontendV1Config):
last_feat_dim = calculate_output_dim(
in_dim=last_feat_dim,
filter_size=pool_kernel_size[1],
stride=pool_stride[1] or pool_kernel_size[1],
stride=(pool_stride or pool_kernel_size)[1],
padding=pool_padding[1],
)
pool_layer_index += 1
Expand Down

0 comments on commit e22c46a

Please sign in to comment.