Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding UNet Model #210

Merged
merged 30 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ba54cf0
model implemented
shivance Dec 27, 2022
11c50d9
adding documentation
shivance Dec 27, 2022
ca73586
ran juliaformatter
shivance Dec 28, 2022
552a8fd
removed custom forward pass using Parallel
shivance Jan 1, 2023
c577aed
removing _random_normal
shivance Jan 1, 2023
fb642c4
incorporating suggested changes
shivance Jan 2, 2023
7c7b1ee
Revert "ran juliaformatter"
shivance Jan 3, 2023
99f07ad
adapting to fastai's unet impl
shivance Jan 10, 2023
fc756d9
undoing utilities formatting
shivance Jan 10, 2023
60b082c
formatting + documentation + func signature
shivance Jan 10, 2023
2f1cc6d
adding unit tests for unet
shivance Jan 10, 2023
8d2ba2b
configuring CI
shivance Jan 10, 2023
77a3148
configuring CI
shivance Jan 10, 2023
8aebd14
Merge branch 'master' into unet
shivance Jan 10, 2023
429096b
Update convnets.jl
shivance Jan 10, 2023
d761126
Update convnets.jl
shivance Jan 10, 2023
1b5d2b7
updated test
shivance Jan 11, 2023
354e3c4
minor fixes
shivance Jan 12, 2023
6494be7
typing fix
shivance Jan 12, 2023
2d68f61
Update src/utilities.jl
shivance Jan 12, 2023
627480f
fixing ci
shivance Jan 12, 2023
4012fb2
renaming:
shivance Jan 16, 2023
016cef4
fixing test
shivance Jan 22, 2023
6097c57
Update .github/workflows/CI.yml
shivance Jan 22, 2023
98b4c30
Update src/convnets/unet.jl
shivance Jan 22, 2023
54c334f
Update src/convnets/unet.jl
shivance Jan 22, 2023
4fae8d6
incorporating suggestions
shivance Jan 22, 2023
4735dff
minor change
shivance Jan 22, 2023
3bebe5a
minor edit
shivance Jan 22, 2023
65aa5e8
Update src/convnets/unet.jl
shivance Jan 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ julia> ]add Metalhead
| [ViT](https://arxiv.org/abs/2010.11929) | [`ViT`](@ref) | N |
| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](@ref) | N |
| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](@ref) | N |
| [UNet](https://arxiv.org/abs/1505.04597v1) | [`UNet`](@ref) | N |

To contribute new models, see our [contributing docs](@ref Contributing-to-Metalhead.jl).

Expand Down
5 changes: 3 additions & 2 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ include("convnets/densenet.jl")
include("convnets/squeezenet.jl")
include("convnets/convnext.jl")
include("convnets/convmixer.jl")
include("convnets/unet.jl")

# Mixers
include("mixers/core.jl")
Expand All @@ -73,15 +74,15 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, MNASNet,
EfficientNet, EfficientNetv2, ConvMixer, ConvNeXt,
MLPMixer, ResMLP, gMLP, ViT
MLPMixer, ResMLP, gMLP, ViT, UNet

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt,
:SEResNet, :SEResNeXt, :Res2Net, :Res2NeXt, :GoogLeNet, :DenseNet,
:Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
:MobileNetv1, :MobileNetv2, :MobileNetv3, :MNASNet,
:EfficientNet, :EfficientNetv2, :ConvMixer, :ConvNeXt,
:MLPMixer, :ResMLP, :gMLP, :ViT)
:MLPMixer, :ResMLP, :gMLP, :ViT, :UNet)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end

Expand Down
97 changes: 97 additions & 0 deletions src/convnets/unet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
function PixelShuffleICNR(inplanes, outplanes; r = 2)
return Chain(basic_conv_bn((1, 1), inplanes, outplanes * (r^2)),
Flux.PixelShuffle(r))
end

function UNetCombineLayer(inplanes, outplanes)
return Chain(basic_conv_bn((3, 3), inplanes, outplanes; pad = 1),
basic_conv_bn((3, 3), outplanes, outplanes; pad = 1))
end

function UNetMiddleBlock(inplanes)
return Chain(basic_conv_bn((3, 3), inplanes, 2inplanes; pad = 1),
basic_conv_bn((3, 3), 2inplanes, inplanes; pad = 1))
end
shivance marked this conversation as resolved.
Show resolved Hide resolved

function UNetFinalBlock(inplanes, outplanes)
return Chain(basicblock(inplanes, inplanes; reduction_factor = 1),
basic_conv_bn((1, 1), inplanes, outplanes))
end

function unetlayers(layers, sz; outplanes = nothing, skip_upscale = 0,
m_middle = _ -> (identity,))
isempty(layers) && return m_middle(sz[end - 1])

layer, layers = layers[1], layers[2:end]
outsz = Flux.outputsize(layer, sz)
does_downscale = sz[1] ÷ 2 == outsz[1]

if !does_downscale
return Chain(layer, unetlayers(layers, outsz; outplanes, skip_upscale)...)
elseif does_downscale && skip_upscale > 0
return Chain(layer,
unetlayers(layers, outsz; skip_upscale = skip_upscale - 1,
outplanes)...)
else
childunet = Chain(unetlayers(layers, outsz; skip_upscale)...)
outsz = Flux.outputsize(childunet, outsz)

inplanes = sz[end - 1]
midplanes = outsz[end - 1]
outplanes = isnothing(outplanes) ? inplanes : outplanes

return UNetBlock(Chain(layer, childunet),
inplanes, midplanes, outplanes)
end
end

