diff --git a/src/Metalhead.jl b/src/Metalhead.jl index efe1fafa0..5c9cac03a 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -41,7 +41,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, - SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, + SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet, MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet.jl index 4e28072dd..95933912d 100644 --- a/src/convnets/efficientnet.jl +++ b/src/convnets/efficientnet.jl @@ -24,9 +24,8 @@ function efficientnet(scalings, block_config; inchannels = 3, nclasses = 1000, max_width = 1280) wscale, dscale = scalings out_channels = _round_channels(32, 8) - stem = Chain(Conv((3, 3), inchannels => out_channels; - bias = false, stride = 2, pad = SamePad()), - BatchNorm(out_channels, swish)) + stem = conv_bn((3, 3), inchannels, out_channels, swish; + bias = false, stride = 2, pad = SamePad()) blocks = [] for (n, k, s, e, i, o) in block_config