Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding UNet Model #210

Merged
merged 30 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ba54cf0
model implemented
shivance Dec 27, 2022
11c50d9
adding documentation
shivance Dec 27, 2022
ca73586
ran juliaformatter
shivance Dec 28, 2022
552a8fd
removed custom forward pass using Parallel
shivance Jan 1, 2023
c577aed
removing _random_normal
shivance Jan 1, 2023
fb642c4
incorporating suggested changes
shivance Jan 2, 2023
7c7b1ee
Revert "ran juliaformatter"
shivance Jan 3, 2023
99f07ad
adapting to fastai's unet impl
shivance Jan 10, 2023
fc756d9
undoing utilities formatting
shivance Jan 10, 2023
60b082c
formatting + documentation + func signature
shivance Jan 10, 2023
2f1cc6d
adding unit tests for unet
shivance Jan 10, 2023
8d2ba2b
configuring CI
shivance Jan 10, 2023
77a3148
configuring CI
shivance Jan 10, 2023
8aebd14
Merge branch 'master' into unet
shivance Jan 10, 2023
429096b
Update convnets.jl
shivance Jan 10, 2023
d761126
Update convnets.jl
shivance Jan 10, 2023
1b5d2b7
updated test
shivance Jan 11, 2023
354e3c4
minor fixes
shivance Jan 12, 2023
6494be7
typing fix
shivance Jan 12, 2023
2d68f61
Update src/utilities.jl
shivance Jan 12, 2023
627480f
fixing ci
shivance Jan 12, 2023
4012fb2
renaming:
shivance Jan 16, 2023
016cef4
fixing test
shivance Jan 22, 2023
6097c57
Update .github/workflows/CI.yml
shivance Jan 22, 2023
98b4c30
Update src/convnets/unet.jl
shivance Jan 22, 2023
54c334f
Update src/convnets/unet.jl
shivance Jan 22, 2023
4fae8d6
incorporating suggestions
shivance Jan 22, 2023
4735dff
minor change
shivance Jan 22, 2023
3bebe5a
minor edit
shivance Jan 22, 2023
65aa5e8
Update src/convnets/unet.jl
shivance Jan 26, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
- 'r"/*/SEResNet*"'
- '[r"Res2Net", r"Res2NeXt"]'
- '"Inception"'
- '"DenseNet"'
- '["UNet", "DenseNet"]'
shivance marked this conversation as resolved.
Show resolved Hide resolved
- '["ConvNeXt", "ConvMixer"]'
- 'r"Mixers"'
- 'r"ViTs"'
Expand Down
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ julia> ]add Metalhead
| [ViT](https://arxiv.org/abs/2010.11929) | [`ViT`](@ref) | N |
| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](@ref) | N |
| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](@ref) | N |
| [UNet](https://arxiv.org/abs/1505.04597v1) | [`UNet`](@ref) | N |

To contribute new models, see our [contributing docs](@ref Contributing-to-Metalhead.jl).

Expand Down
5 changes: 3 additions & 2 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ include("convnets/densenet.jl")
include("convnets/squeezenet.jl")
include("convnets/convnext.jl")
include("convnets/convmixer.jl")
include("convnets/unet.jl")

# Mixers
include("mixers/core.jl")
Expand All @@ -73,15 +74,15 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, MNASNet,
EfficientNet, EfficientNetv2, ConvMixer, ConvNeXt,
MLPMixer, ResMLP, gMLP, ViT
MLPMixer, ResMLP, gMLP, ViT, UNet

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt,
:SEResNet, :SEResNeXt, :Res2Net, :Res2NeXt, :GoogLeNet, :DenseNet,
:Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
:MobileNetv1, :MobileNetv2, :MobileNetv3, :MNASNet,
:EfficientNet, :EfficientNetv2, :ConvMixer, :ConvNeXt,
:MLPMixer, :ResMLP, :gMLP, :ViT)
:MLPMixer, :ResMLP, :gMLP, :ViT, :UNet)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end

Expand Down
123 changes: 123 additions & 0 deletions src/convnets/unet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
function PixelShuffleICNR(inplanes, outplanes; r = 2)
return Chain(Chain(basic_conv_bn((1, 1), inplanes, outplanes * (r^2)...)),
Flux.PixelShuffle(r))
end

function UNetCombineLayer(inplanes, outplanes)
return Chain(Chain(basic_conv_bn((3, 3), inplanes, outplanes; pad = 1)...),
Chain(basic_conv_bn((3, 3), outplanes, outplanes; pad = 1)...))
end

function UNetMiddleBlock(inplanes)
return Chain(Chain(basic_conv_bn((3, 3), inplanes, 2 * inplanes; pad = 1)...),
Chain(basic_conv_bn((3, 3), 2 * inplanes, inplanes; pad = 1)...))
end

function UNetFinalBlock(inplanes, outplanes)
return Chain(basicblock(inplanes, inplanes; reduction_factor = 1),
Chain(basic_conv_bn((1, 1), inplanes, outplanes)...))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions should be lower case so that they aren't confused with constructors/types.


function unetlayers(layers, sz; outplanes = nothing, skip_upscale = 0,
m_middle = _ -> (identity,))
isempty(layers) && return m_middle(sz[end - 1])

layer, layers = layers[1], layers[2:end]
outsz = Flux.outputsize(layer, sz)
does_downscale = sz[1] ÷ 2 == outsz[1]

