Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement EfficientNet #113

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using MLUtils
import Functors

include("utilities.jl")
include("layers.jl")
include("layers.jl")

# CNN models
include("convnets/alexnet.jl")
Expand All @@ -23,6 +23,7 @@ include("convnets/resnext.jl")
include("convnets/densenet.jl")
include("convnets/squeezenet.jl")
include("convnets/mobilenet.jl")
include("convnets/efficientnet/efficientnet.jl")

# Other models
include("other/mlpmixer.jl")
Expand All @@ -37,12 +38,13 @@ export AlexNet,
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
ResNeXt,
MobileNetv2, MobileNetv3,
EfficientNet,
MLPMixer,
ViT

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt,
:MobileNetv2, :MobileNetv3, :MLPMixer, :ViT)
for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt,
:MobileNetv2, :MobileNetv3, :EfficientNet, :MLPMixer)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end

Expand Down
72 changes: 72 additions & 0 deletions src/convnets/efficientnet/efficientnet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
include("params.jl")
include("mb.jl")

struct EfficientNet{S, B, H, P, F}
stem::S
blocks::B

head::H
pooling::P
top::F
end
Flux.@functor EfficientNet

function EfficientNet(
model_name, block_params, global_params; in_channels, n_classes, pretrain,
)
pad, bias = SamePad(), false
out_channels = round_filter(32, global_params)
stem = Chain(
Conv((3, 3), in_channels=>out_channels; bias, stride=2, pad),
BatchNorm(out_channels, swish))

blocks = MBConv[]
for bp in block_params
in_channels = round_filter(bp.in_channels, global_params)
out_channels = round_filter(bp.out_channels, global_params)
repeat = global_params.depth_coef ≈ 1 ?
bp.repeat : ceil(Int64, global_params.depth_coef * bp.repeat)

push!(blocks, MBConv(
in_channels, out_channels, bp.kernel, bp.stride;
expansion_ratio=bp.expansion_ratio))
for _ in 1:(repeat - 1)
push!(blocks, MBConv(
out_channels, out_channels, bp.kernel, 1;
expansion_ratio=bp.expansion_ratio))
end
end
blocks = Chain(blocks...)

head_out_channels = round_filter(1280, global_params)
head = Chain(
Conv((1, 1), out_channels=>head_out_channels; bias, pad),
BatchNorm(head_out_channels, swish))

top = Dense(head_out_channels, n_classes)
model = EfficientNet(stem, blocks, head, AdaptiveMeanPool((1, 1)), top)
pretrain && loadpretrain!(model, "EfficientNet" * model_name)
model
end

"""
EfficientNet(block_params, global_params; in_channels = 3)

Construct an EfficientNet model
([reference](https://arxiv.org/abs/1905.11946)).

# Arguments
- `model_name::String`: Name of the model. Accepts `b0`-`b8` names.
- `in_channels::Int`: Number of input channels. Default is `3`.
- `n_classes::Int`: Number of output classes. Default is `1000`.
- `pretrain::Bool`: Whether to load ImageNet pretrained weights.
Default is `false`.
"""
EfficientNet(
model_name::String; in_channels::Int = 3,
n_classes::Int = 1000, pretrain::Bool = false,
) = EfficientNet(
model_name, get_efficientnet_params(model_name)...;
in_channels, n_classes, pretrain)

(m::EfficientNet)(x) = m.top(Flux.flatten(m.pooling(m.head(m.blocks(m.stem(x))))))
82 changes: 82 additions & 0 deletions src/convnets/efficientnet/mb.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
struct MBConv{E, D, X, P}
expansion::E
depthwise::D
excitation::X
projection::P

do_expansion::Bool
do_excitation::Bool
do_skip::Bool
end
Flux.@functor MBConv

"""
MBConv(
in_channels, out_channels, kernel, stride;
expansion_ratio, se_ratio)

Mobile Inverted Residual Bottleneck Block
([reference](https://arxiv.org/abs/1801.04381)).

# Arguments
- `in_channels`: Number of input channels.
- `out_channels`: Number of output channels.
- `expansion_ratio`:
Expansion ratio defines the number of output channels.
Set to `1` to disable expansion phase.
`out_channels = input_channels * expansion_ratio`.
- `kernel`: Size of the kernel for the depthwise conv phase.
- `stride`: Size of the stride for the depthwise conv phase.
- `se_ratio`:
Squeeze-Excitation ratio. Should be in `(0, 1]` range.
Set to `-1` to disable.
"""
function MBConv(
in_channels, out_channels, kernel, stride;
expansion_ratio, se_ratio = 0.25,
)
do_skip = stride == 1 && in_channels == out_channels
do_expansion, do_excitation = expansion_ratio != 1, 0 < se_ratio ≤ 1
pad, bias = SamePad(), false

