Skip to content

Commit

Permalink
Merge pull request #168 from theabhirath/style-3
Browse files Browse the repository at this point in the history
  • Loading branch information
darsnack authored Jun 11, 2022
2 parents f97a61d + d4f1d07 commit ee344a9
Show file tree
Hide file tree
Showing 29 changed files with 1,495 additions and 1,382 deletions.
9 changes: 9 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
style = "sciml"
whitespace_in_kwargs = true
format_docstrings = true
always_for_in = true
join_lines_based_on_source = true
separate_kwargs_with_semicolon = true
always_use_return = true
margin = 92
indent = 4
4 changes: 4 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# .git-blame-ignore-revs
# Switched to SciML style for code
d5d28f0ef6e1e253ecf3fdbbec2f511836c8767b
70d639de532b046980cbea8d17fb1829e04cccfe
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Metalhead"
uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
version = "0.7.1"
version = "0.7.2-DEV"

[deps]
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
Expand All @@ -16,7 +16,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
BSON = "0.3.2"
Flux = "0.13"
Functors = "0.2"
MLUtils = "0.2"
MLUtils = "0.2.6"
NNlib = "0.7.34, 0.8"
julia = "1.6"

Expand Down
25 changes: 13 additions & 12 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,23 @@ include("vit-based/vit.jl")

include("pretrain.jl")

export AlexNet,
VGG, VGG11, VGG13, VGG16, VGG19,
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
GoogLeNet, Inception3, SqueezeNet,
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
ResNeXt,
MobileNetv1, MobileNetv2, MobileNetv3,
MLPMixer, ResMLP, gMLP,
ViT,
ConvNeXt, ConvMixer
export AlexNet,
VGG, VGG11, VGG13, VGG16, VGG19,
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152,
GoogLeNet, Inception3, SqueezeNet,
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
ResNeXt,
MobileNetv1, MobileNetv2, MobileNetv3,
MLPMixer, ResMLP, gMLP,
ViT,
ConvNeXt, ConvMixer

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt,
for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet,
:ResNeXt,
:MobileNetv1, :MobileNetv2, :MobileNetv3,
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvNeXt, :ConvMixer)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end

end # module
52 changes: 27 additions & 25 deletions src/convnets/alexnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,27 @@ Create an AlexNet model
([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)).
# Arguments
- `nclasses`: the number of output classes
- `nclasses`: the number of output classes
"""
function alexnet(; nclasses = 1000)
layers = Chain(Chain(Conv((11, 11), 3 => 64, stride = (4, 4), relu, pad = (2, 2)),
MaxPool((3, 3), stride = (2, 2)),
Conv((5, 5), 64 => 192, relu, pad = (2, 2)),
MaxPool((3, 3), stride = (2, 2)),
Conv((3, 3), 192 => 384, relu, pad = (1, 1)),
Conv((3, 3), 384 => 256, relu, pad = (1, 1)),
Conv((3, 3), 256 => 256, relu, pad = (1, 1)),
MaxPool((3, 3), stride = (2, 2)),
AdaptiveMeanPool((6,6))),
Chain(MLUtils.flatten,
Dropout(0.5),
Dense(256 * 6 * 6, 4096, relu),
Dropout(0.5),
Dense(4096, 4096, relu),
Dense(4096, nclasses)))

return layers
layers = Chain(Chain(Conv((11, 11), 3 => 64, relu; stride = (4, 4), pad = (2, 2)),
MaxPool((3, 3); stride = (2, 2)),
Conv((5, 5), 64 => 192, relu; pad = (2, 2)),
MaxPool((3, 3); stride = (2, 2)),
Conv((3, 3), 192 => 384, relu; pad = (1, 1)),
Conv((3, 3), 384 => 256, relu; pad = (1, 1)),
Conv((3, 3), 256 => 256, relu; pad = (1, 1)),
MaxPool((3, 3); stride = (2, 2)),
AdaptiveMeanPool((6, 6))),
Chain(MLUtils.flatten,
Dropout(0.5),
Dense(256 * 6 * 6, 4096, relu),
Dropout(0.5),
Dense(4096, 4096, relu),
Dense(4096, nclasses)))

return layers
end

"""
Expand All @@ -34,21 +35,22 @@ Create a `AlexNet`.
See also [`alexnet`](#).
!!! warning
`AlexNet` does not currently support pretrained weights.
# Arguments
- `pretrain`: set to `true` to load pre-trained weights for ImageNet
- `nclasses`: the number of output classes
- `pretrain`: set to `true` to load pre-trained weights for ImageNet
- `nclasses`: the number of output classes
"""
struct AlexNet
layers
layers::Any
end

function AlexNet(; pretrain = false, nclasses = 1000)
layers = alexnet(nclasses = nclasses)
pretrain && loadpretrain!(layers, "AlexNet")

AlexNet(layers)
layers = alexnet(; nclasses = nclasses)
pretrain && loadpretrain!(layers, "AlexNet")
return AlexNet(layers)
end

@functor AlexNet
Expand Down
63 changes: 35 additions & 28 deletions src/convnets/convmixer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,35 @@ Creates a ConvMixer model.
([reference](https://arxiv.org/abs/2201.09792))
# Arguments
- `planes`: number of planes in the output of each block
- `depth`: number of layers
- `inchannels`: number of channels in the input
- `kernel_size`: kernel size of the convolutional layers
- `patch_size`: size of the patches
- `activation`: activation function used after the convolutional layers
- `nclasses`: number of classes in the output
- `planes`: number of planes in the output of each block
- `depth`: number of layers
- `inchannels`: number of channels in the input
- `kernel_size`: kernel size of the convolutional layers
- `patch_size`: size of the patches
- `activation`: activation function used after the convolutional layers
- `nclasses`: number of classes in the output
"""
function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9),
patch_size::Dims{2} = (7, 7), activation = gelu, nclasses = 1000)
stem = conv_bn(patch_size, inchannels, planes, activation; preact = true, stride = patch_size[1])
blocks = [Chain(SkipConnection(Chain(conv_bn(kernel_size, planes, planes, activation;
preact = true, groups = planes, pad = SamePad())), +),
conv_bn((1, 1), planes, planes, activation; preact = true)...) for _ in 1:depth]
head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses))
return Chain(Chain(stem..., Chain(blocks)), head)
stem = conv_bn(patch_size, inchannels, planes, activation; preact = true,
stride = patch_size[1])
blocks = [Chain(SkipConnection(Chain(conv_bn(kernel_size, planes, planes, activation;
preact = true, groups = planes,
pad = SamePad())), +),
conv_bn((1, 1), planes, planes, activation; preact = true)...)
for _ in 1:depth]
head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses))
return Chain(Chain(stem..., Chain(blocks)), head)
end

