Skip to content

Commit

Permalink
Merge pull request #270 from FluxML/a2/vgg-fix
Browse files Browse the repository at this point in the history
Deprecate older VGG API
  • Loading branch information
theabhirath authored Jan 5, 2024
2 parents 2ad3337 + 1207f84 commit eed0c64
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 58 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Metalhead"
uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
version = "0.9.1"
version = "0.9.2"

[deps]
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
Expand Down
3 changes: 3 additions & 0 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ include("vit-based/vit.jl")
# Load pretrained weights
include("pretrain.jl")

# deprecated
include("deprecations.jl")

# export model functions
export AlexNet, VGG, ResNet, WideResNet, ResNeXt, DenseNet,
GoogLeNet, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
Expand Down
88 changes: 31 additions & 57 deletions src/convnets/vgg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,11 @@ A VGG block of convolution layers
- `batchnorm`: set to `true` to include batch normalization after each convolution
"""
function vgg_block(ifilters::Integer, ofilters::Integer, depth::Integer, batchnorm::Bool)
k = (3, 3)
p = (1, 1)
layers = []
for _ in 1:depth
if batchnorm
append!(layers, conv_norm(k, ifilters, ofilters; pad = p))
else
push!(layers, Conv(k, ifilters => ofilters, relu; pad = p))
end
ifilters = ofilters
norm_layer = batchnorm ? BatchNorm : identity
layers = [conv_norm((3, 3), ifilters, ofilters; pad = (1, 1), norm_layer)...]
for i in 2:depth
append!(layers, conv_norm((3, 3), ofilters, ofilters; pad = (1, 1), norm_layer))
end
ifilters = ofilters
return layers
end

Expand Down Expand Up @@ -77,7 +70,8 @@ function vgg_classifier_layers(imsize::NTuple{3, <:Integer}, nclasses::Integer,
end

"""
vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout_prob)
vgg(imsize::Dims{2}; config, batchnorm::Bool = false, fcsize::Integer = 4096,
dropout_prob = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000)
Create a VGG model
([reference](https://arxiv.org/abs/1409.1556v6)).
Expand All @@ -102,44 +96,10 @@ function vgg(imsize::Dims{2}; config, batchnorm::Bool = false, fcsize::Integer =
return Chain(Chain(conv...), class)
end

const VGG_CONV_CONFIGS = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)],
:B => [(64, 2), (128, 2), (256, 2), (512, 2), (512, 2)],
:D => [(64, 2), (128, 2), (256, 3), (512, 3), (512, 3)],
:E => [(64, 2), (128, 2), (256, 4), (512, 4), (512, 4)])

const VGG_CONFIGS = Dict(11 => :A, 13 => :B, 16 => :D, 19 => :E)

"""
VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, dropout_prob)
Construct a VGG model with the specified input image size. Typically, the image size is `(224, 224)`.
## Keyword Arguments:
- `config` : VGG convolutional block configuration. It is defined as a vector of tuples
`(output_channels, num_convolutions)` for each block
- `inchannels`: number of input channels
- `batchnorm`: set to `true` to use batch normalization after each convolution
- `nclasses`: number of output classes
- `fcsize`: intermediate fully connected layer size
(see [`Metalhead.vgg_classifier_layers`](@ref))
- `dropout_prob`: dropout level between fully connected layers
"""
struct VGG
layers::Any
end
@functor VGG

function VGG(imsize::Dims{2}; config, batchnorm::Bool = false, dropout_prob = 0.5,
inchannels::Integer = 3, nclasses::Integer = 1000)
layers = vgg(imsize; config, inchannels, batchnorm, nclasses, dropout_prob)
return VGG(layers)
end

(m::VGG)(x) = m.layers(x)

backbone(m::VGG) = m.layers[1]
classifier(m::VGG) = m.layers[2]
const VGG_CONFIGS = Dict(11 => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)],
13 => [(64, 2), (128, 2), (256, 2), (512, 2), (512, 2)],
16 => [(64, 2), (128, 2), (256, 3), (512, 3), (512, 3)],
19 => [(64, 2), (128, 2), (256, 4), (512, 4), (512, 4)])

"""
VGG(depth::Integer; pretrain::Bool = false, batchnorm::Bool = false,
Expand All @@ -148,21 +108,30 @@ classifier(m::VGG) = m.layers[2]
Create a VGG style model with specified `depth`.
([reference](https://arxiv.org/abs/1409.1556v6)).
!!! warning
`VGG` does not currently support pretrained weights for the `batchnorm = true` option.
# Arguments
- `depth`: the depth of the VGG model. Must be one of [11, 13, 16, 19].
- `pretrain`: set to `true` to load pre-trained model weights for ImageNet
- `batchnorm`: set to `true` to use batch normalization after each convolution
- `inchannels`: number of input channels
- `nclasses`: number of output classes
- `depth`: the depth of the VGG model. Must be one of [11, 13, 16, 19].
- `pretrain`: set to `true` to load pre-trained model weights for ImageNet
- `batchnorm`: set to `true` to use batch normalization after each convolution
- `inchannels`: number of input channels
- `nclasses`: number of output classes
See also [`vgg`](@ref).
"""
struct VGG
layers::Any
end
@functor VGG

function VGG(depth::Integer; pretrain::Bool = false, batchnorm::Bool = false,
inchannels::Integer = 3, nclasses::Integer = 1000)
_checkconfig(depth, keys(VGG_CONFIGS))
model = VGG((224, 224); config = VGG_CONV_CONFIGS[VGG_CONFIGS[depth]], batchnorm,
inchannels, nclasses)
layers = vgg((224, 224); config = VGG_CONFIGS[depth], batchnorm, inchannels, nclasses)
model = VGG(layers)
if pretrain
artifact_name = string("vgg", depth)
if batchnorm
Expand All @@ -174,3 +143,8 @@ function VGG(depth::Integer; pretrain::Bool = false, batchnorm::Bool = false,
end
return model
end

(m::VGG)(x) = m.layers(x)

backbone(m::VGG) = m.layers[1]
classifier(m::VGG) = m.layers[2]
9 changes: 9 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Deprecated; to be removed in a future release
function VGG(imsize::Dims{2}; config, batchnorm::Bool = false, dropout_prob = 0.5,
inchannels::Integer = 3, nclasses::Integer = 1000)
Base.depwarn("The `VGG(imsize; config, inchannels, batchnorm, nclasses)` constructor
will be deprecated in a future release. Please use `vgg(imsize; config,
inchannels, batchnorm, nclasses)` instead for the same functionality.", :VGG)
layers = vgg(imsize; config, inchannels, batchnorm, nclasses, dropout_prob)
return VGG(layers)
end

2 comments on commit eed0c64

@theabhirath
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/98248

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.2 -m "<description of version>" eed0c64c4fc9b8de5d8b733af26421140a27fbec
git push origin v0.9.2

Please sign in to comment.