function UNetBlock(m_child, inplanes, midplanes, outplanes = 2inplanes)
theabhirath marked this conversation as resolved.
Show resolved Hide resolved
return Chain(;
upsample = SkipConnection(Chain(; child = m_child,
upsample = PixelShuffleICNR(midplanes,
midplanes)),
Parallel(cat_channels, identity,
BatchNorm(inplanes))),
act = xs -> relu.(xs),
combine = UNetCombineLayer(inplanes + midplanes, outplanes))
end
shivance marked this conversation as resolved.
Show resolved Hide resolved

"""
UNet(backbone, inputsize)
DenseNet(transition_configs::NTuple{N,Integer})
shivance marked this conversation as resolved.
Show resolved Hide resolved

Creates a UNet model with specified backbone. Backbone of Any Metalhead model can be used as
encoder .
Set `pretrain = true` to load the model with pre-trained weights for ImageNet.
shivance marked this conversation as resolved.
Show resolved Hide resolved

!!! warning

`UNet` does not currently support pretrained weights.

See also [`Metalhead.UNet`](@ref).
shivance marked this conversation as resolved.
Show resolved Hide resolved
"""
struct UNet
layers::Any
end
@functor UNet

function UNet(backbone,
inputsize,
outplanes,
final = UNetFinalBlock,
fdownscale::Integer = 0,
kwargs...)
backbonelayers = collect(iterlayers(backbone))
layers = unetlayers(backbonelayers,
inputsize;
m_middle = UNetMiddleBlock,
skip_upscale = fdownscale,
kwargs...)

outsz = Flux.outputsize(layers, inputsize)
layers = Chain(layers, final(outsz[end - 1], outplanes))

return UNet(layers)
end
shivance marked this conversation as resolved.
Show resolved Hide resolved

(m::UNet)(x::AbstractArray) = m.layers(x)
43 changes: 23 additions & 20 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ seconddimmean(x) = dropdims(mean(x; dims = 2); dims = 2)
# utility function for making sure that all layers have a channel size divisible by 8
# used by MobileNet variants
function _round_channels(channels::Number, divisor::Integer = 8, min_value::Integer = 0)
new_channels = max(min_value, floor(Int, channels + divisor / 2) ÷ divisor * divisor)
# Make sure that round down does not go down by more than 10%
return new_channels < 0.9 * channels ? new_channels + divisor : new_channels
new_channels = max(min_value, floor(Int, channels + divisor / 2) ÷ divisor * divisor)
# Make sure that round down does not go down by more than 10%
return new_channels < 0.9 * channels ? new_channels + divisor : new_channels
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all the indentation changes in this file are bad

end

"""
addact(activation = relu, xs...)
addact(activation = relu, xs...)

Convenience function for applying an activation function to the output after
summing up the input arrays. Useful as the `connection` argument for the block
Expand All @@ -19,7 +19,7 @@ function in [`resnet`](@ref).
addact(activation = relu, xs...) = activation(sum(xs))

"""
actadd(activation = relu, xs...)
actadd(activation = relu, xs...)

Convenience function for adding input arrays after applying an activation
function to them. Useful as the `connection` argument for the block function in
Expand All @@ -28,7 +28,7 @@ function to them. Useful as the `connection` argument for the block function in
actadd(activation = relu, xs...) = sum(activation.(x) for x in xs)

"""
cat_channels(x, y, zs...)
cat_channels(x, y, zs...)

Concatenate `x` and `y` (and any `z`s) along the channel dimension (third dimension).
Equivalent to `cat(x, y, zs...; dims=3)`.
Expand All @@ -40,7 +40,7 @@ cat_channels(x::Tuple, y::AbstractArray...) = cat_channels(x..., y...)
cat_channels(x::Tuple) = cat_channels(x...)

"""
swapdims(perm)
swapdims(perm)

Convenience function for permuting the dimensions of an array.
`perm` is a vector or tuple specifying a permutation of the input dimensions.
Expand All @@ -50,32 +50,35 @@ swapdims(perm) = Base.Fix2(permutedims, perm)

# Utility function for pretty printing large models
function _maybe_big_show(io, model)
if isdefined(Flux, :_big_show)
if isnothing(get(io, :typeinfo, nothing)) # e.g. top level in REPL
Flux._big_show(io, model)
else
show(io, model)
end
else
show(io, model)
end
if isdefined(Flux, :_big_show)
if isnothing(get(io, :typeinfo, nothing)) # e.g. top level in REPL
Flux._big_show(io, model)
else
show(io, model)
end
else
show(io, model)
end
end

"""
linear_scheduler(drop_prob = 0.0; start_value = 0.0, depth)
linear_scheduler(drop_prob::Nothing; depth::Integer)
linear_scheduler(drop_prob = 0.0; start_value = 0.0, depth)
linear_scheduler(drop_prob::Nothing; depth::Integer)

Returns the dropout probabilities for a given depth using the linear scaling rule. Note
that this returns evenly spaced values between `start_value` and `drop_prob`, not including
`drop_prob`. If `drop_prob` is `nothing`, it returns a `Vector` of length `depth` with all
values equal to `nothing`.
"""
function linear_scheduler(drop_prob = 0.0; depth::Integer, start_value = 0.0)
return LinRange(start_value, drop_prob, depth + 1)[1:depth]
return LinRange(start_value, drop_prob, depth + 1)[1:depth]
end
linear_scheduler(drop_prob::Nothing; depth::Integer) = fill(drop_prob, depth)

# Utility function for depth and configuration checks in models
function _checkconfig(config, configs)
@assert config in configs "Invalid configuration. Must be one of $(sort(collect(configs)))."
@assert config in configs "Invalid configuration. Must be one of $(sort(collect(configs)))."
end

# Utility function to return Iterator over layers, adopted from FastAI.jl
iterlayers(m::Chain) = Iterators.flatten(iterlayers(l) for l in m.layers)
shivance marked this conversation as resolved.
Show resolved Hide resolved