if !does_downscale
return Chain(layer, unetlayers(layers, outsz; outplanes, skip_upscale)...)
elseif does_downscale && skip_upscale > 0
return Chain(layer,
unetlayers(layers, outsz; skip_upscale = skip_upscale - 1,
outplanes)...)
else
childunet = Chain(unetlayers(layers, outsz; skip_upscale)...)
outsz = Flux.outputsize(childunet, outsz)

inplanes = sz[end - 1]
midplanes = outsz[end - 1]
outplanes = isnothing(outplanes) ? inplanes : outplanes

return UNetBlock(Chain(layer, childunet),
inplanes, midplanes, outplanes)
end
end

function UNetBlock(m_child, inplanes, midplanes, outplanes = 2 * inplanes)
return Chain(SkipConnection(Chain(m_child,
PixelShuffleICNR(midplanes, midplanes)),
Parallel(cat_channels, identity, BatchNorm(inplanes))),
xs -> relu.(xs),
UNetCombineLayer(inplanes + midplanes, outplanes))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return Chain(SkipConnection(Chain(m_child,
PixelShuffleICNR(midplanes, midplanes)),
Parallel(cat_channels, identity, BatchNorm(inplanes))),
xs -> relu.(xs),
UNetCombineLayer(inplanes + midplanes, outplanes))
return Chain(SkipConnection(Chain(m_child,
PixelShuffleICNR(midplanes, midplanes)),
Parallel(cat_channels, identity, BatchNorm(inplanes))),
relu,
UNetCombineLayer(inplanes + midplanes, outplanes))

end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also another lower case function name


"""
unet(backbone; inputsize::NTuple{4, Integer}, outplanes::Integer = 3,
final::Any = UNetFinalBlock, fdownscale::Integer = 0, kwargs...)

Creates a UNet model with specified backbone. Backbone of Any Metalhead model
can be used as encoder.
([reference](https://arxiv.org/abs/1505.04597)).
shivance marked this conversation as resolved.
Show resolved Hide resolved

# Arguments

- `backbone`: The backbone layers to be used in encoder.
For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed
to instantiate a UNet with layers of resnet18 as encoder.
- `inputsize`: size of input image
- `outplanes`: number of output feature planes
- `final`: final block as described in original paper
- `fdownscale`: downscale factor
"""
function unet(backbone, inputsize::NTuple{4, Integer}, outplanes::Integer,
final::Any = UNetFinalBlock, fdownscale::Integer = 0, kwargs...)
backbonelayers = collect(iterlayers(backbone))
layers = unetlayers(backbonelayers, inputsize; m_middle = UNetMiddleBlock,
skip_upscale = fdownscale, kwargs...)

outsz = Flux.outputsize(layers, inputsize)
layers = Chain(layers, final(outsz[end - 1], outplanes))

return layers
end

"""
UNet(backbone, inputsize::NTuple{4, Integer}, outplanes::Integer = 3;
pretrain::Bool = false)

Creates a UNet model with specified backbone. Backbone of Any Metalhead model can be used as
encoder.
([reference](https://arxiv.org/abs/1505.04597)).

# Arguments

- `backbone`: The backbone layers to be used in encoder.
For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed to instantiate a UNet with layers of
resnet18 as encoder.
- `inputsize`: size of input image
ToucheSir marked this conversation as resolved.
Show resolved Hide resolved
- `outplanes`: number of output feature planes.
- `pretrain`: Whether to load the pre-trained weights for ImageNet

!!! warning

`UNet` does not currently support pretrained weights.

See also [`Metalhead.UNet`](@ref).
shivance marked this conversation as resolved.
Show resolved Hide resolved
"""
struct UNet
layers::Any
end
@functor UNet

function UNet(backbone, inputsize::NTuple{4, Integer}, outplanes::Integer = 3;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have a default backbone that matches the original paper. Or is something else used more often in practice (@theabhirath or @lorenzoh)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the original CNN series (read: VGG, ResNet, Inception, DenseNet) are used very commonly, but that choice should be on the user since the backbone is often domain dependent (not sure why, but I've seen more DenseNets for medical use cases and more Inceptions for autonomous vehicles and the like). So having the default backbone match the paper would be best

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inputsize should probably be passed in as imsize and inchannels matching the other conv-based models in Metalhead.

pretrain::Bool = false)
layers = unet(backbone, inputsize, outplanes)
if pretrain
loadpretrain!(layers, string("UNet"))
end
return UNet(layers)
end

(m::UNet)(x::AbstractArray) = m.layers(x)
4 changes: 4 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,7 @@ linear_scheduler(drop_prob::Nothing; depth::Integer) = fill(drop_prob, depth)
function _checkconfig(config, configs)
@assert config in configs "Invalid configuration. Must be one of $(sort(collect(configs)))."
end

# Utility function to return Iterator over layers, adopted from FastAI.jl
iterlayers(m::Chain) = Iterators.flatten(iterlayers(l) for l in m.layers)
shivance marked this conversation as resolved.
Show resolved Hide resolved
iterlayers(m) = (m,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call this flatten_chains instead because it only flattens Chain.

8 changes: 8 additions & 0 deletions test/convnets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,11 @@ end
_gc()
end
end

@testset "UNet" begin
encoder = Metalhead.backbone(ResNet(18))
model = UNet(encoder, (256, 256, 3, 1), 10)
@test size(model(x_256)) == (256, 256, 10, 1)
@test gradtest(model, x_256)
_gc()
end