convmixer_config = Dict(:base => Dict(:planes => 1536, :depth => 20, :kernel_size => (9, 9),
:patch_size => (7, 7)),
:patch_size => (7, 7)),
:small => Dict(:planes => 768, :depth => 32, :kernel_size => (7, 7),
:patch_size => (7, 7)),
:large => Dict(:planes => 1024, :depth => 20, :kernel_size => (9, 9),
:patch_size => (7, 7)))
:patch_size => (7, 7)),
:large => Dict(:planes => 1024, :depth => 20,
:kernel_size => (9, 9),
:patch_size => (7, 7)))

"""
ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
Expand All @@ -38,22 +43,24 @@ Creates a ConvMixer model.
([reference](https://arxiv.org/abs/2201.09792))
# Arguments
- `mode`: the mode of the model, either `:base`, `:small` or `:large`
- `inchannels`: number of channels in the input
- `activation`: activation function used after the convolutional layers
- `nclasses`: number of classes in the output
- `mode`: the mode of the model, either `:base`, `:small` or `:large`
- `inchannels`: number of channels in the input
- `activation`: activation function used after the convolutional layers
- `nclasses`: number of classes in the output
"""
struct ConvMixer
layers
layers::Any
end

function ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
planes = convmixer_config[mode][:planes]
depth = convmixer_config[mode][:depth]
kernel_size = convmixer_config[mode][:kernel_size]
patch_size = convmixer_config[mode][:patch_size]
layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation, nclasses)
return ConvMixer(layers)
planes = convmixer_config[mode][:planes]
depth = convmixer_config[mode][:depth]
kernel_size = convmixer_config[mode][:kernel_size]
patch_size = convmixer_config[mode][:patch_size]
layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation,
nclasses)
return ConvMixer(layers)
end

