Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to SciML style, take 2 #163

Closed
wants to merge 10 commits into from
2 changes: 2 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
style = "sciml"
whitespace_in_kwargs = true
3 changes: 3 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# .git-blame-ignore-revs
# Switched to SciML style for code
fd2869f57c66fa650547cd8581feeba9eda08b88
25 changes: 13 additions & 12 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,23 @@ include("vit-based/vit.jl")

include("pretrain.jl")

export AlexNet,
VGG, VGG11, VGG13, VGG16, VGG19,
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
GoogLeNet, Inception3, SqueezeNet,
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
ResNeXt,
MobileNetv1, MobileNetv2, MobileNetv3,
MLPMixer, ResMLP, gMLP,
ViT,
ConvNeXt, ConvMixer
export AlexNet,
VGG, VGG11, VGG13, VGG16, VGG19,
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
GoogLeNet, Inception3, SqueezeNet,
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
ResNeXt,
MobileNetv1, MobileNetv2, MobileNetv3,
MLPMixer, ResMLP, gMLP,
ViT,
ConvNeXt, ConvMixer

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt,
for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet,
:ResNeXt,
:MobileNetv1, :MobileNetv2, :MobileNetv3,
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvNeXt, :ConvMixer)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end

end # module
43 changes: 21 additions & 22 deletions src/convnets/alexnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,23 @@ Create an AlexNet model
- `nclasses`: the number of output classes
"""
function alexnet(; nclasses = 1000)
layers = Chain(Chain(Conv((11, 11), 3 => 64, stride = (4, 4), relu, pad = (2, 2)),
MaxPool((3, 3), stride = (2, 2)),
Conv((5, 5), 64 => 192, relu, pad = (2, 2)),
MaxPool((3, 3), stride = (2, 2)),
Conv((3, 3), 192 => 384, relu, pad = (1, 1)),
Conv((3, 3), 384 => 256, relu, pad = (1, 1)),
Conv((3, 3), 256 => 256, relu, pad = (1, 1)),
MaxPool((3, 3), stride = (2, 2)),
AdaptiveMeanPool((6,6))),
Chain(MLUtils.flatten,
Dropout(0.5),
Dense(256 * 6 * 6, 4096, relu),
Dropout(0.5),
Dense(4096, 4096, relu),
Dense(4096, nclasses)))

return layers
layers = Chain(Chain(Conv((11, 11), 3 => 64, stride = (4, 4), relu, pad = (2, 2)),
MaxPool((3, 3), stride = (2, 2)),
Conv((5, 5), 64 => 192, relu, pad = (2, 2)),
MaxPool((3, 3), stride = (2, 2)),
Conv((3, 3), 192 => 384, relu, pad = (1, 1)),
Conv((3, 3), 384 => 256, relu, pad = (1, 1)),
Conv((3, 3), 256 => 256, relu, pad = (1, 1)),
MaxPool((3, 3), stride = (2, 2)),
AdaptiveMeanPool((6, 6))),
Chain(MLUtils.flatten,
Dropout(0.5),
Dense(256 * 6 * 6, 4096, relu),
Dropout(0.5),
Dense(4096, 4096, relu),
Dense(4096, nclasses)))

return layers
end

"""
Expand All @@ -41,14 +41,13 @@ See also [`alexnet`](#).
- `nclasses`: the number of output classes
"""
struct AlexNet
layers
layers::Any
end

function AlexNet(; pretrain = false, nclasses = 1000)
layers = alexnet(nclasses = nclasses)
pretrain && loadpretrain!(layers, "AlexNet")

AlexNet(layers)
layers = alexnet(nclasses = nclasses)
pretrain && loadpretrain!(layers, "AlexNet")
AlexNet(layers)
end

@functor AlexNet
Expand Down
39 changes: 22 additions & 17 deletions src/convnets/convmixer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,24 @@ Creates a ConvMixer model.
"""
function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9),
patch_size::Dims{2} = (7, 7), activation = gelu, nclasses = 1000)
stem = conv_bn(patch_size, inchannels, planes, activation; preact = true, stride = patch_size[1])
blocks = [Chain(SkipConnection(Chain(conv_bn(kernel_size, planes, planes, activation;
preact = true, groups = planes, pad = SamePad())), +),
conv_bn((1, 1), planes, planes, activation; preact = true)...) for _ in 1:depth]
head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses))
return Chain(Chain(stem..., Chain(blocks)), head)
stem = conv_bn(patch_size, inchannels, planes, activation; preact = true,
stride = patch_size[1])
blocks = [Chain(SkipConnection(Chain(conv_bn(kernel_size, planes, planes, activation;
preact = true, groups = planes,
pad = SamePad())), +),
conv_bn((1, 1), planes, planes, activation; preact = true)...)
for _ in 1:depth]
head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses))
return Chain(Chain(stem..., Chain(blocks)), head)
end

convmixer_config = Dict(:base => Dict(:planes => 1536, :depth => 20, :kernel_size => (9, 9),
:patch_size => (7, 7)),
:patch_size => (7, 7)),
:small => Dict(:planes => 768, :depth => 32, :kernel_size => (7, 7),
:patch_size => (7, 7)),
:large => Dict(:planes => 1024, :depth => 20, :kernel_size => (9, 9),
:patch_size => (7, 7)))
:patch_size => (7, 7)),
:large => Dict(:planes => 1024, :depth => 20,
:kernel_size => (9, 9),
:patch_size => (7, 7)))

"""
ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
Expand All @@ -44,16 +48,17 @@ Creates a ConvMixer model.
- `nclasses`: number of classes in the output
"""
struct ConvMixer
layers
layers::Any
end

function ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
planes = convmixer_config[mode][:planes]
depth = convmixer_config[mode][:depth]
kernel_size = convmixer_config[mode][:kernel_size]
patch_size = convmixer_config[mode][:patch_size]
layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation, nclasses)
return ConvMixer(layers)
planes = convmixer_config[mode][:planes]
depth = convmixer_config[mode][:depth]
kernel_size = convmixer_config[mode][:kernel_size]
patch_size = convmixer_config[mode][:patch_size]
layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation,
nclasses)
return ConvMixer(layers)
end

@functor ConvMixer
Expand Down
101 changes: 52 additions & 49 deletions src/convnets/convnext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ Creates a single block of ConvNeXt.
- `drop_path_rate`: Stochastic depth rate.
- `λ`: Init value for LayerScale
"""
function convnextblock(planes, drop_path_rate = 0., λ = 1f-6)
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
swapdims((3, 1, 2, 4)),
LayerNorm(planes; ϵ = 1f-6),
mlp_block(planes, 4 * planes),
LayerScale(planes, λ),
swapdims((2, 3, 1, 4)),
DropPath(drop_path_rate)), +)
return layers
function convnextblock(planes, drop_path_rate = 0.0, λ = 1.0f-6)
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
swapdims((3, 1, 2, 4)),
LayerNorm(planes; ϵ = 1.0f-6),
mlp_block(planes, 4 * planes),
LayerScale(planes, λ),
swapdims((2, 3, 1, 4)),
DropPath(drop_path_rate)), +)
return layers
end

"""
Expand All @@ -34,45 +34,48 @@ Creates the layers for a ConvNeXt model.
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
- `nclasses`: number of output classes
"""
function convnext(depths, planes; inchannels = 3, drop_path_rate = 0., λ = 1f-6, nclasses = 1000)
@assert length(depths) == length(planes) "`planes` should have exactly one value for each block"

downsample_layers = []
stem = Chain(Conv((4, 4), inchannels => planes[1]; stride = 4),
ChannelLayerNorm(planes[1]; ϵ = 1f-6))
push!(downsample_layers, stem)
for m in 1:length(depths) - 1
downsample_layer = Chain(ChannelLayerNorm(planes[m]; ϵ = 1f-6),
Conv((2, 2), planes[m] => planes[m + 1]; stride = 2))
push!(downsample_layers, downsample_layer)
end

stages = []
dp_rates = LinRange{Float32}(0., drop_path_rate, sum(depths))
cur = 0
for i in 1:length(depths)
push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]])
cur += depths[i]
end

backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages))))
head = Chain(GlobalMeanPool(),
MLUtils.flatten,
LayerNorm(planes[end]),
Dense(planes[end], nclasses))

return Chain(Chain(backbone), head)
function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
nclasses = 1000)
@assert length(depths)==length(planes) "`planes` should have exactly one value for each block"
downsample_layers = []
stem = Chain(Conv((4, 4), inchannels => planes[1]; stride = 4),
ChannelLayerNorm(planes[1]; ϵ = 1.0f-6))
push!(downsample_layers, stem)
for m in 1:(length(depths) - 1)
downsample_layer = Chain(ChannelLayerNorm(planes[m]; ϵ = 1.0f-6),
Conv((2, 2), planes[m] => planes[m + 1]; stride = 2))
push!(downsample_layers, downsample_layer)
end
stages = []
dp_rates = LinRange{Float32}(0.0, drop_path_rate, sum(depths))
cur = 0
for i in 1:length(depths)
push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]])
cur += depths[i]
end
backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages))))
head = Chain(GlobalMeanPool(),
MLUtils.flatten,
LayerNorm(planes[end]),
Dense(planes[end], nclasses))

return Chain(Chain(backbone), head)
end

# Configurations for ConvNeXt models
convnext_configs = Dict(:tiny => Dict(:depths => [3, 3, 9, 3], :planes => [96, 192, 384, 768]),
:small => Dict(:depths => [3, 3, 27, 3], :planes => [96, 192, 384, 768]),
:base => Dict(:depths => [3, 3, 27, 3], :planes => [128, 256, 512, 1024]),
:large => Dict(:depths => [3, 3, 27, 3], :planes => [192, 384, 768, 1536]),
:xlarge => Dict(:depths => [3, 3, 27, 3], :planes => [256, 512, 1024, 2048]))
convnext_configs = Dict(:tiny => Dict(:depths => [3, 3, 9, 3],
:planes => [96, 192, 384, 768]),
:small => Dict(:depths => [3, 3, 27, 3],
:planes => [96, 192, 384, 768]),
:base => Dict(:depths => [3, 3, 27, 3],
:planes => [128, 256, 512, 1024]),
:large => Dict(:depths => [3, 3, 27, 3],
:planes => [192, 384, 768, 1536]),
:xlarge => Dict(:depths => [3, 3, 27, 3],
:planes => [256, 512, 1024, 2048]))

struct ConvNeXt
layers
layers::Any
end

"""
Expand All @@ -89,13 +92,13 @@ Creates a ConvNeXt model.

See also [`Metalhead.convnext`](#).
"""
function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0., λ = 1f-6,
function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
nclasses = 1000)
@assert mode in keys(convnext_configs) "`size` must be one of $(collect(keys(convnext_configs)))"
depths = convnext_configs[mode][:depths]
planes = convnext_configs[mode][:planes]
layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses)
return ConvNeXt(layers)
@assert mode in keys(convnext_configs) "`size` must be one of $(collect(keys(convnext_configs)))"
depths = convnext_configs[mode][:depths]
planes = convnext_configs[mode][:planes]
layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses)
return ConvNeXt(layers)
end

(m::ConvNeXt)(x) = m.layers(x)
Expand Down
Loading