Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Judyxujj committed Oct 20, 2023
1 parent 370fb10 commit f26d94c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 9 deletions.
6 changes: 4 additions & 2 deletions i6_models/parts/frontend/generic_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class GenericFrontendV1Config(ModelConfiguration):
in_features: number of input features to module
layer_ordering: the ordering of the front end layer sequences, the ordering element must be selected from FrontendLayerType
e.g. the ordering of VGG4LayerActFrontendV1 would be [FrontendLayerType.Conv2d, FrontendLayerType.Activation,
FrontendLayerType.Pool2d, FrontendLayerType.Conv2d, FrontendLayerType.Conv2d, FrontendLayerType.Activation, FrontendLayerType.Pool2d]
FrontendLayerType.Pool2d, FrontendLayerType.Conv2d, FrontendLayerType.Conv2d, FrontendLayerType.Activation,
FrontendLayerType.Pool2d]
conv_kernel_sizes: kernel sizes for each conv layer
conv_strides: stride sizes for each conv layer
conv_paddings: paddings sizes for each conv layer
Expand Down Expand Up @@ -102,7 +103,8 @@ class GenericFrontendV1(nn.Module):
def __init__(self, model_cfg: GenericFrontendV1Config):
"""
Generic Front-End
can be used to generate customized frontend by combine convolutional and pooling layers, as well as activation functions different
can be used to generate customized frontend by combining convolutional and pooling layers, as well as activation
functions differently
To get the ESPnet case, for example Conv2dSubsampling6, use these options
layer_ordering = [FrontendLayerType.Conv2d, FrontendLayerType.Conv2d]
Expand Down
7 changes: 0 additions & 7 deletions tests/test_generic_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
import torch
from torch import nn

import sys

sys.path.insert(0, "/Users/jxu/Desktop/PR/i6_models")

from i6_models.parts.frontend.generic_frontend import FrontendLayerType, GenericFrontendV1, GenericFrontendV1Config


Expand Down Expand Up @@ -169,6 +165,3 @@ def get_output_shape(test_parameters: GenericFrontendV1TestParams):
)

return


test_generic_frontend_v1()

0 comments on commit f26d94c

Please sign in to comment.