@functor ConvMixer
Expand Down
133 changes: 71 additions & 62 deletions src/convnets/convnext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@ Creates a single block of ConvNeXt.
([reference](https://arxiv.org/abs/2201.03545))
# Arguments:
- `planes`: number of input channels.
- `drop_path_rate`: Stochastic depth rate.
- `λ`: Init value for LayerScale
- `planes`: number of input channels.
- `drop_path_rate`: Stochastic depth rate.
- `λ`: Init value for LayerScale
"""
function convnextblock(planes, drop_path_rate = 0., λ = 1f-6)
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
swapdims((3, 1, 2, 4)),
LayerNorm(planes; ϵ = 1f-6),
mlp_block(planes, 4 * planes),
LayerScale(planes, λ),
swapdims((2, 3, 1, 4)),
DropPath(drop_path_rate)), +)
return layers
function convnextblock(planes, drop_path_rate = 0.0, λ = 1.0f-6)
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
swapdims((3, 1, 2, 4)),
LayerNorm(planes; ϵ = 1.0f-6),
mlp_block(planes, 4 * planes),
LayerScale(planes, λ),
swapdims((2, 3, 1, 4)),
DropPath(drop_path_rate)), +)
return layers
end

"""
Expand All @@ -27,52 +28,59 @@ Creates the layers for a ConvNeXt model.
([reference](https://arxiv.org/abs/2201.03545))
# Arguments:
- `inchannels`: number of input channels.
- `depths`: list with configuration for depth of each block
- `planes`: list with configuration for number of output channels in each block
- `drop_path_rate`: Stochastic depth rate.
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
- `nclasses`: number of output classes
- `inchannels`: number of input channels.
- `depths`: list with configuration for depth of each block
- `planes`: list with configuration for number of output channels in each block
- `drop_path_rate`: Stochastic depth rate.
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
- `nclasses`: number of output classes
"""
function convnext(depths, planes; inchannels = 3, drop_path_rate = 0., λ = 1f-6, nclasses = 1000)
@assert length(depths) == length(planes) "`planes` should have exactly one value for each block"

downsample_layers = []
stem = Chain(Conv((4, 4), inchannels => planes[1]; stride = 4),
ChannelLayerNorm(planes[1]; ϵ = 1f-6))
push!(downsample_layers, stem)
for m in 1:length(depths) - 1
downsample_layer = Chain(ChannelLayerNorm(planes[m]; ϵ = 1f-6),
Conv((2, 2), planes[m] => planes[m + 1]; stride = 2))
push!(downsample_layers, downsample_layer)
end

stages = []
dp_rates = LinRange{Float32}(0., drop_path_rate, sum(depths))
cur = 0
for i in 1:length(depths)
push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]])
cur += depths[i]
end

backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages))))
head = Chain(GlobalMeanPool(),
MLUtils.flatten,
LayerNorm(planes[end]),
Dense(planes[end], nclasses))

return Chain(Chain(backbone), head)
function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
nclasses = 1000)
@assert length(depths)==length(planes) "`planes` should have exactly one value for each block"

downsample_layers = []
stem = Chain(Conv((4, 4), inchannels => planes[1]; stride = 4),
ChannelLayerNorm(planes[1]; ϵ = 1.0f-6))
push!(downsample_layers, stem)
for m in 1:(length(depths) - 1)
downsample_layer = Chain(ChannelLayerNorm(planes[m]; ϵ = 1.0f-6),
Conv((2, 2), planes[m] => planes[m + 1]; stride = 2))
push!(downsample_layers, downsample_layer)
end

stages = []
dp_rates = LinRange{Float32}(0.0, drop_path_rate, sum(depths))
cur = 0
for i in 1:length(depths)
push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]])
cur += depths[i]
end

backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages))))
head = Chain(GlobalMeanPool(),
MLUtils.flatten,
LayerNorm(planes[end]),
Dense(planes[end], nclasses))

return Chain(Chain(backbone), head)
end

# Configurations for ConvNeXt models
convnext_configs = Dict(:tiny => Dict(:depths => [3, 3, 9, 3], :planes => [96, 192, 384, 768]),
:small => Dict(:depths => [3, 3, 27, 3], :planes => [96, 192, 384, 768]),
:base => Dict(:depths => [3, 3, 27, 3], :planes => [128, 256, 512, 1024]),
:large => Dict(:depths => [3, 3, 27, 3], :planes => [192, 384, 768, 1536]),
:xlarge => Dict(:depths => [3, 3, 27, 3], :planes => [256, 512, 1024, 2048]))
convnext_configs = Dict(:tiny => Dict(:depths => [3, 3, 9, 3],
:planes => [96, 192, 384, 768]),
:small => Dict(:depths => [3, 3, 27, 3],
:planes => [96, 192, 384, 768]),
:base => Dict(:depths => [3, 3, 27, 3],
:planes => [128, 256, 512, 1024]),
:large => Dict(:depths => [3, 3, 27, 3],
:planes => [192, 384, 768, 1536]),
:xlarge => Dict(:depths => [3, 3, 27, 3],
:planes => [256, 512, 1024, 2048]))

struct ConvNeXt
layers
layers::Any
end

"""
Expand All @@ -82,20 +90,21 @@ Creates a ConvNeXt model.
([reference](https://arxiv.org/abs/2201.03545))
# Arguments:
- `inchannels`: number of input channels.
- `drop_path_rate`: Stochastic depth rate.
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
- `nclasses`: number of output classes
- `inchannels`: number of input channels.
- `drop_path_rate`: Stochastic depth rate.
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
- `nclasses`: number of output classes
See also [`Metalhead.convnext`](#).
"""
function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0., λ = 1f-6,
function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
nclasses = 1000)
@assert mode in keys(convnext_configs) "`size` must be one of $(collect(keys(convnext_configs)))"
depths = convnext_configs[mode][:depths]
planes = convnext_configs[mode][:planes]
layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses)
return ConvNeXt(layers)
@assert mode in keys(convnext_configs) "`size` must be one of $(collect(keys(convnext_configs)))"
depths = convnext_configs[mode][:depths]
planes = convnext_configs[mode][:planes]
layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses)
return ConvNeXt(layers)
end

(m::ConvNeXt)(x) = m.layers(x)
Expand Down
Loading

0 comments on commit ee344a9

Please sign in to comment.