From 6d96f6d71cd78a150a5925fc8b6d714ac51217a7 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Tue, 8 Feb 2022 01:30:02 +0200 Subject: [PATCH 1/3] Implement EfficientNet --- src/Metalhead.jl | 4 +- src/convnets/efficientnet/efficientnet.jl | 62 +++++++++++++++++ src/convnets/efficientnet/mb.jl | 82 +++++++++++++++++++++++ src/convnets/efficientnet/params.jl | 60 +++++++++++++++++ test/convnets.jl | 7 ++ 5 files changed, 214 insertions(+), 1 deletion(-) create mode 100644 src/convnets/efficientnet/efficientnet.jl create mode 100644 src/convnets/efficientnet/mb.jl create mode 100644 src/convnets/efficientnet/params.jl diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 977b6183f..39488155f 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -22,6 +22,7 @@ include("convnets/resnext.jl") include("convnets/densenet.jl") include("convnets/squeezenet.jl") include("convnets/mobilenet.jl") +include("convnets/efficientnet/efficientnet.jl") # Other models include("other/mlpmixer.jl") @@ -33,11 +34,12 @@ export AlexNet, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, ResNeXt, MobileNetv2, MobileNetv3, + EfficientNet, MLPMixer # use Flux._big_show to pretty print large models for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt, - :MobileNetv2, :MobileNetv3, :MLPMixer) + :MobileNetv2, :MobileNetv3, :EfficientNet, :MLPMixer) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/efficientnet/efficientnet.jl b/src/convnets/efficientnet/efficientnet.jl new file mode 100644 index 000000000..983ce4529 --- /dev/null +++ b/src/convnets/efficientnet/efficientnet.jl @@ -0,0 +1,62 @@ +include("params.jl") +include("mb.jl") + +struct EfficientNet{S, B, H, P, F} + stem::S + blocks::B + + head::H + pooling::P + top::F +end +Flux.@functor EfficientNet + +""" + EfficientNet(block_params, global_params; in_channels = 3) + +Construct an EfficientNet model +([reference](https://arxiv.org/abs/1905.11946)). +""" +function EfficientNet( + model_name, block_params, global_params; in_channels, n_classes, pretrain, +) + pad, bias = SamePad(), false + out_channels = round_filter(32, global_params) + stem = Chain( + Conv((3, 3), in_channels=>out_channels; bias, stride=2, pad), + BatchNorm(out_channels, swish)) + + blocks = MBConv[] + for bp in block_params + in_channels = round_filter(bp.in_channels, global_params) + out_channels = round_filter(bp.out_channels, global_params) + repeat = global_params.depth_coef ≈ 1 ? + bp.repeat : ceil(Int64, global_params.depth_coef * bp.repeat) + + push!(blocks, MBConv( + in_channels, out_channels, bp.kernel, bp.stride; + expansion_ratio=bp.expansion_ratio)) + for _ in 1:(repeat - 1) + push!(blocks, MBConv( + out_channels, out_channels, bp.kernel, 1; + expansion_ratio=bp.expansion_ratio)) + end + end + blocks = Chain(blocks...) + + head_out_channels = round_filter(1280, global_params) + head = Chain( + Conv((1, 1), out_channels=>head_out_channels; bias, pad), + BatchNorm(head_out_channels, swish)) + + top = n_classes ≡ nothing ? + identity : (Dense(head_out_channels, n_classes) ∘ Flux.flatten) + model = EfficientNet(stem, blocks, head, AdaptiveMeanPool((1, 1)), top) + pretrain && loadpretrain!(model, "EfficientNet" * model_name) + model +end + +EfficientNet(model_name::String; in_channels = 3, n_classes = 1000, pretrain = false) = + EfficientNet(model_name, get_efficientnet_params(model_name)...; in_channels, n_classes, pretrain) + +(m::EfficientNet)(x) = m.top(m.pooling(m.head(m.blocks(m.stem(x))))) diff --git a/src/convnets/efficientnet/mb.jl b/src/convnets/efficientnet/mb.jl new file mode 100644 index 000000000..466daf2d9 --- /dev/null +++ b/src/convnets/efficientnet/mb.jl @@ -0,0 +1,82 @@ +struct MBConv{E, D, X, P} + expansion::E + depthwise::D + excitation::X + projection::P + + do_expansion::Bool + do_excitation::Bool + do_skip::Bool +end +Flux.@functor MBConv + +""" + MBConv( + in_channels, out_channels, kernel, stride; + expansion_ratio, se_ratio) + +Mobile Inverted Residual Bottleneck Block +([reference](https://arxiv.org/abs/1801.04381)). + +# Arguments +- `in_channels`: Number of input channels. +- `out_channels`: Number of output channels. +- `expansion_ratio`: + Expansion ratio defines the number of output channels. + Set to `1` to disable expansion phase. + `out_channels = input_channels * expansion_ratio`. +- `kernel`: Size of the kernel for the depthwise conv phase. +- `stride`: Size of the stride for the depthwise conv phase. +- `se_ratio`: + Squeeze-Excitation ratio. Should be in `(0, 1]` range. + Set to `-1` to disable. +""" +function MBConv( + in_channels, out_channels, kernel, stride; + expansion_ratio, se_ratio = 0.25, +) + do_skip = stride == 1 && in_channels == out_channels + do_expansion, do_excitation = expansion_ratio != 1, 0 < se_ratio ≤ 1 + pad, bias = SamePad(), false + + mid_channels = ceil(Int, in_channels * expansion_ratio) + expansion = do_expansion ? + Chain( + Conv((1, 1), in_channels=>mid_channels; bias, pad), + BatchNorm(mid_channels, swish)) : + identity + + depthwise = Chain( + Conv(kernel, mid_channels=>mid_channels; bias, stride, pad, groups=mid_channels), + BatchNorm(mid_channels, swish)) + + if do_excitation + n_squeezed_channels = max(1, ceil(Int, in_channels * se_ratio)) + excitation = Chain( + AdaptiveMeanPool((1, 1)), + Conv((1, 1), mid_channels=>n_squeezed_channels, swish; pad), + Conv((1, 1), n_squeezed_channels=>mid_channels; pad)) + else + excitation = identity + end + + projection = Chain( + Conv((1, 1), mid_channels=>out_channels; pad, bias), + BatchNorm(out_channels)) + MBConv( + expansion, depthwise, excitation, projection, do_expansion, + do_excitation, do_skip) +end + +function (m::MBConv)(x) + o = m.depthwise(m.expansion(x)) + + if m.do_excitation + o = σ.(m.excitation(o)) .* o + end + o = m.projection(o) + if m.do_skip + o = o + x + end + o +end diff --git a/src/convnets/efficientnet/params.jl b/src/convnets/efficientnet/params.jl new file mode 100644 index 000000000..9a442e56b --- /dev/null +++ b/src/convnets/efficientnet/params.jl @@ -0,0 +1,60 @@ +struct BlockParams + repeat::Int + kernel::Tuple{Int, Int} + stride::Int + expansion_ratio::Int + in_channels::Int + out_channels::Int +end + +struct GlobalParams + width_coef::Real + depth_coef::Real + image_size::Tuple{Int, Int} + + depth_divisor::Int + min_depth::Union{Nothing, Int} +end + +# (width_coefficient, depth_coefficient, resolution) +get_efficientnet_coefficients(model_name::String) = + Dict( + "b0" => (1.0, 1.0, 224), + "b1" => (1.0, 1.1, 240), + "b2" => (1.1, 1.2, 260), + "b3" => (1.2, 1.4, 300), + "b4" => (1.4, 1.8, 380), + "b5" => (1.6, 2.2, 456), + "b6" => (1.8, 2.6, 528), + "b7" => (2.0, 3.1, 600), + "b8" => (2.2, 3.6, 672), + "l2" => (4.3, 5.3, 800))[model_name] + +function get_efficientnet_params(model_name) + block_params = [ + BlockParams(1, (3, 3), 1, 1, 32, 16), + BlockParams(2, (3, 3), 2, 6, 16, 24), + BlockParams(2, (5, 5), 2, 6, 24, 40), + BlockParams(3, (3, 3), 2, 6, 40, 80), + BlockParams(3, (5, 5), 1, 6, 80, 112), + BlockParams(4, (5, 5), 2, 6, 112, 192), + BlockParams(1, (3, 3), 1, 6, 192, 320)] + + width_coef, depth_coef, resolution = get_efficientnet_coefficients(model_name) + global_params = GlobalParams( + width_coef, depth_coef, (resolution, resolution), 8, nothing) + block_params, global_params +end + +function round_filter(filters, global_params::GlobalParams) + global_params.width_coef ≈ 1 && return filters + + depth_divisor = global_params.depth_divisor + filters *= global_params.width_coef + min_depth = global_params.min_depth + min_depth = min_depth ≡ nothing ? depth_divisor : min_depth + + new_filters = max(min_depth, (floor(Int, filters + depth_divisor / 2) ÷ depth_divisor) * depth_divisor) + new_filters < 0.9 * filters && (new_filters += global_params.depth_divisor) + new_filters +end diff --git a/test/convnets.jl b/test/convnets.jl index e9a99748d..2bc016b20 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -125,3 +125,10 @@ end end end end + +@testset "EfficientNet" begin + m = EfficientNet("b0") + @test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2) + @test_throws ArgumentError (EfficientNet("b0"; pretrain = true); true) + @test_skip gradtest(m, rand(Float32, 224, 224, 3, 2)) +end From 5f3b72c8bfab91e37e85e4256608c4ba537f8c8c Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Sat, 12 Feb 2022 00:15:00 +0200 Subject: [PATCH 2/3] Add docs & todo for tests --- Project.toml | 1 + src/Metalhead.jl | 2 ++ src/convnets/efficientnet/efficientnet.jl | 32 +++++++++++++++-------- src/convnets/efficientnet/params.jl | 3 +-- test/convnets.jl | 9 +++++-- 5 files changed, 32 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 2411ae9be..279b5f9c5 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 39488155f..b99da5e1c 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -6,6 +6,8 @@ using Functors using BSON using Artifacts, LazyArtifacts using Statistics +using Pickle +using Downloads: download import Functors diff --git a/src/convnets/efficientnet/efficientnet.jl b/src/convnets/efficientnet/efficientnet.jl index 983ce4529..d2fa84a5a 100644 --- a/src/convnets/efficientnet/efficientnet.jl +++ b/src/convnets/efficientnet/efficientnet.jl @@ -11,12 +11,6 @@ struct EfficientNet{S, B, H, P, F} end Flux.@functor EfficientNet -""" - EfficientNet(block_params, global_params; in_channels = 3) - -Construct an EfficientNet model -([reference](https://arxiv.org/abs/1905.11946)). -""" function EfficientNet( model_name, block_params, global_params; in_channels, n_classes, pretrain, ) @@ -49,14 +43,30 @@ function EfficientNet( Conv((1, 1), out_channels=>head_out_channels; bias, pad), BatchNorm(head_out_channels, swish)) - top = n_classes ≡ nothing ? - identity : (Dense(head_out_channels, n_classes) ∘ Flux.flatten) + top = Dense(head_out_channels, n_classes) model = EfficientNet(stem, blocks, head, AdaptiveMeanPool((1, 1)), top) pretrain && loadpretrain!(model, "EfficientNet" * model_name) model end -EfficientNet(model_name::String; in_channels = 3, n_classes = 1000, pretrain = false) = - EfficientNet(model_name, get_efficientnet_params(model_name)...; in_channels, n_classes, pretrain) +""" + EfficientNet(block_params, global_params; in_channels = 3) + +Construct an EfficientNet model +([reference](https://arxiv.org/abs/1905.11946)). + +# Arguments +- `model_name::String`: Name of the model. Accepts `b0`-`b8` names. +- `in_channels::Int`: Number of input channels. Default is `3`. +- `n_classes::Int`: Number of output classes. Default is `1000`. +- `pretrain::Bool`: Whether to load ImageNet pretrained weights. + Default is `false`. +""" +EfficientNet( + model_name::String; in_channels::Int = 3, + n_classes::Int = 1000, pretrain::Bool = false, +) = EfficientNet( + model_name, get_efficientnet_params(model_name)...; + in_channels, n_classes, pretrain) -(m::EfficientNet)(x) = m.top(m.pooling(m.head(m.blocks(m.stem(x))))) +(m::EfficientNet)(x) = m.top(Flux.flatten(m.pooling(m.head(m.blocks(m.stem(x)))))) diff --git a/src/convnets/efficientnet/params.jl b/src/convnets/efficientnet/params.jl index 9a442e56b..45e0db198 100644 --- a/src/convnets/efficientnet/params.jl +++ b/src/convnets/efficientnet/params.jl @@ -27,8 +27,7 @@ get_efficientnet_coefficients(model_name::String) = "b5" => (1.6, 2.2, 456), "b6" => (1.8, 2.6, 528), "b7" => (2.0, 3.1, 600), - "b8" => (2.2, 3.6, 672), - "l2" => (4.3, 5.3, 800))[model_name] + "b8" => (2.2, 3.6, 672))[model_name] function get_efficientnet_params(model_name) block_params = [ diff --git a/test/convnets.jl b/test/convnets.jl index 2bc016b20..f7f558712 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -127,8 +127,13 @@ end end @testset "EfficientNet" begin - m = EfficientNet("b0") - @test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2) + m = EfficientNet("b4") + x = rand(Float32, 224, 224, 3, 2) + @test size(m(x)) == (1000, 2) @test_throws ArgumentError (EfficientNet("b0"; pretrain = true); true) @test_skip gradtest(m, rand(Float32, 224, 224, 3, 2)) + + # TODO: add test for inferrability once the new version of Flux is released with + # https://github.com/FluxML/Flux.jl/pull/1856 + # @inferred m(x) end From 3c845c88def2c3a5f8dff41225402c53e2c3ad73 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Sat, 12 Feb 2022 00:16:33 +0200 Subject: [PATCH 3/3] Remove Pickle deps --- Project.toml | 1 - src/Metalhead.jl | 2 -- 2 files changed, 3 deletions(-) diff --git a/Project.toml b/Project.toml index 279b5f9c5..2411ae9be 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] diff --git a/src/Metalhead.jl b/src/Metalhead.jl index b99da5e1c..39488155f 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -6,8 +6,6 @@ using Functors using BSON using Artifacts, LazyArtifacts using Statistics -using Pickle -using Downloads: download import Functors