Skip to content

Commit

Permalink
Merge pull request #210 from shivance/unet
Browse files Browse the repository at this point in the history
Adding UNet implementation
  • Loading branch information
darsnack authored Jan 27, 2023
2 parents f4097f7 + 65aa5e8 commit 80ab995
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 2 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ jobs:
- '[r"Res2Net", r"Res2NeXt"]'
- '"Inception"'
- '"DenseNet"'
- '"UNet"'
- '["ConvNeXt", "ConvMixer"]'
- 'r"Mixers"'
- 'r"ViTs"'
Expand Down
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ julia> ]add Metalhead
| [ViT](https://arxiv.org/abs/2010.11929) | [`ViT`](@ref) | N |
| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](@ref) | N |
| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](@ref) | N |
| [UNet](https://arxiv.org/abs/1505.04597v1) | [`UNet`](@ref) | N |

To contribute new models, see our [contributing docs](@ref Contributing-to-Metalhead.jl).

Expand Down
5 changes: 3 additions & 2 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ include("convnets/densenet.jl")
include("convnets/squeezenet.jl")
include("convnets/convnext.jl")
include("convnets/convmixer.jl")
include("convnets/unet.jl")

# Mixers
include("mixers/core.jl")
Expand All @@ -73,15 +74,15 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, MNASNet,
EfficientNet, EfficientNetv2, ConvMixer, ConvNeXt,
MLPMixer, ResMLP, gMLP, ViT
MLPMixer, ResMLP, gMLP, ViT, UNet

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt,
:SEResNet, :SEResNeXt, :Res2Net, :Res2NeXt, :GoogLeNet, :DenseNet,
:Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
:MobileNetv1, :MobileNetv2, :MobileNetv3, :MNASNet,
:EfficientNet, :EfficientNetv2, :ConvMixer, :ConvNeXt,
:MLPMixer, :ResMLP, :gMLP, :ViT)
:MLPMixer, :ResMLP, :gMLP, :ViT, :UNet)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end

Expand Down
125 changes: 125 additions & 0 deletions src/convnets/unet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
function pixel_shuffle_icnr(inplanes, outplanes; r = 2)
return Chain(Chain(basic_conv_bn((1, 1), inplanes, outplanes * (r^2)...)),
Flux.PixelShuffle(r))
end

function unet_combine_layer(inplanes, outplanes)
return Chain(Chain(basic_conv_bn((3, 3), inplanes, outplanes; pad = 1)...),
Chain(basic_conv_bn((3, 3), outplanes, outplanes; pad = 1)...))
end

function unet_middle_block(inplanes)
return Chain(Chain(basic_conv_bn((3, 3), inplanes, 2 * inplanes; pad = 1)...),
Chain(basic_conv_bn((3, 3), 2 * inplanes, inplanes; pad = 1)...))
end

function unet_final_block(inplanes, outplanes)
return Chain(basicblock(inplanes, inplanes; reduction_factor = 1),
Chain(basic_conv_bn((1, 1), inplanes, outplanes)...))
end

function unet_block(m_child, inplanes, midplanes, outplanes = 2 * inplanes)
return Chain(SkipConnection(Chain(m_child,
pixel_shuffle_icnr(midplanes, midplanes)),
Parallel(cat_channels, identity, BatchNorm(inplanes))),
relu,
unet_combine_layer(inplanes + midplanes, outplanes))
end

function unetlayers(layers, sz; outplanes = nothing, skip_upscale = 0,
m_middle = _ -> (identity,))
isempty(layers) && return m_middle(sz[end - 1])

layer, layers = layers[1], layers[2:end]
outsz = Flux.outputsize(layer, sz)
does_downscale = sz[1] ÷ 2 == outsz[1]

if !does_downscale
return Chain(layer, unetlayers(layers, outsz; outplanes, skip_upscale)...)
elseif does_downscale && skip_upscale > 0
return Chain(layer,
unetlayers(layers, outsz; skip_upscale = skip_upscale - 1,
outplanes)...)
else
childunet = Chain(unetlayers(layers, outsz; skip_upscale)...)
outsz = Flux.outputsize(childunet, outsz)

inplanes = sz[end - 1]
midplanes = outsz[end - 1]
outplanes = isnothing(outplanes) ? inplanes : outplanes

return unet_block(Chain(layer, childunet),
inplanes, midplanes, outplanes)
end
end

"""
unet(encoder_backbone, imgdims, outplanes::Integer, final::Any = unet_final_block,
fdownscale::Integer = 0)
Creates a UNet model with specified convolutional backbone.
Backbone of any Metalhead ResNet-like model can be used as encoder
([reference](https://arxiv.org/abs/1505.04597)).
# Arguments
- `encoder_backbone`: The backbone layers of specified model to be used as encoder.
For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed
to instantiate a UNet with layers of resnet18 as encoder.
- `inputsize`: size of input image
- `outplanes`: number of output feature planes
- `final`: final block as described in original paper
- `fdownscale`: downscale factor
"""
function unet(encoder_backbone, imgdims, outplanes::Integer,
final::Any = unet_final_block, fdownscale::Integer = 0)
backbonelayers = collect(flatten_chains(encoder_backbone))
layers = unetlayers(backbonelayers, imgdims; m_middle = unet_middle_block,
skip_upscale = fdownscale)

outsz = Flux.outputsize(layers, imgdims)
layers = Chain(layers, final(outsz[end - 1], outplanes))

return layers
end

"""
UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,
encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false)
Creates a UNet model with an encoder built of specified backbone. By default it uses
[`DenseNet`](@ref) backbone, but any ResNet-like Metalhead model can be used for the encoder.
([reference](https://arxiv.org/abs/1505.04597)).
# Arguments
- `imsize`: size of input image
- `inchannels`: number of channels in input image
- `outplanes`: number of output feature planes.
- `encoder_backbone`: The backbone layers of specified model to be used as encoder.
For example, `Metalhead.backbone(Metalhead.ResNet(18))` can be passed to instantiate a UNet with layers of
resnet18 as encoder.
- `pretrain`: Whether to load the pre-trained weights for ImageNet
!!! warning
`UNet` does not currently support pretrained weights.
See also [`Metalhead.unet`](@ref).
"""
struct UNet
layers::Any
end
@functor UNet

function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,
encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false)
layers = unet(encoder_backbone, (imsize..., inchannels, 1), outplanes)

if pretrain
loadpretrain!(layers, string("UNet"))
end
return UNet(layers)
end

(m::UNet)(x::AbstractArray) = m.layers(x)
11 changes: 11 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,14 @@ linear_scheduler(drop_prob::Nothing; depth::Integer) = fill(drop_prob, depth)
function _checkconfig(config, configs)
@assert config in configs "Invalid configuration. Must be one of $(sort(collect(configs)))."
end

"""
flatten_chains(m::Chain)
flatten_chains(m)
Convenience function for traversing nested layers of a Chain object and flatten them
into a single iterator.
"""
flatten_chains(m::Chain) = Iterators.flatten(flatten_chains(l) for l in m.layers)
flatten_chains(m) = (m,)

11 changes: 11 additions & 0 deletions test/convnets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,14 @@ end
_gc()
end
end

@testset "UNet" begin
encoder = Metalhead.backbone(ResNet(18))
model = UNet((256, 256), 3, 10, encoder)
@test size(model(x_256)) == (256, 256, 10, 1)
@test gradtest(model, x_256)

model = UNet()
@test size(model(x_256)) == (256, 256, 3, 1)
_gc()
end

0 comments on commit 80ab995

Please sign in to comment.