mid_channels = ceil(Int, in_channels * expansion_ratio)
expansion = do_expansion ?
Chain(
Conv((1, 1), in_channels=>mid_channels; bias, pad),
BatchNorm(mid_channels, swish)) :
identity

depthwise = Chain(
Conv(kernel, mid_channels=>mid_channels; bias, stride, pad, groups=mid_channels),
BatchNorm(mid_channels, swish))

if do_excitation
n_squeezed_channels = max(1, ceil(Int, in_channels * se_ratio))
excitation = Chain(
AdaptiveMeanPool((1, 1)),
Conv((1, 1), mid_channels=>n_squeezed_channels, swish; pad),
Conv((1, 1), n_squeezed_channels=>mid_channels; pad))
else
excitation = identity
end

projection = Chain(
Conv((1, 1), mid_channels=>out_channels; pad, bias),
BatchNorm(out_channels))
MBConv(
expansion, depthwise, excitation, projection, do_expansion,
do_excitation, do_skip)
end

function (m::MBConv)(x)
o = m.depthwise(m.expansion(x))

if m.do_excitation
o = σ.(m.excitation(o)) .* o
end
o = m.projection(o)
if m.do_skip
o = o + x
end
o
end
59 changes: 59 additions & 0 deletions src/convnets/efficientnet/params.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
struct BlockParams
repeat::Int
kernel::Tuple{Int, Int}
stride::Int
expansion_ratio::Int
in_channels::Int
out_channels::Int
end

struct GlobalParams
width_coef::Real
depth_coef::Real
image_size::Tuple{Int, Int}

depth_divisor::Int
min_depth::Union{Nothing, Int}
end

# (width_coefficient, depth_coefficient, resolution)
get_efficientnet_coefficients(model_name::String) =
Dict(
"b0" => (1.0, 1.0, 224),
"b1" => (1.0, 1.1, 240),
"b2" => (1.1, 1.2, 260),
"b3" => (1.2, 1.4, 300),
"b4" => (1.4, 1.8, 380),
"b5" => (1.6, 2.2, 456),
"b6" => (1.8, 2.6, 528),
"b7" => (2.0, 3.1, 600),
"b8" => (2.2, 3.6, 672))[model_name]

function get_efficientnet_params(model_name)
block_params = [
BlockParams(1, (3, 3), 1, 1, 32, 16),
BlockParams(2, (3, 3), 2, 6, 16, 24),
BlockParams(2, (5, 5), 2, 6, 24, 40),
BlockParams(3, (3, 3), 2, 6, 40, 80),
BlockParams(3, (5, 5), 1, 6, 80, 112),
BlockParams(4, (5, 5), 2, 6, 112, 192),
BlockParams(1, (3, 3), 1, 6, 192, 320)]

width_coef, depth_coef, resolution = get_efficientnet_coefficients(model_name)
global_params = GlobalParams(
width_coef, depth_coef, (resolution, resolution), 8, nothing)
block_params, global_params
end

function round_filter(filters, global_params::GlobalParams)
global_params.width_coef ≈ 1 && return filters

depth_divisor = global_params.depth_divisor
filters *= global_params.width_coef
min_depth = global_params.min_depth
min_depth = min_depth ≡ nothing ? depth_divisor : min_depth

new_filters = max(min_depth, (floor(Int, filters + depth_divisor / 2) ÷ depth_divisor) * depth_divisor)
new_filters < 0.9 * filters && (new_filters += global_params.depth_divisor)
new_filters
end
12 changes: 12 additions & 0 deletions test/convnets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,15 @@ end
end
end
end

@testset "EfficientNet" begin
m = EfficientNet("b4")
x = rand(Float32, 224, 224, 3, 2)
@test size(m(x)) == (1000, 2)
@test_throws ArgumentError (EfficientNet("b0"; pretrain = true); true)
@test_skip gradtest(m, rand(Float32, 224, 224, 3, 2))

# TODO: add test for inferrability once the new version of Flux is released with
# https://github.com/FluxML/Flux.jl/pull/1856
# @inferred m(x)
end