From fc36979a7abcc7c8c2cd52217ad66d66c049bea2 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 12 Dec 2024 02:18:39 +0200 Subject: [PATCH 01/12] Add WeightNorm --- ext/FluxAMDGPUExt/FluxAMDGPUExt.jl | 3 +-- src/layers/normalise.jl | 41 ++++++++++++++++++++++++++++++ t.jl | 20 +++++++++++++++ 3 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 t.jl diff --git a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl index 85ce0365cb..2398b322d5 100644 --- a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl +++ b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl @@ -3,8 +3,7 @@ module FluxAMDGPUExt import ChainRulesCore import ChainRulesCore: NoTangent import Flux -import Flux: adapt_storage, fmap -import Flux: DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias +import Flux: fmap, DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias import NNlib using MLDataDevices using AMDGPU diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 85cece9477..5cf6a097ef 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -568,3 +568,44 @@ scale parameters, `false` otherwise. See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref). """ hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine + +struct WeightNorm{which, dims, L, S} + layer::L + g::S +end +@layer WeightNorm + +function WeightNorm(layer, which::Symbol; dims = -1) + v = getfield(layer, which) + iszero(v) && error( + "`$which` field for `$(typeof(layer))` is all zero, which will result in NaN.") + + d = if dims isa Colon + 1:ndims(v) + elseif dims == -1 + dims = ndims(v) + else + dims + end + g = one.(sum(v; dims=d)) + WeightNorm{which, dims, typeof(layer), typeof(g)}(layer, g) +end + +(w::WeightNorm)(x) = weightnorm(w)(x) + +function weightnorm(wn::WeightNorm{which, dims}) where {which, dims} + # TODO support recursive WeightNorm + v = getfield(wn.layer, which) + w = weightnorm(v, wn.g; dims) + + fields, ctor = Functors.functor(wn.layer) + return ctor(merge( + fields, NamedTuple{(which,)}((w,)), + )) +end + +function weightnorm(v::AbstractArray, g::AbstractArray; dims) + n2 = sum(abs2, v; dims) + ϵ = eps(eltype(v)) + return @. v * g / sqrt(n2 + ϵ) +end diff --git a/t.jl b/t.jl new file mode 100644 index 0000000000..0cc91122ca --- /dev/null +++ b/t.jl @@ -0,0 +1,20 @@ +using Flux + +function main() + # x = rand(Float32, 20, 16) + # d = Dense(20 => 40) + x = rand(Float32, 128, 1, 16) + d = Conv((3,), 1 => 2) + + @show size(d.weight) + + wn = Flux.WeightNorm(d, :weight) + @show size(wn.g) + y1 = wn(x) + + w = Flux.weightnorm(wn) + y2 = w(x) + @assert y1 ≈ y2 + return +end +main() From f7f9ddc9d9bb9346b1f71d1203ff1a652959fd98 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 12 Dec 2024 15:29:12 +0200 Subject: [PATCH 02/12] Fixes --- docs/src/reference/models/layers.md | 2 + src/Flux.jl | 2 +- src/layers/normalise.jl | 109 +++++++++++++++++++++++----- src/layers/recurrent.jl | 1 + t.jl | 35 ++++++--- 5 files changed, 119 insertions(+), 30 deletions(-) diff --git a/docs/src/reference/models/layers.md b/docs/src/reference/models/layers.md index 355d3e7833..b57141ea2b 100644 --- a/docs/src/reference/models/layers.md +++ b/docs/src/reference/models/layers.md @@ -126,6 +126,8 @@ AlphaDropout LayerNorm InstanceNorm GroupNorm +WeightNorm +Flux.remove_weight_norms Flux.normalise ``` diff --git a/src/Flux.jl b/src/Flux.jl index 8fb2351aa2..5b3d3f7977 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -42,7 +42,7 @@ export Chain, Dense, Embedding, EmbeddingBag, SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv, AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, Dropout, AlphaDropout, - LayerNorm, BatchNorm, InstanceNorm, GroupNorm, + LayerNorm, BatchNorm, InstanceNorm, GroupNorm, WeightNorm, MultiHeadAttention, Upsample, PixelShuffle, fmap, cpu, gpu, f32, f64, f16, rand32, randn32, zeros32, ones32, diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 5cf6a097ef..a87e435ace 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -569,34 +569,72 @@ See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`Laye """ hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine -struct WeightNorm{which, dims, L, S} +struct WeightNorm{which, dims, L, G, V} layer::L - g::S + g::G + v::V end @layer WeightNorm -function WeightNorm(layer, which::Symbol; dims = -1) - v = getfield(layer, which) - iszero(v) && error( +""" + WeightNorm(layer::L, which::Symbol = :weight; dims = -1) + +Apply weight normalization to a parameter given by `which` in a `layer`. + +``w = g \\frac{\\mathbf{v}}{\\lVert \\mathbf{v} \\rVert}`` + +Decouples the magnitude of a weight tensor from its direction. +By default, normalization is applied along the output channel `dim=-1` +(equivalent to `dims=ndims(w)`). + +### Example + +```jldoctest +julia> c = Conv((3,), 1 => 2); + +julia> wc = WeightNorm(c, :weight) +WeightNorm( + Conv((3,), 1 => 2), # 8 parameters + 3×1×1 Array{Float32,...}, # 3 parameters + 3×1×2 Array{Float32,...}, # 6 parameters +) # Total: 4 arrays, 17 parameters, 348 bytes. + +julia> x = ones(Float32, 12, 1, 1); + +julia> c(x) ≈ wc(x) # forward pass is the same as with the original layer +true +``` + +# Reference + +Salimans & Kingma, _Weight Normalization_ (2016) https://arxiv.org/abs/1602.07868 +""" +function WeightNorm(layer::L, which::Symbol = :weight; dims = -1) where L + hasfield(L, which) || error("`$L` does not have field `:$which`.") + + x = getfield(layer, which) + iszero(x) && error( "`$which` field for `$(typeof(layer))` is all zero, which will result in NaN.") d = if dims isa Colon - 1:ndims(v) + 1:ndims(x) elseif dims == -1 - dims = ndims(v) + dims = ndims(x) else dims end - g = one.(sum(v; dims=d)) - WeightNorm{which, dims, typeof(layer), typeof(g)}(layer, g) + + g = sqrt.(sum(abs2, x; dims) .+ eps(eltype(x))) + v = x ./ g + WeightNorm{which, dims, L, typeof(g), typeof(v)}(layer, g, v) end -(w::WeightNorm)(x) = weightnorm(w)(x) +(w::WeightNorm)(x) = transform(w)(x) -function weightnorm(wn::WeightNorm{which, dims}) where {which, dims} - # TODO support recursive WeightNorm - v = getfield(wn.layer, which) - w = weightnorm(v, wn.g; dims) +function transform(wn::WeightNorm{which, dims}) where {which, dims} + ϵ = eps(eltype(wn.v)) + n2 = sum(abs2, wn.v; dims) + w = @. wn.g * wn.v / sqrt(n2 + ϵ) fields, ctor = Functors.functor(wn.layer) return ctor(merge( @@ -604,8 +642,43 @@ function weightnorm(wn::WeightNorm{which, dims}) where {which, dims} )) end -function weightnorm(v::AbstractArray, g::AbstractArray; dims) - n2 = sum(abs2, v; dims) - ϵ = eps(eltype(v)) - return @. v * g / sqrt(n2 + ϵ) +function Base.show(io::IO, w::WeightNorm{which, dims}) where {which, dims} + print(io, "WeightNorm(") + Base.show(io, w.layer) + print(io, ", :", which, "; dims=", dims) + print(io, ")") end + +""" + remove_weight_norms(x) + +Remove any [WeightNorm](@ref) parametrization in the model. + +### Example + +```jldoctest +julia> model = Chain( + WeightNorm(Conv((3,), 1 => 2), :weight), + WeightNorm(Conv((3,), 2 => 2), :weight), +) +Chain( + WeightNorm( + Conv((3,), 1 => 2), # 8 parameters + 3×1×1 Array{Float32,...}, # 3 parameters + 3×1×2 Array{Float32,...}, # 6 parameters + ), + WeightNorm( + Conv((3,), 2 => 2), # 14 parameters + 3×2×1 Array{Float32,...}, # 6 parameters + 3×2×2 Array{Float32,...}, # 12 parameters + ), +) # Total: 8 arrays, 49 parameters, 756 bytes. + +julia> Flux.remove_weight_norms(model) +Chain( + Conv((3,), 1 => 2), # 8 parameters + Conv((3,), 2 => 2), # 14 parameters +) # Total: 4 arrays, 22 parameters, 392 bytes. +``` +""" +remove_weight_norms(x) = fmap(transform, x; exclude=l -> l isa WeightNorm) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 9386a3fc2d..40864be294 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -103,6 +103,7 @@ x = rand(Float32, 10) # Run forward res = rnn(x, h0) +``` """ initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2)) diff --git a/t.jl b/t.jl index 0cc91122ca..cbddc0e451 100644 --- a/t.jl +++ b/t.jl @@ -1,20 +1,33 @@ using Flux +using Zygote function main() - # x = rand(Float32, 20, 16) - # d = Dense(20 => 40) - x = rand(Float32, 128, 1, 16) - d = Conv((3,), 1 => 2) + x = rand(Float32, 20, 1) + c = Dense(20 => 20) - @show size(d.weight) + # x = rand(Float32, 12, 1, 1) + # c = Conv((3,), 1 => 2) + y1 = c(x) + wn = WeightNorm(c, :weight) + @show wn + y2 = wn(x) - wn = Flux.WeightNorm(d, :weight) - @show size(wn.g) - y1 = wn(x) - - w = Flux.weightnorm(wn) - y2 = w(x) @assert y1 ≈ y2 + + g = Zygote.gradient(wn) do wn + sum(wn(x)) + end + display(g); println() + + model = Chain( + WeightNorm(Conv((3,), 1 => 2), :weight), + WeightNorm(Conv((3,), 2 => 2), :weight), + ) + @show model + # y1 = model(x) + + mm = Flux.remove_weight_norms(model) + @show mm return end main() From 76c73bf71689ef535b72c535dfbe236079c9925e Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 12 Dec 2024 16:48:38 +0200 Subject: [PATCH 03/12] Add WeightNorm tests --- NEWS.md | 3 + src/layers/normalise.jl | 6 +- t.jl | 33 -- test/layers/normalisation.jl | 974 ++++++++++++++++++----------------- 4 files changed, 514 insertions(+), 502 deletions(-) delete mode 100644 t.jl diff --git a/NEWS.md b/NEWS.md index b0a2a40136..2a40a64aec 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,9 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release. +## v0.15.3 +* Add `WeightNorm` normalization layer. + ## v0.15.0 (December 2024) This release includes two **breaking changes**: - The recurrent layers have been thoroughly revised. See below and read the [documentation](https://fluxml.ai/Flux.jl/v0.15/guide/models/recurrence/) for details. diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index a87e435ace..56a7622f95 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -610,11 +610,11 @@ true Salimans & Kingma, _Weight Normalization_ (2016) https://arxiv.org/abs/1602.07868 """ function WeightNorm(layer::L, which::Symbol = :weight; dims = -1) where L - hasfield(L, which) || error("`$L` does not have field `:$which`.") + hasfield(L, which) || throw(ArgumentError("`$L` does not have field `:$which`.")) x = getfield(layer, which) - iszero(x) && error( - "`$which` field for `$(typeof(layer))` is all zero, which will result in NaN.") + iszero(x) && throw(ArgumentError( + "`$which` field for `$(typeof(layer))` is all zero, which will result in NaN.")) d = if dims isa Colon 1:ndims(x) diff --git a/t.jl b/t.jl deleted file mode 100644 index cbddc0e451..0000000000 --- a/t.jl +++ /dev/null @@ -1,33 +0,0 @@ -using Flux -using Zygote - -function main() - x = rand(Float32, 20, 1) - c = Dense(20 => 20) - - # x = rand(Float32, 12, 1, 1) - # c = Conv((3,), 1 => 2) - y1 = c(x) - wn = WeightNorm(c, :weight) - @show wn - y2 = wn(x) - - @assert y1 ≈ y2 - - g = Zygote.gradient(wn) do wn - sum(wn(x)) - end - display(g); println() - - model = Chain( - WeightNorm(Conv((3,), 1 => 2), :weight), - WeightNorm(Conv((3,), 2 => 2), :weight), - ) - @show model - # y1 = model(x) - - mm = Flux.remove_weight_norms(model) - @show mm - return -end -main() diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index b0c6584b84..254e6186f6 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -3,474 +3,516 @@ using Zygote: pullback, ForwardDiff evalwgrad(f, x...) = pullback(f, x...)[1] -@testset "Dropout" begin - @testset for rng_kwargs in ((), (; rng = MersenneTwister())) - x = [1.0+0im,2.0+1im,3.0+3im] - @test x == Dropout(0.1; rng_kwargs...)(x) - @test x == evalwgrad(Dropout(0; rng_kwargs...), x) - @test zero(x) == evalwgrad(Dropout(1; rng_kwargs...), x) - - x = [1.,2.,3.] - @test x == Dropout(0.1; rng_kwargs...)(x) - @test x == evalwgrad(Dropout(0; rng_kwargs...), x) - @test zero(x) == evalwgrad(Dropout(1; rng_kwargs...), x) - - x = rand(100) - m = Dropout(0.9; rng_kwargs...) - y = evalwgrad(m, x) - @test count(a->a==0, y) > 50 - testmode!(m, true) - y = evalwgrad(m, x) # should override istraining - @test count(a->a==0, y) == 0 - testmode!(m, false) - y = evalwgrad(m, x) - @test count(a->a==0, y) > 50 - - # Keyword active=false - m2 = Dropout(0.9; active=false, rng_kwargs...) - y2 = evalwgrad(m2, x) - @test count(iszero, y2) == 0 - - x = rand(Float32, 100) - m = Chain(Dense(100 => 100), - Dropout(0.9; rng_kwargs...)) - y = evalwgrad(m, x) - @test count(a->a == 0, y) > 50 - testmode!(m, true) - y = evalwgrad(m, x) # should override istraining - @test count(a->a == 0, y) == 0 - - x = rand(100, 50) - m = Dropout(0.5; dims = 2, rng_kwargs...) - y = m(x) - c = map(i->count(a->a==0, @view y[i, :]), 1:100) - @test minimum(c) == maximum(c) - m = Dropout(0.5; dims = 1, rng_kwargs...) - y = m(x) - c = map(i->count(a->a==0, @view y[:, i]), 1:50) - @test minimum(c) == maximum(c) - - # issue #1084 - m = Dropout(0.9; rng_kwargs...) - x = rand(100) - - testmode!(m) - y = m(x) - @test count(a->a == 0, y) == 0 - trainmode!(m) - y = m(x) - @test count(a->a == 0, y) > 50 - - y = Flux.dropout(values(rng_kwargs)..., x, 0.9) # , active=true) - @test count(a->a == 0, y) > 50 - - y = Flux.dropout(values(rng_kwargs)..., x, 0.9 * 0) # , active=false) - @test count(a->a == 0, y) == 0 - - # CPU RNGs map onto CPU ok - if isempty(rng_kwargs) - @test cpu(m).rng isa Random.TaskLocalRNG - else - @test cpu(m).rng === only(values(rng_kwargs)) - end - end - - @test Dropout(0.5; active=true).active === true - @test_throws Exception Dropout(0.5; active=:something_else) -end - -@testset "AlphaDropout" begin - @testset for rng_kwargs in ((), (; rng = MersenneTwister())) - x = [1., 2., 3.] - @test x == AlphaDropout(0.1; rng_kwargs...)(x) - @test x == evalwgrad(AlphaDropout(0; rng_kwargs...), x) - @test zero(x) == evalwgrad(AlphaDropout(1; rng_kwargs...), x) - - x = randn(1000) # large enough to prevent flaky test - m = AlphaDropout(0.5; rng_kwargs...) - - y = evalwgrad(m, x) - # Should preserve unit mean and variance - @test mean(y) ≈ 0 atol=0.2 - @test var(y) ≈ 1 atol=0.2 - - testmode!(m, true) # should override istraining - @test evalwgrad(m, x) == x - - testmode!(m, false) - y = evalwgrad(m, x) - @test mean(y) ≈ 0 atol=0.2 - @test var(y) ≈ 1 atol=0.2 - - # Known good value ranges - # Values taken from https://github.com/pytorch/pytorch/blob/v1.10.0/test/cpp/api/modules.cpp#L1337-L1338 - x = ones(100) - if isempty(rng_kwargs) - @test 40 < sum(evalwgrad(m, x)) < 130 - else - # FIXME: this breaks spuriously for MersenneTwister - @test_skip 40 < sum(evalwgrad(m, x)) < 130 - end - - # CPU RNGs map onto CPU ok - if isempty(rng_kwargs) - @test cpu(m).rng isa Random.TaskLocalRNG - else - @test cpu(m).rng === only(values(rng_kwargs)) - end - end - - @test AlphaDropout(0.5; active=true).active === true - @test_throws Exception AlphaDropout(0.5; active=:something_else) -end - -@testset "BatchNorm" begin - let m = BatchNorm(2), x = [1.0 3.0 5.0; - 2.0 4.0 6.0] - - @test Flux.hasaffine(m) == true - @test length(Flux.trainables(m)) == 2 - - @test m.β == [0, 0] # initβ(2) - @test m.γ == [1, 1] # initγ(2) - # initial m.σ is 1 - # initial m.μ is 0 - - y = evalwgrad(m, x) - @test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5) - # julia> x - # 2×3 Array{Float64,2}: - # 1.0 3.0 5.0 - # 2.0 4.0 6.0 - # - # μ of batch will be - # (1. + 3. + 5.) / 3 = 3 - # (2. + 4. + 6.) / 3 = 4 - # - # ∴ update rule with momentum: - # .1 * 3 + 0 = .3 - # .1 * 4 + 0 = .4 - @test m.μ ≈ reshape([0.3, 0.4], 2, 1) - - # julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] - # 2×1 Array{Float64,2}: - # 1.3 - # 1.3 - @test m.σ² ≈ .1 .* var(x, dims=2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] - - x′ = m(x) - @test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5) - - @inferred m(x) - end - - let m = BatchNorm(2; track_stats=false), x = Float32[1.0 3.0 5.0; 2.0 4.0 6.0] - y = @inferred m(x) - m16 = f16(m) - y16 = @inferred m16(f16(x)) - @test eltype(y16) == Float16 - @test y16 ≈ y atol=1e-3 - end - - # with activation function - let m = BatchNorm(2, sigmoid), x = Float32[1.0 3.0 5.0; - 2.0 4.0 6.0] - y = m(x) - @test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7) - @inferred m(x) - m16 = f16(m) - y16 = @inferred m16(f16(x)) - @test eltype(y16) == Float16 - @test y16 ≈ y atol=1e-3 - end - - let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1) - y = reshape(permutedims(x, [2, 1, 3]), 2, :) - y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3]) - @test m(x) == y - @inferred m(x) - end - - let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1) - y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) - y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) - @test m(x) == y - @inferred m(x) - end - - let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) - y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :) - y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5]) - @test m(x) == y - @inferred m(x) - end - - let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1); - m(x) - @test (@allocated m(x)) < 100_000_000 - @inferred m(x) - end - - @test length(Flux.trainables(BatchNorm(10))) == 2 - @test length(Flux.trainables(BatchNorm(10, affine=true))) == 2 - @test length(Flux.trainables(BatchNorm(10, affine=false))) == 0 - - @test BatchNorm(5; active=true).active === true - @test_throws Exception BatchNorm(5; active=:something_else) -end - -@testset "InstanceNorm" begin - # begin tests - let m = InstanceNorm(2; affine=true, track_stats=true), sizes = (3, 2, 2), - x = reshape(collect(1:prod(sizes)), sizes) - - @test length(Flux.trainables(m)) == 2 - x = Float32.(x) - @test m.β == [0, 0] # initβ(2) - @test m.γ == [1, 1] # initγ(2) - y = evalwgrad(m, x) - - #julia> x - #[:, :, 1] = - # 1.0 4.0 - # 2.0 5.0 - # 3.0 6.0 - # - #[:, :, 2] = - # 7.0 10.0 - # 8.0 11.0 - # 9.0 12.0 - # - # μ will be - # (1. + 2. + 3.) / 3 = 2. - # (4. + 5. + 6.) / 3 = 5. - # - # (7. + 8. + 9.) / 3 = 8. - # (10. + 11. + 12.) / 3 = 11. - # - # ∴ update rule with momentum: - # (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5 - # (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8 - N = ndims(x) - @test m.μ ≈ [0.5, 0.8] - n = prod(size(x,i) for i in 1:N-2) - corr = n / (n-1) - σ² = var(x, dims=1:N-2, corrected=false) - @test m.σ² ≈ 0.1*corr*vec(mean(σ², dims=N)) .+ 0.9 * 1 - - y = m(x) - @test length(m.μ) == 2 - @test length(m.σ²) == 2 - @test y ≈ (x .- reshape(m.μ, 1,2,1)) ./ sqrt.(reshape(m.σ², 1,2,1) .+ 1f-5) atol=1.0e-5 - - @inferred m(x) - end - - # with activation function - let m = InstanceNorm(2, sigmoid; affine=true, track_stats=true), sizes = (3, 2, 2), - x = reshape(collect(1:prod(sizes)), sizes) - x = Float64.(x) - affine_shape = collect(sizes) - affine_shape[[1,3]] .= 1 - - y = evalwgrad(m, x) - y = m(x) # inference time after a training step - μ = reshape(m.μ, affine_shape...) - σ² = reshape(m.σ², affine_shape...) - @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 - - @inferred m(x) - end - - # with activation function - let m = InstanceNorm(2, sigmoid; affine=true, track_stats=false), sizes = (3, 2, 2), - x = reshape(collect(1:prod(sizes)), sizes) - - @test Flux.hasaffine(m) == true - @test length(Flux.trainables(m)) == 2 - x = Float64.(x) - y = m(x) - μ = mean(x, dims=1) - σ² = var(x, dims=1, corrected=false) - @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 - - @inferred m(x) - end - - let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2), - x = reshape(collect(1:prod(sizes)), sizes) - @test Flux.hasaffine(m) == false - @test length(Flux.trainables(m)) == 0 - - x = Float64.(x) - y = m(x) - μ = mean(x, dims=1) - σ² = var(x, dims=1, corrected=false) - @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 - - @inferred m(x) - end - - - let m = trainmode!(InstanceNorm(2; affine=true)), sizes = (2, 4, 1, 2, 3), - x = Float32.(reshape(collect(1:prod(sizes)), sizes)) - y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) - y = reshape(m(y), sizes...) - @test m(x) == y - - @inferred m(x) - end - - # check that μ, σ², and the output are the correct size for higher rank tensors - let m = InstanceNorm(2; affine=true,track_stats=true), sizes = (5, 5, 3, 4, 2, 6), - x = reshape(Float32.(collect(1:prod(sizes))), sizes) - y = evalwgrad(m, x) - @test size(m.μ) == (sizes[end - 1], ) - @test size(m.σ²) == (sizes[end - 1], ) - @test size(y) == sizes - - @inferred m(x) - end - - # show that instance norm is equal to batch norm when channel and batch dims are squashed - let m_inorm = trainmode!(InstanceNorm(2; affine=true)), m_bnorm = trainmode!(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6), - x = reshape(Float32.(collect(1:prod(sizes))), sizes) - @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes) - end - - let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1); - m(x) - @test (@allocated m(x)) < 100_000_000 - - @inferred m(x) - end - - @test length(Flux.trainables(InstanceNorm(10))) == 0 - @test length(Flux.trainables(InstanceNorm(10, affine=true))) == 2 - @test length(Flux.trainables(InstanceNorm(10, affine=false))) == 0 - - @test InstanceNorm(5; active=true).active === true - @test_throws Exception InstanceNorm(5; active=:something_else) -end - -@testset "LayerNorm" begin - x = rand(2,3) - @test LayerNorm(2)(x) ≈ Flux.normalise(x, dims=1) - x = rand(2,3,4) - @test LayerNorm(2)(x) ≈ Flux.normalise(x, dims=1) - x = rand(2,3,4,5) - @test LayerNorm(2)(x) ≈ Flux.normalise(x, dims=1) - x = rand(2) - @test LayerNorm(2, tanh)(x) ≈ tanh.(Flux.normalise(x, dims=1)) - - x = rand(2,3,4,5) - @test LayerNorm((2,3))(x) ≈ Flux.normalise(x, dims=(1,2)) - x = rand(2,3,4,5) - @test LayerNorm((2,3,4))(x) ≈ Flux.normalise(x, dims=1:3) - - m = LayerNorm((2,3,4)) - @test Flux.hasaffine(m) == true - @test length(Flux.trainables(m)) == 2 - m = LayerNorm((2,3,4), affine=false) - @test Flux.hasaffine(m) == false - @test length(Flux.trainables(m)) == 0 -end - -@testset "GroupNorm" begin - # begin tests - squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions - - let m = GroupNorm(4,2), sizes = (3,4,2), - x = reshape(collect(1:prod(sizes)), sizes) - - @test length(Flux.trainables(m)) == 2 - x = Float32.(x) - @test m.β == [0, 0, 0, 0] # initβ(32) - @test m.γ == [1, 1, 1, 1] # initγ(32) - - ŷ = evalwgrad(m, x) +#@testset "Dropout" begin +# @testset for rng_kwargs in ((), (; rng = MersenneTwister())) +# x = [1.0+0im,2.0+1im,3.0+3im] +# @test x == Dropout(0.1; rng_kwargs...)(x) +# @test x == evalwgrad(Dropout(0; rng_kwargs...), x) +# @test zero(x) == evalwgrad(Dropout(1; rng_kwargs...), x) + +# x = [1.,2.,3.] +# @test x == Dropout(0.1; rng_kwargs...)(x) +# @test x == evalwgrad(Dropout(0; rng_kwargs...), x) +# @test zero(x) == evalwgrad(Dropout(1; rng_kwargs...), x) + +# x = rand(100) +# m = Dropout(0.9; rng_kwargs...) +# y = evalwgrad(m, x) +# @test count(a->a==0, y) > 50 +# testmode!(m, true) +# y = evalwgrad(m, x) # should override istraining +# @test count(a->a==0, y) == 0 +# testmode!(m, false) +# y = evalwgrad(m, x) +# @test count(a->a==0, y) > 50 + +# # Keyword active=false +# m2 = Dropout(0.9; active=false, rng_kwargs...) +# y2 = evalwgrad(m2, x) +# @test count(iszero, y2) == 0 + +# x = rand(Float32, 100) +# m = Chain(Dense(100 => 100), +# Dropout(0.9; rng_kwargs...)) +# y = evalwgrad(m, x) +# @test count(a->a == 0, y) > 50 +# testmode!(m, true) +# y = evalwgrad(m, x) # should override istraining +# @test count(a->a == 0, y) == 0 + +# x = rand(100, 50) +# m = Dropout(0.5; dims = 2, rng_kwargs...) +# y = m(x) +# c = map(i->count(a->a==0, @view y[i, :]), 1:100) +# @test minimum(c) == maximum(c) +# m = Dropout(0.5; dims = 1, rng_kwargs...) +# y = m(x) +# c = map(i->count(a->a==0, @view y[:, i]), 1:50) +# @test minimum(c) == maximum(c) + +# # issue #1084 +# m = Dropout(0.9; rng_kwargs...) +# x = rand(100) + +# testmode!(m) +# y = m(x) +# @test count(a->a == 0, y) == 0 +# trainmode!(m) +# y = m(x) +# @test count(a->a == 0, y) > 50 + +# y = Flux.dropout(values(rng_kwargs)..., x, 0.9) # , active=true) +# @test count(a->a == 0, y) > 50 + +# y = Flux.dropout(values(rng_kwargs)..., x, 0.9 * 0) # , active=false) +# @test count(a->a == 0, y) == 0 + +# # CPU RNGs map onto CPU ok +# if isempty(rng_kwargs) +# @test cpu(m).rng isa Random.TaskLocalRNG +# else +# @test cpu(m).rng === only(values(rng_kwargs)) +# end +# end + +# @test Dropout(0.5; active=true).active === true +# @test_throws Exception Dropout(0.5; active=:something_else) +#end + +#@testset "AlphaDropout" begin +# @testset for rng_kwargs in ((), (; rng = MersenneTwister())) +# x = [1., 2., 3.] +# @test x == AlphaDropout(0.1; rng_kwargs...)(x) +# @test x == evalwgrad(AlphaDropout(0; rng_kwargs...), x) +# @test zero(x) == evalwgrad(AlphaDropout(1; rng_kwargs...), x) + +# x = randn(1000) # large enough to prevent flaky test +# m = AlphaDropout(0.5; rng_kwargs...) + +# y = evalwgrad(m, x) +# # Should preserve unit mean and variance +# @test mean(y) ≈ 0 atol=0.2 +# @test var(y) ≈ 1 atol=0.2 + +# testmode!(m, true) # should override istraining +# @test evalwgrad(m, x) == x + +# testmode!(m, false) +# y = evalwgrad(m, x) +# @test mean(y) ≈ 0 atol=0.2 +# @test var(y) ≈ 1 atol=0.2 + +# # Known good value ranges +# # Values taken from https://github.com/pytorch/pytorch/blob/v1.10.0/test/cpp/api/modules.cpp#L1337-L1338 +# x = ones(100) +# if isempty(rng_kwargs) +# @test 40 < sum(evalwgrad(m, x)) < 130 +# else +# # FIXME: this breaks spuriously for MersenneTwister +# @test_skip 40 < sum(evalwgrad(m, x)) < 130 +# end + +# # CPU RNGs map onto CPU ok +# if isempty(rng_kwargs) +# @test cpu(m).rng isa Random.TaskLocalRNG +# else +# @test cpu(m).rng === only(values(rng_kwargs)) +# end +# end + +# @test AlphaDropout(0.5; active=true).active === true +# @test_throws Exception AlphaDropout(0.5; active=:something_else) +#end + +#@testset "BatchNorm" begin +# let m = BatchNorm(2), x = [1.0 3.0 5.0; +# 2.0 4.0 6.0] + +# @test Flux.hasaffine(m) == true +# @test length(Flux.trainables(m)) == 2 + +# @test m.β == [0, 0] # initβ(2) +# @test m.γ == [1, 1] # initγ(2) +# # initial m.σ is 1 +# # initial m.μ is 0 + +# y = evalwgrad(m, x) +# @test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5) +# # julia> x +# # 2×3 Array{Float64,2}: +# # 1.0 3.0 5.0 +# # 2.0 4.0 6.0 +# # +# # μ of batch will be +# # (1. + 3. + 5.) / 3 = 3 +# # (2. + 4. + 6.) / 3 = 4 +# # +# # ∴ update rule with momentum: +# # .1 * 3 + 0 = .3 +# # .1 * 4 + 0 = .4 +# @test m.μ ≈ reshape([0.3, 0.4], 2, 1) + +# # julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] +# # 2×1 Array{Float64,2}: +# # 1.3 +# # 1.3 +# @test m.σ² ≈ .1 .* var(x, dims=2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + +# x′ = m(x) +# @test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5) + +# @inferred m(x) +# end + +# let m = BatchNorm(2; track_stats=false), x = Float32[1.0 3.0 5.0; 2.0 4.0 6.0] +# y = @inferred m(x) +# m16 = f16(m) +# y16 = @inferred m16(f16(x)) +# @test eltype(y16) == Float16 +# @test y16 ≈ y atol=1e-3 +# end + +# # with activation function +# let m = BatchNorm(2, sigmoid), x = Float32[1.0 3.0 5.0; +# 2.0 4.0 6.0] +# y = m(x) +# @test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7) +# @inferred m(x) +# m16 = f16(m) +# y16 = @inferred m16(f16(x)) +# @test eltype(y16) == Float16 +# @test y16 ≈ y atol=1e-3 +# end + +# let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1) +# y = reshape(permutedims(x, [2, 1, 3]), 2, :) +# y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3]) +# @test m(x) == y +# @inferred m(x) +# end + +# let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1) +# y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) +# y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) +# @test m(x) == y +# @inferred m(x) +# end + +# let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) +# y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :) +# y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5]) +# @test m(x) == y +# @inferred m(x) +# end + +# let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1); +# m(x) +# @test (@allocated m(x)) < 100_000_000 +# @inferred m(x) +# end + +# @test length(Flux.trainables(BatchNorm(10))) == 2 +# @test length(Flux.trainables(BatchNorm(10, affine=true))) == 2 +# @test length(Flux.trainables(BatchNorm(10, affine=false))) == 0 + +# @test BatchNorm(5; active=true).active === true +# @test_throws Exception BatchNorm(5; active=:something_else) +#end + +#@testset "InstanceNorm" begin +# # begin tests +# let m = InstanceNorm(2; affine=true, track_stats=true), sizes = (3, 2, 2), +# x = reshape(collect(1:prod(sizes)), sizes) + +# @test length(Flux.trainables(m)) == 2 +# x = Float32.(x) +# @test m.β == [0, 0] # initβ(2) +# @test m.γ == [1, 1] # initγ(2) +# y = evalwgrad(m, x) + +# #julia> x +# #[:, :, 1] = +# # 1.0 4.0 +# # 2.0 5.0 +# # 3.0 6.0 +# # +# #[:, :, 2] = +# # 7.0 10.0 +# # 8.0 11.0 +# # 9.0 12.0 +# # +# # μ will be +# # (1. + 2. + 3.) / 3 = 2. +# # (4. + 5. + 6.) / 3 = 5. +# # +# # (7. + 8. + 9.) / 3 = 8. +# # (10. + 11. + 12.) / 3 = 11. +# # +# # ∴ update rule with momentum: +# # (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5 +# # (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8 +# N = ndims(x) +# @test m.μ ≈ [0.5, 0.8] +# n = prod(size(x,i) for i in 1:N-2) +# corr = n / (n-1) +# σ² = var(x, dims=1:N-2, corrected=false) +# @test m.σ² ≈ 0.1*corr*vec(mean(σ², dims=N)) .+ 0.9 * 1 + +# y = m(x) +# @test length(m.μ) == 2 +# @test length(m.σ²) == 2 +# @test y ≈ (x .- reshape(m.μ, 1,2,1)) ./ sqrt.(reshape(m.σ², 1,2,1) .+ 1f-5) atol=1.0e-5 + +# @inferred m(x) +# end + +# # with activation function +# let m = InstanceNorm(2, sigmoid; affine=true, track_stats=true), sizes = (3, 2, 2), +# x = reshape(collect(1:prod(sizes)), sizes) +# x = Float64.(x) +# affine_shape = collect(sizes) +# affine_shape[[1,3]] .= 1 + +# y = evalwgrad(m, x) +# y = m(x) # inference time after a training step +# μ = reshape(m.μ, affine_shape...) +# σ² = reshape(m.σ², affine_shape...) +# @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 + +# @inferred m(x) +# end + +# # with activation function +# let m = InstanceNorm(2, sigmoid; affine=true, track_stats=false), sizes = (3, 2, 2), +# x = reshape(collect(1:prod(sizes)), sizes) + +# @test Flux.hasaffine(m) == true +# @test length(Flux.trainables(m)) == 2 +# x = Float64.(x) +# y = m(x) +# μ = mean(x, dims=1) +# σ² = var(x, dims=1, corrected=false) +# @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 + +# @inferred m(x) +# end + +# let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2), +# x = reshape(collect(1:prod(sizes)), sizes) +# @test Flux.hasaffine(m) == false +# @test length(Flux.trainables(m)) == 0 + +# x = Float64.(x) +# y = m(x) +# μ = mean(x, dims=1) +# σ² = var(x, dims=1, corrected=false) +# @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 + +# @inferred m(x) +# end + + +# let m = trainmode!(InstanceNorm(2; affine=true)), sizes = (2, 4, 1, 2, 3), +# x = Float32.(reshape(collect(1:prod(sizes)), sizes)) +# y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) +# y = reshape(m(y), sizes...) +# @test m(x) == y + +# @inferred m(x) +# end + +# # check that μ, σ², and the output are the correct size for higher rank tensors +# let m = InstanceNorm(2; affine=true,track_stats=true), sizes = (5, 5, 3, 4, 2, 6), +# x = reshape(Float32.(collect(1:prod(sizes))), sizes) +# y = evalwgrad(m, x) +# @test size(m.μ) == (sizes[end - 1], ) +# @test size(m.σ²) == (sizes[end - 1], ) +# @test size(y) == sizes + +# @inferred m(x) +# end + +# # show that instance norm is equal to batch norm when channel and batch dims are squashed +# let m_inorm = trainmode!(InstanceNorm(2; affine=true)), m_bnorm = trainmode!(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6), +# x = reshape(Float32.(collect(1:prod(sizes))), sizes) +# @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes) +# end + +# let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1); +# m(x) +# @test (@allocated m(x)) < 100_000_000 + +# @inferred m(x) +# end + +# @test length(Flux.trainables(InstanceNorm(10))) == 0 +# @test length(Flux.trainables(InstanceNorm(10, affine=true))) == 2 +# @test length(Flux.trainables(InstanceNorm(10, affine=false))) == 0 + +# @test InstanceNorm(5; active=true).active === true +# @test_throws Exception InstanceNorm(5; active=:something_else) +#end + +#@testset "LayerNorm" begin +# x = rand(2,3) +# @test LayerNorm(2)(x) ≈ Flux.normalise(x, dims=1) +# x = rand(2,3,4) +# @test LayerNorm(2)(x) ≈ Flux.normalise(x, dims=1) +# x = rand(2,3,4,5) +# @test LayerNorm(2)(x) ≈ Flux.normalise(x, dims=1) +# x = rand(2) +# @test LayerNorm(2, tanh)(x) ≈ tanh.(Flux.normalise(x, dims=1)) + +# x = rand(2,3,4,5) +# @test LayerNorm((2,3))(x) ≈ Flux.normalise(x, dims=(1,2)) +# x = rand(2,3,4,5) +# @test LayerNorm((2,3,4))(x) ≈ Flux.normalise(x, dims=1:3) + +# m = LayerNorm((2,3,4)) +# @test Flux.hasaffine(m) == true +# @test length(Flux.trainables(m)) == 2 +# m = LayerNorm((2,3,4), affine=false) +# @test Flux.hasaffine(m) == false +# @test length(Flux.trainables(m)) == 0 +#end + +#@testset "GroupNorm" begin +# # begin tests +# squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions + +# let m = GroupNorm(4,2), sizes = (3,4,2), +# x = reshape(collect(1:prod(sizes)), sizes) + +# @test length(Flux.trainables(m)) == 2 +# x = Float32.(x) +# @test m.β == [0, 0, 0, 0] # initβ(32) +# @test m.γ == [1, 1, 1, 1] # initγ(32) + +# ŷ = evalwgrad(m, x) - @test m.μ === nothing - @test m.σ² === nothing - ŷ = m(x) - y = [-1.4638476 0.29276943 -1.4638476 0.29276943; -0.87830865 0.87830853 -0.8783088 0.8783083; -0.29276967 1.4638474 -0.2927699 1.4638472;;; -1.4638476 0.29276943 -1.4638472 0.29276943; -0.8783083 0.8783083 -0.8783083 0.8783083; -0.29276943 1.4638472 -0.29276943 1.4638472] - - @test ŷ ≈ y atol=1.0e-5 - end - # with activation function - let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2), - x = reshape(collect(1:prod(sizes)), sizes) +# @test m.μ === nothing +# @test m.σ² === nothing +# ŷ = m(x) +# y = [-1.4638476 0.29276943 -1.4638476 0.29276943; -0.87830865 0.87830853 -0.8783088 0.8783083; -0.29276967 1.4638474 -0.2927699 1.4638472;;; -1.4638476 0.29276943 -1.4638472 0.29276943; -0.8783083 0.8783083 -0.8783083 0.8783083; -0.29276943 1.4638472 -0.29276943 1.4638472] + +# @test ŷ ≈ y atol=1.0e-5 +# end +# # with activation function +# let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2), +# x = reshape(collect(1:prod(sizes)), sizes) - x = Float32.(x) - μ_affine_shape = ones(Int,length(sizes) + 1) - μ_affine_shape[end-1] = 2 # Number of groups - - affine_shape = ones(Int,length(sizes) + 1) - affine_shape[end-2] = 2 # Channels per group - affine_shape[end-1] = 2 # Number of groups - affine_shape[1] = sizes[1] - affine_shape[end] = sizes[end] - - og_shape = size(x) - - ŷ = m(x) - y = [0.18787955 0.57267404 0.18787955 0.57267404; 0.2935284 0.70647156 0.29352835 0.70647156; 0.42732593 0.81212044 0.42732587 0.8121204;;; 0.18787955 0.57267404 0.1878796 0.57267404; 0.29352847 0.70647156 0.29352847 0.70647156; 0.42732602 0.8121204 0.42732602 0.8121204] - @test ŷ ≈ y atol=1e-7 - end - - let m = trainmode!(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3), - x = Float32.(reshape(collect(1:prod(sizes)), sizes)) - y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) - y = reshape(m(y), sizes...) - @test m(x) == y - end - - # show that group norm is the same as instance norm when the group size is the same as the number of channels - let IN = trainmode!(InstanceNorm(4; affine=true)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5), - x = Float32.(reshape(collect(1:prod(sizes)), sizes)) - @test IN(x) ≈ GN(x) - end - - # show that group norm is the same as batch norm for a group of size 1 and batch of size 1 - let BN = trainmode!(BatchNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,1), - x = Float32.(reshape(collect(1:prod(sizes)), sizes)) - @test BN(x) ≈ GN(x) - end - - @test GroupNorm(5, 5; active=true).active === true - @test_throws Exception GroupNorm(5, 5; active=:something_else) +# x = Float32.(x) +# μ_affine_shape = ones(Int,length(sizes) + 1) +# μ_affine_shape[end-1] = 2 # Number of groups + +# affine_shape = ones(Int,length(sizes) + 1) +# affine_shape[end-2] = 2 # Channels per group +# affine_shape[end-1] = 2 # Number of groups +# affine_shape[1] = sizes[1] +# affine_shape[end] = sizes[end] + +# og_shape = size(x) + +# ŷ = m(x) +# y = [0.18787955 0.57267404 0.18787955 0.57267404; 0.2935284 0.70647156 0.29352835 0.70647156; 0.42732593 0.81212044 0.42732587 0.8121204;;; 0.18787955 0.57267404 0.1878796 0.57267404; 0.29352847 0.70647156 0.29352847 0.70647156; 0.42732602 0.8121204 0.42732602 0.8121204] +# @test ŷ ≈ y atol=1e-7 +# end + +# let m = trainmode!(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3), +# x = Float32.(reshape(collect(1:prod(sizes)), sizes)) +# y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) +# y = reshape(m(y), sizes...) +# @test m(x) == y +# end + +# # show that group norm is the same as instance norm when the group size is the same as the number of channels +# let IN = trainmode!(InstanceNorm(4; affine=true)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5), +# x = Float32.(reshape(collect(1:prod(sizes)), sizes)) +# @test IN(x) ≈ GN(x) +# end + +# # show that group norm is the same as batch norm for a group of size 1 and batch of size 1 +# let BN = trainmode!(BatchNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,1), +# x = Float32.(reshape(collect(1:prod(sizes)), sizes)) +# @test BN(x) ≈ GN(x) +# end + +# @test GroupNorm(5, 5; active=true).active === true +# @test_throws Exception GroupNorm(5, 5; active=:something_else) +#end + +@testset "WeightNorm" begin + x = rand(Float32, 1, 3) + m = Dense(1 => 2) + mn = WeightNorm(m) + @test m(x) ≈ mn(x) + + @test_throws ArgumentError WeightNorm(m, :weights) + @test_throws "does not have field" WeightNorm(m, :weights) + + @test_throws ArgumentError WeightNorm(m, :bias) + @test_throws "is all zero" WeightNorm(m, :bias) + + g = (Zygote.gradient(mn) do mn + sum(mn(x)) + end)[1] + @test g.layer.weight ≡ nothing # Original weight does not participate. + @test g.layer.bias ≢ nothing + @test g.g ≢ nothing + @test g.v ≢ nothing + + om = Flux.remove_weight_norms(mn) + @test om isa Dense + @test om.weight ≈ m.weight + @test om.bias ≈ m.bias + + # Test with Chain. + + c = Chain( + WeightNorm(Conv((3,), 1 => 2)), + WeightNorm(Conv((3,), 2 => 2)), + ) + @test c[1] isa WeightNorm + @test c[2] isa WeightNorm + + oc = Flux.remove_weight_norms(c) + @test oc[1] isa Conv + @test oc[2] isa Conv + + x = rand(Float32, 12, 1, 1) + @test c(x) ≈ oc(x) end -@testset "second derivatives" begin - m1 = Dropout(0.5) - @test Zygote.hessian_reverse(sum∘m1, [1.0,2.0,3.0]) == zeros(3, 3) - - m2 = Chain(BatchNorm(3), sum) - @test Zygote.hessian_reverse(m2, Float32[1 2; 3 4; 5 6]) == zeros(Float32, 6, 6) broken = VERSION >= v"1.11" -end - -@testset "ForwardDiff" begin - bn = BatchNorm(3) - @test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32} - # iszero(bn.μ) # is true. But ideally would not be, if Flux would automatically choose trainmode - Flux.trainmode!(bn) - # This was an error, https://github.com/FluxML/Flux.jl/issues/2122 - @test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32} - @test !iszero(bn.μ) - - # Easy case of 2122, gradient with x - x5 = rand(Float32, 5, 3) - bn1 = BatchNorm(5, relu) - bn2 = BatchNorm(5, relu) - g1 = Zygote.gradient(x -> sum(abs2, bn1(x)), x5)[1] - g2 = ForwardDiff.gradient(x -> sum(abs2, bn2(x)), x5) - @test g1 ≈ g2 - - # Harder case? - v1, re1 = Flux.destructure(BatchNorm(5, relu)); - g1 = Zygote.gradient(v -> sum(abs2, re1(v)(x5)), v1)[1] - - v2, re2 = Flux.destructure(BatchNorm(5, relu)); - g2 = ForwardDiff.gradient(v -> sum(abs2, re2(v)(x5)), v2) -end +# @testset "second derivatives" begin +# m1 = Dropout(0.5) +# @test Zygote.hessian_reverse(sum∘m1, [1.0,2.0,3.0]) == zeros(3, 3) + +# m2 = Chain(BatchNorm(3), sum) +# @test Zygote.hessian_reverse(m2, Float32[1 2; 3 4; 5 6]) == zeros(Float32, 6, 6) broken = VERSION >= v"1.11" +# end + +# @testset "ForwardDiff" begin +# bn = BatchNorm(3) +# @test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32} +# # iszero(bn.μ) # is true. But ideally would not be, if Flux would automatically choose trainmode +# Flux.trainmode!(bn) +# # This was an error, https://github.com/FluxML/Flux.jl/issues/2122 +# @test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32} +# @test !iszero(bn.μ) + +# # Easy case of 2122, gradient with x +# x5 = rand(Float32, 5, 3) +# bn1 = BatchNorm(5, relu) +# bn2 = BatchNorm(5, relu) +# g1 = Zygote.gradient(x -> sum(abs2, bn1(x)), x5)[1] +# g2 = ForwardDiff.gradient(x -> sum(abs2, bn2(x)), x5) +# @test g1 ≈ g2 + +# # Harder case? +# v1, re1 = Flux.destructure(BatchNorm(5, relu)); +# g1 = Zygote.gradient(v -> sum(abs2, re1(v)(x5)), v1)[1] + +# v2, re2 = Flux.destructure(BatchNorm(5, relu)); +# g2 = ForwardDiff.gradient(v -> sum(abs2, re2(v)(x5)), v2) +# end From 36dc721c61b46955f62558fcda08de6372ba1be6 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 12 Dec 2024 17:09:44 +0200 Subject: [PATCH 04/12] More tests --- test/layers/normalisation.jl | 947 ++++++++++++++++++----------------- 1 file changed, 480 insertions(+), 467 deletions(-) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 254e6186f6..2ea46bf03e 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -3,443 +3,443 @@ using Zygote: pullback, ForwardDiff evalwgrad(f, x...) = pullback(f, x...)[1] -#@testset "Dropout" begin -# @testset for rng_kwargs in ((), (; rng = MersenneTwister())) -# x = [1.0+0im,2.0+1im,3.0+3im] -# @test x == Dropout(0.1; rng_kwargs...)(x) -# @test x == evalwgrad(Dropout(0; rng_kwargs...), x) -# @test zero(x) == evalwgrad(Dropout(1; rng_kwargs...), x) - -# x = [1.,2.,3.] -# @test x == Dropout(0.1; rng_kwargs...)(x) -# @test x == evalwgrad(Dropout(0; rng_kwargs...), x) -# @test zero(x) == evalwgrad(Dropout(1; rng_kwargs...), x) - -# x = rand(100) -# m = Dropout(0.9; rng_kwargs...) -# y = evalwgrad(m, x) -# @test count(a->a==0, y) > 50 -# testmode!(m, true) -# y = evalwgrad(m, x) # should override istraining -# @test count(a->a==0, y) == 0 -# testmode!(m, false) -# y = evalwgrad(m, x) -# @test count(a->a==0, y) > 50 - -# # Keyword active=false -# m2 = Dropout(0.9; active=false, rng_kwargs...) -# y2 = evalwgrad(m2, x) -# @test count(iszero, y2) == 0 - -# x = rand(Float32, 100) -# m = Chain(Dense(100 => 100), -# Dropout(0.9; rng_kwargs...)) -# y = evalwgrad(m, x) -# @test count(a->a == 0, y) > 50 -# testmode!(m, true) -# y = evalwgrad(m, x) # should override istraining -# @test count(a->a == 0, y) == 0 - -# x = rand(100, 50) -# m = Dropout(0.5; dims = 2, rng_kwargs...) -# y = m(x) -# c = map(i->count(a->a==0, @view y[i, :]), 1:100) -# @test minimum(c) == maximum(c) -# m = Dropout(0.5; dims = 1, rng_kwargs...) -# y = m(x) -# c = map(i->count(a->a==0, @view y[:, i]), 1:50) -# @test minimum(c) == maximum(c) - -# # issue #1084 -# m = Dropout(0.9; rng_kwargs...) -# x = rand(100) - -# testmode!(m) -# y = m(x) -# @test count(a->a == 0, y) == 0 -# trainmode!(m) -# y = m(x) -# @test count(a->a == 0, y) > 50 - -# y = Flux.dropout(values(rng_kwargs)..., x, 0.9) # , active=true) -# @test count(a->a == 0, y) > 50 - -# y = Flux.dropout(values(rng_kwargs)..., x, 0.9 * 0) # , active=false) -# @test count(a->a == 0, y) == 0 - -# # CPU RNGs map onto CPU ok -# if isempty(rng_kwargs) -# @test cpu(m).rng isa Random.TaskLocalRNG -# else -# @test cpu(m).rng === only(values(rng_kwargs)) -# end -# end - -# @test Dropout(0.5; active=true).active === true -# @test_throws Exception Dropout(0.5; active=:something_else) -#end - -#@testset "AlphaDropout" begin -# @testset for rng_kwargs in ((), (; rng = MersenneTwister())) -# x = [1., 2., 3.] -# @test x == AlphaDropout(0.1; rng_kwargs...)(x) -# @test x == evalwgrad(AlphaDropout(0; rng_kwargs...), x) -# @test zero(x) == evalwgrad(AlphaDropout(1; rng_kwargs...), x) - -# x = randn(1000) # large enough to prevent flaky test -# m = AlphaDropout(0.5; rng_kwargs...) - -# y = evalwgrad(m, x) -# # Should preserve unit mean and variance -# @test mean(y) ≈ 0 atol=0.2 -# @test var(y) ≈ 1 atol=0.2 - -# testmode!(m, true) # should override istraining -# @test evalwgrad(m, x) == x - -# testmode!(m, false) -# y = evalwgrad(m, x) -# @test mean(y) ≈ 0 atol=0.2 -# @test var(y) ≈ 1 atol=0.2 - -# # Known good value ranges -# # Values taken from https://github.com/pytorch/pytorch/blob/v1.10.0/test/cpp/api/modules.cpp#L1337-L1338 -# x = ones(100) -# if isempty(rng_kwargs) -# @test 40 < sum(evalwgrad(m, x)) < 130 -# else -# # FIXME: this breaks spuriously for MersenneTwister -# @test_skip 40 < sum(evalwgrad(m, x)) < 130 -# end - -# # CPU RNGs map onto CPU ok -# if isempty(rng_kwargs) -# @test cpu(m).rng isa Random.TaskLocalRNG -# else -# @test cpu(m).rng === only(values(rng_kwargs)) -# end -# end - -# @test AlphaDropout(0.5; active=true).active === true -# @test_throws Exception AlphaDropout(0.5; active=:something_else) -#end - -#@testset "BatchNorm" begin -# let m = BatchNorm(2), x = [1.0 3.0 5.0; -# 2.0 4.0 6.0] - -# @test Flux.hasaffine(m) == true -# @test length(Flux.trainables(m)) == 2 - -# @test m.β == [0, 0] # initβ(2) -# @test m.γ == [1, 1] # initγ(2) -# # initial m.σ is 1 -# # initial m.μ is 0 - -# y = evalwgrad(m, x) -# @test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5) -# # julia> x -# # 2×3 Array{Float64,2}: -# # 1.0 3.0 5.0 -# # 2.0 4.0 6.0 -# # -# # μ of batch will be -# # (1. + 3. + 5.) / 3 = 3 -# # (2. + 4. + 6.) / 3 = 4 -# # -# # ∴ update rule with momentum: -# # .1 * 3 + 0 = .3 -# # .1 * 4 + 0 = .4 -# @test m.μ ≈ reshape([0.3, 0.4], 2, 1) - -# # julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] -# # 2×1 Array{Float64,2}: -# # 1.3 -# # 1.3 -# @test m.σ² ≈ .1 .* var(x, dims=2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] - -# x′ = m(x) -# @test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5) - -# @inferred m(x) -# end - -# let m = BatchNorm(2; track_stats=false), x = Float32[1.0 3.0 5.0; 2.0 4.0 6.0] -# y = @inferred m(x) -# m16 = f16(m) -# y16 = @inferred m16(f16(x)) -# @test eltype(y16) == Float16 -# @test y16 ≈ y atol=1e-3 -# end - -# # with activation function -# let m = BatchNorm(2, sigmoid), x = Float32[1.0 3.0 5.0; -# 2.0 4.0 6.0] -# y = m(x) -# @test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7) -# @inferred m(x) -# m16 = f16(m) -# y16 = @inferred m16(f16(x)) -# @test eltype(y16) == Float16 -# @test y16 ≈ y atol=1e-3 -# end - -# let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1) -# y = reshape(permutedims(x, [2, 1, 3]), 2, :) -# y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3]) -# @test m(x) == y -# @inferred m(x) -# end - -# let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1) -# y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) -# y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) -# @test m(x) == y -# @inferred m(x) -# end - -# let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) -# y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :) -# y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5]) -# @test m(x) == y -# @inferred m(x) -# end - -# let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1); -# m(x) -# @test (@allocated m(x)) < 100_000_000 -# @inferred m(x) -# end - -# @test length(Flux.trainables(BatchNorm(10))) == 2 -# @test length(Flux.trainables(BatchNorm(10, affine=true))) == 2 -# @test length(Flux.trainables(BatchNorm(10, affine=false))) == 0 - -# @test BatchNorm(5; active=true).active === true -# @test_throws Exception BatchNorm(5; active=:something_else) -#end - -#@testset "InstanceNorm" begin -# # begin tests -# let m = InstanceNorm(2; affine=true, track_stats=true), sizes = (3, 2, 2), -# x = reshape(collect(1:prod(sizes)), sizes) - -# @test length(Flux.trainables(m)) == 2 -# x = Float32.(x) -# @test m.β == [0, 0] # initβ(2) -# @test m.γ == [1, 1] # initγ(2) -# y = evalwgrad(m, x) - -# #julia> x -# #[:, :, 1] = -# # 1.0 4.0 -# # 2.0 5.0 -# # 3.0 6.0 -# # -# #[:, :, 2] = -# # 7.0 10.0 -# # 8.0 11.0 -# # 9.0 12.0 -# # -# # μ will be -# # (1. + 2. + 3.) / 3 = 2. -# # (4. + 5. + 6.) / 3 = 5. -# # -# # (7. + 8. + 9.) / 3 = 8. -# # (10. + 11. + 12.) / 3 = 11. -# # -# # ∴ update rule with momentum: -# # (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5 -# # (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8 -# N = ndims(x) -# @test m.μ ≈ [0.5, 0.8] -# n = prod(size(x,i) for i in 1:N-2) -# corr = n / (n-1) -# σ² = var(x, dims=1:N-2, corrected=false) -# @test m.σ² ≈ 0.1*corr*vec(mean(σ², dims=N)) .+ 0.9 * 1 - -# y = m(x) -# @test length(m.μ) == 2 -# @test length(m.σ²) == 2 -# @test y ≈ (x .- reshape(m.μ, 1,2,1)) ./ sqrt.(reshape(m.σ², 1,2,1) .+ 1f-5) atol=1.0e-5 - -# @inferred m(x) -# end - -# # with activation function -# let m = InstanceNorm(2, sigmoid; affine=true, track_stats=true), sizes = (3, 2, 2), -# x = reshape(collect(1:prod(sizes)), sizes) -# x = Float64.(x) -# affine_shape = collect(sizes) -# affine_shape[[1,3]] .= 1 - -# y = evalwgrad(m, x) -# y = m(x) # inference time after a training step -# μ = reshape(m.μ, affine_shape...) -# σ² = reshape(m.σ², affine_shape...) -# @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 - -# @inferred m(x) -# end - -# # with activation function -# let m = InstanceNorm(2, sigmoid; affine=true, track_stats=false), sizes = (3, 2, 2), -# x = reshape(collect(1:prod(sizes)), sizes) - -# @test Flux.hasaffine(m) == true -# @test length(Flux.trainables(m)) == 2 -# x = Float64.(x) -# y = m(x) -# μ = mean(x, dims=1) -# σ² = var(x, dims=1, corrected=false) -# @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 - -# @inferred m(x) -# end - -# let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2), -# x = reshape(collect(1:prod(sizes)), sizes) -# @test Flux.hasaffine(m) == false -# @test length(Flux.trainables(m)) == 0 - -# x = Float64.(x) -# y = m(x) -# μ = mean(x, dims=1) -# σ² = var(x, dims=1, corrected=false) -# @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 - -# @inferred m(x) -# end - - -# let m = trainmode!(InstanceNorm(2; affine=true)), sizes = (2, 4, 1, 2, 3), -# x = Float32.(reshape(collect(1:prod(sizes)), sizes)) -# y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) -# y = reshape(m(y), sizes...) -# @test m(x) == y - -# @inferred m(x) -# end - -# # check that μ, σ², and the output are the correct size for higher rank tensors -# let m = InstanceNorm(2; affine=true,track_stats=true), sizes = (5, 5, 3, 4, 2, 6), -# x = reshape(Float32.(collect(1:prod(sizes))), sizes) -# y = evalwgrad(m, x) -# @test size(m.μ) == (sizes[end - 1], ) -# @test size(m.σ²) == (sizes[end - 1], ) -# @test size(y) == sizes - -# @inferred m(x) -# end - -# # show that instance norm is equal to batch norm when channel and batch dims are squashed -# let m_inorm = trainmode!(InstanceNorm(2; affine=true)), m_bnorm = trainmode!(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6), -# x = reshape(Float32.(collect(1:prod(sizes))), sizes) -# @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes) -# end - -# let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1); -# m(x) -# @test (@allocated m(x)) < 100_000_000 - -# @inferred m(x) -# end - -# @test length(Flux.trainables(InstanceNorm(10))) == 0 -# @test length(Flux.trainables(InstanceNorm(10, affine=true))) == 2 -# @test length(Flux.trainables(InstanceNorm(10, affine=false))) == 0 - -# @test InstanceNorm(5; active=true).active === true -# @test_throws Exception InstanceNorm(5; active=:something_else) -#end - -#@testset "LayerNorm" begin -# x = rand(2,3) -# @test LayerNorm(2)(x) ≈ Flux.normalise(x, dims=1) -# x = rand(2,3,4) -# @test LayerNorm(2)(x) ≈ Flux.normalise(x, dims=1) -# x = rand(2,3,4,5) -# @test LayerNorm(2)(x) ≈ Flux.normalise(x, dims=1) -# x = rand(2) -# @test LayerNorm(2, tanh)(x) ≈ tanh.(Flux.normalise(x, dims=1)) - -# x = rand(2,3,4,5) -# @test LayerNorm((2,3))(x) ≈ Flux.normalise(x, dims=(1,2)) -# x = rand(2,3,4,5) -# @test LayerNorm((2,3,4))(x) ≈ Flux.normalise(x, dims=1:3) - -# m = LayerNorm((2,3,4)) -# @test Flux.hasaffine(m) == true -# @test length(Flux.trainables(m)) == 2 -# m = LayerNorm((2,3,4), affine=false) -# @test Flux.hasaffine(m) == false -# @test length(Flux.trainables(m)) == 0 -#end - -#@testset "GroupNorm" begin -# # begin tests -# squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions - -# let m = GroupNorm(4,2), sizes = (3,4,2), -# x = reshape(collect(1:prod(sizes)), sizes) - -# @test length(Flux.trainables(m)) == 2 -# x = Float32.(x) -# @test m.β == [0, 0, 0, 0] # initβ(32) -# @test m.γ == [1, 1, 1, 1] # initγ(32) - -# ŷ = evalwgrad(m, x) +@testset "Dropout" begin + @testset for rng_kwargs in ((), (; rng = MersenneTwister())) + x = [1.0+0im,2.0+1im,3.0+3im] + @test x == Dropout(0.1; rng_kwargs...)(x) + @test x == evalwgrad(Dropout(0; rng_kwargs...), x) + @test zero(x) == evalwgrad(Dropout(1; rng_kwargs...), x) + + x = [1.,2.,3.] + @test x == Dropout(0.1; rng_kwargs...)(x) + @test x == evalwgrad(Dropout(0; rng_kwargs...), x) + @test zero(x) == evalwgrad(Dropout(1; rng_kwargs...), x) + + x = rand(100) + m = Dropout(0.9; rng_kwargs...) + y = evalwgrad(m, x) + @test count(a->a==0, y) > 50 + testmode!(m, true) + y = evalwgrad(m, x) # should override istraining + @test count(a->a==0, y) == 0 + testmode!(m, false) + y = evalwgrad(m, x) + @test count(a->a==0, y) > 50 + + # Keyword active=false + m2 = Dropout(0.9; active=false, rng_kwargs...) + y2 = evalwgrad(m2, x) + @test count(iszero, y2) == 0 + + x = rand(Float32, 100) + m = Chain(Dense(100 => 100), + Dropout(0.9; rng_kwargs...)) + y = evalwgrad(m, x) + @test count(a->a == 0, y) > 50 + testmode!(m, true) + y = evalwgrad(m, x) # should override istraining + @test count(a->a == 0, y) == 0 + + x = rand(100, 50) + m = Dropout(0.5; dims = 2, rng_kwargs...) + y = m(x) + c = map(i->count(a->a==0, @view y[i, :]), 1:100) + @test minimum(c) == maximum(c) + m = Dropout(0.5; dims = 1, rng_kwargs...) + y = m(x) + c = map(i->count(a->a==0, @view y[:, i]), 1:50) + @test minimum(c) == maximum(c) + + # issue #1084 + m = Dropout(0.9; rng_kwargs...) + x = rand(100) + + testmode!(m) + y = m(x) + @test count(a->a == 0, y) == 0 + trainmode!(m) + y = m(x) + @test count(a->a == 0, y) > 50 + + y = Flux.dropout(values(rng_kwargs)..., x, 0.9) # , active=true) + @test count(a->a == 0, y) > 50 + + y = Flux.dropout(values(rng_kwargs)..., x, 0.9 * 0) # , active=false) + @test count(a->a == 0, y) == 0 + + # CPU RNGs map onto CPU ok + if isempty(rng_kwargs) + @test cpu(m).rng isa Random.TaskLocalRNG + else + @test cpu(m).rng === only(values(rng_kwargs)) + end + end + + @test Dropout(0.5; active=true).active === true + @test_throws Exception Dropout(0.5; active=:something_else) +end + +@testset "AlphaDropout" begin + @testset for rng_kwargs in ((), (; rng = MersenneTwister())) + x = [1., 2., 3.] + @test x == AlphaDropout(0.1; rng_kwargs...)(x) + @test x == evalwgrad(AlphaDropout(0; rng_kwargs...), x) + @test zero(x) == evalwgrad(AlphaDropout(1; rng_kwargs...), x) + + x = randn(1000) # large enough to prevent flaky test + m = AlphaDropout(0.5; rng_kwargs...) + + y = evalwgrad(m, x) + # Should preserve unit mean and variance + @test mean(y) ≈ 0 atol=0.2 + @test var(y) ≈ 1 atol=0.2 + + testmode!(m, true) # should override istraining + @test evalwgrad(m, x) == x + + testmode!(m, false) + y = evalwgrad(m, x) + @test mean(y) ≈ 0 atol=0.2 + @test var(y) ≈ 1 atol=0.2 + + # Known good value ranges + # Values taken from https://github.com/pytorch/pytorch/blob/v1.10.0/test/cpp/api/modules.cpp#L1337-L1338 + x = ones(100) + if isempty(rng_kwargs) + @test 40 < sum(evalwgrad(m, x)) < 130 + else + # FIXME: this breaks spuriously for MersenneTwister + @test_skip 40 < sum(evalwgrad(m, x)) < 130 + end + + # CPU RNGs map onto CPU ok + if isempty(rng_kwargs) + @test cpu(m).rng isa Random.TaskLocalRNG + else + @test cpu(m).rng === only(values(rng_kwargs)) + end + end + + @test AlphaDropout(0.5; active=true).active === true + @test_throws Exception AlphaDropout(0.5; active=:something_else) +end + +@testset "BatchNorm" begin + let m = BatchNorm(2), x = [1.0 3.0 5.0; + 2.0 4.0 6.0] + + @test Flux.hasaffine(m) == true + @test length(Flux.trainables(m)) == 2 + + @test m.β == [0, 0] # initβ(2) + @test m.γ == [1, 1] # initγ(2) + # initial m.σ is 1 + # initial m.μ is 0 + + y = evalwgrad(m, x) + @test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5) + # julia> x + # 2×3 Array{Float64,2}: + # 1.0 3.0 5.0 + # 2.0 4.0 6.0 + # + # μ of batch will be + # (1. + 3. + 5.) / 3 = 3 + # (2. + 4. + 6.) / 3 = 4 + # + # ∴ update rule with momentum: + # .1 * 3 + 0 = .3 + # .1 * 4 + 0 = .4 + @test m.μ ≈ reshape([0.3, 0.4], 2, 1) + + # julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + # 2×1 Array{Float64,2}: + # 1.3 + # 1.3 + @test m.σ² ≈ .1 .* var(x, dims=2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + + x′ = m(x) + @test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5) + + @inferred m(x) + end + + let m = BatchNorm(2; track_stats=false), x = Float32[1.0 3.0 5.0; 2.0 4.0 6.0] + y = @inferred m(x) + m16 = f16(m) + y16 = @inferred m16(f16(x)) + @test eltype(y16) == Float16 + @test y16 ≈ y atol=1e-3 + end + + # with activation function + let m = BatchNorm(2, sigmoid), x = Float32[1.0 3.0 5.0; + 2.0 4.0 6.0] + y = m(x) + @test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7) + @inferred m(x) + m16 = f16(m) + y16 = @inferred m16(f16(x)) + @test eltype(y16) == Float16 + @test y16 ≈ y atol=1e-3 + end + + let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1) + y = reshape(permutedims(x, [2, 1, 3]), 2, :) + y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3]) + @test m(x) == y + @inferred m(x) + end + + let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1) + y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) + y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) + @test m(x) == y + @inferred m(x) + end + + let m = trainmode!(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) + y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :) + y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5]) + @test m(x) == y + @inferred m(x) + end + + let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1); + m(x) + @test (@allocated m(x)) < 100_000_000 + @inferred m(x) + end + + @test length(Flux.trainables(BatchNorm(10))) == 2 + @test length(Flux.trainables(BatchNorm(10, affine=true))) == 2 + @test length(Flux.trainables(BatchNorm(10, affine=false))) == 0 + + @test BatchNorm(5; active=true).active === true + @test_throws Exception BatchNorm(5; active=:something_else) +end + +@testset "InstanceNorm" begin + # begin tests + let m = InstanceNorm(2; affine=true, track_stats=true), sizes = (3, 2, 2), + x = reshape(collect(1:prod(sizes)), sizes) + + @test length(Flux.trainables(m)) == 2 + x = Float32.(x) + @test m.β == [0, 0] # initβ(2) + @test m.γ == [1, 1] # initγ(2) + y = evalwgrad(m, x) + + #julia> x + #[:, :, 1] = + # 1.0 4.0 + # 2.0 5.0 + # 3.0 6.0 + # + #[:, :, 2] = + # 7.0 10.0 + # 8.0 11.0 + # 9.0 12.0 + # + # μ will be + # (1. + 2. + 3.) / 3 = 2. + # (4. + 5. + 6.) / 3 = 5. + # + # (7. + 8. + 9.) / 3 = 8. + # (10. + 11. + 12.) / 3 = 11. + # + # ∴ update rule with momentum: + # (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5 + # (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8 + N = ndims(x) + @test m.μ ≈ [0.5, 0.8] + n = prod(size(x,i) for i in 1:N-2) + corr = n / (n-1) + σ² = var(x, dims=1:N-2, corrected=false) + @test m.σ² ≈ 0.1*corr*vec(mean(σ², dims=N)) .+ 0.9 * 1 + + y = m(x) + @test length(m.μ) == 2 + @test length(m.σ²) == 2 + @test y ≈ (x .- reshape(m.μ, 1,2,1)) ./ sqrt.(reshape(m.σ², 1,2,1) .+ 1f-5) atol=1.0e-5 + + @inferred m(x) + end + + # with activation function + let m = InstanceNorm(2, sigmoid; affine=true, track_stats=true), sizes = (3, 2, 2), + x = reshape(collect(1:prod(sizes)), sizes) + x = Float64.(x) + affine_shape = collect(sizes) + affine_shape[[1,3]] .= 1 + + y = evalwgrad(m, x) + y = m(x) # inference time after a training step + μ = reshape(m.μ, affine_shape...) + σ² = reshape(m.σ², affine_shape...) + @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 + + @inferred m(x) + end + + # with activation function + let m = InstanceNorm(2, sigmoid; affine=true, track_stats=false), sizes = (3, 2, 2), + x = reshape(collect(1:prod(sizes)), sizes) + + @test Flux.hasaffine(m) == true + @test length(Flux.trainables(m)) == 2 + x = Float64.(x) + y = m(x) + μ = mean(x, dims=1) + σ² = var(x, dims=1, corrected=false) + @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 + + @inferred m(x) + end + + let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2), + x = reshape(collect(1:prod(sizes)), sizes) + @test Flux.hasaffine(m) == false + @test length(Flux.trainables(m)) == 0 + + x = Float64.(x) + y = m(x) + μ = mean(x, dims=1) + σ² = var(x, dims=1, corrected=false) + @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 + + @inferred m(x) + end + + + let m = trainmode!(InstanceNorm(2; affine=true)), sizes = (2, 4, 1, 2, 3), + x = Float32.(reshape(collect(1:prod(sizes)), sizes)) + y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) + y = reshape(m(y), sizes...) + @test m(x) == y + + @inferred m(x) + end + + # check that μ, σ², and the output are the correct size for higher rank tensors + let m = InstanceNorm(2; affine=true,track_stats=true), sizes = (5, 5, 3, 4, 2, 6), + x = reshape(Float32.(collect(1:prod(sizes))), sizes) + y = evalwgrad(m, x) + @test size(m.μ) == (sizes[end - 1], ) + @test size(m.σ²) == (sizes[end - 1], ) + @test size(y) == sizes + + @inferred m(x) + end + + # show that instance norm is equal to batch norm when channel and batch dims are squashed + let m_inorm = trainmode!(InstanceNorm(2; affine=true)), m_bnorm = trainmode!(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6), + x = reshape(Float32.(collect(1:prod(sizes))), sizes) + @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes) + end + + let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1); + m(x) + @test (@allocated m(x)) < 100_000_000 + + @inferred m(x) + end + + @test length(Flux.trainables(InstanceNorm(10))) == 0 + @test length(Flux.trainables(InstanceNorm(10, affine=true))) == 2 + @test length(Flux.trainables(InstanceNorm(10, affine=false))) == 0 + + @test InstanceNorm(5; active=true).active === true + @test_throws Exception InstanceNorm(5; active=:something_else) +end + +@testset "LayerNorm" begin + x = rand(2,3) + @test LayerNorm(2)(x) ≈ Flux.normalise(x, dims=1) + x = rand(2,3,4) + @test LayerNorm(2)(x) ≈ Flux.normalise(x, dims=1) + x = rand(2,3,4,5) + @test LayerNorm(2)(x) ≈ Flux.normalise(x, dims=1) + x = rand(2) + @test LayerNorm(2, tanh)(x) ≈ tanh.(Flux.normalise(x, dims=1)) + + x = rand(2,3,4,5) + @test LayerNorm((2,3))(x) ≈ Flux.normalise(x, dims=(1,2)) + x = rand(2,3,4,5) + @test LayerNorm((2,3,4))(x) ≈ Flux.normalise(x, dims=1:3) + + m = LayerNorm((2,3,4)) + @test Flux.hasaffine(m) == true + @test length(Flux.trainables(m)) == 2 + m = LayerNorm((2,3,4), affine=false) + @test Flux.hasaffine(m) == false + @test length(Flux.trainables(m)) == 0 +end + +@testset "GroupNorm" begin + # begin tests + squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions + + let m = GroupNorm(4,2), sizes = (3,4,2), + x = reshape(collect(1:prod(sizes)), sizes) + + @test length(Flux.trainables(m)) == 2 + x = Float32.(x) + @test m.β == [0, 0, 0, 0] # initβ(32) + @test m.γ == [1, 1, 1, 1] # initγ(32) + + ŷ = evalwgrad(m, x) -# @test m.μ === nothing -# @test m.σ² === nothing -# ŷ = m(x) -# y = [-1.4638476 0.29276943 -1.4638476 0.29276943; -0.87830865 0.87830853 -0.8783088 0.8783083; -0.29276967 1.4638474 -0.2927699 1.4638472;;; -1.4638476 0.29276943 -1.4638472 0.29276943; -0.8783083 0.8783083 -0.8783083 0.8783083; -0.29276943 1.4638472 -0.29276943 1.4638472] - -# @test ŷ ≈ y atol=1.0e-5 -# end -# # with activation function -# let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2), -# x = reshape(collect(1:prod(sizes)), sizes) + @test m.μ === nothing + @test m.σ² === nothing + ŷ = m(x) + y = [-1.4638476 0.29276943 -1.4638476 0.29276943; -0.87830865 0.87830853 -0.8783088 0.8783083; -0.29276967 1.4638474 -0.2927699 1.4638472;;; -1.4638476 0.29276943 -1.4638472 0.29276943; -0.8783083 0.8783083 -0.8783083 0.8783083; -0.29276943 1.4638472 -0.29276943 1.4638472] + + @test ŷ ≈ y atol=1.0e-5 + end + # with activation function + let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2), + x = reshape(collect(1:prod(sizes)), sizes) -# x = Float32.(x) -# μ_affine_shape = ones(Int,length(sizes) + 1) -# μ_affine_shape[end-1] = 2 # Number of groups - -# affine_shape = ones(Int,length(sizes) + 1) -# affine_shape[end-2] = 2 # Channels per group -# affine_shape[end-1] = 2 # Number of groups -# affine_shape[1] = sizes[1] -# affine_shape[end] = sizes[end] - -# og_shape = size(x) - -# ŷ = m(x) -# y = [0.18787955 0.57267404 0.18787955 0.57267404; 0.2935284 0.70647156 0.29352835 0.70647156; 0.42732593 0.81212044 0.42732587 0.8121204;;; 0.18787955 0.57267404 0.1878796 0.57267404; 0.29352847 0.70647156 0.29352847 0.70647156; 0.42732602 0.8121204 0.42732602 0.8121204] -# @test ŷ ≈ y atol=1e-7 -# end - -# let m = trainmode!(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3), -# x = Float32.(reshape(collect(1:prod(sizes)), sizes)) -# y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) -# y = reshape(m(y), sizes...) -# @test m(x) == y -# end - -# # show that group norm is the same as instance norm when the group size is the same as the number of channels -# let IN = trainmode!(InstanceNorm(4; affine=true)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5), -# x = Float32.(reshape(collect(1:prod(sizes)), sizes)) -# @test IN(x) ≈ GN(x) -# end - -# # show that group norm is the same as batch norm for a group of size 1 and batch of size 1 -# let BN = trainmode!(BatchNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,1), -# x = Float32.(reshape(collect(1:prod(sizes)), sizes)) -# @test BN(x) ≈ GN(x) -# end - -# @test GroupNorm(5, 5; active=true).active === true -# @test_throws Exception GroupNorm(5, 5; active=:something_else) -#end + x = Float32.(x) + μ_affine_shape = ones(Int,length(sizes) + 1) + μ_affine_shape[end-1] = 2 # Number of groups + + affine_shape = ones(Int,length(sizes) + 1) + affine_shape[end-2] = 2 # Channels per group + affine_shape[end-1] = 2 # Number of groups + affine_shape[1] = sizes[1] + affine_shape[end] = sizes[end] + + og_shape = size(x) + + ŷ = m(x) + y = [0.18787955 0.57267404 0.18787955 0.57267404; 0.2935284 0.70647156 0.29352835 0.70647156; 0.42732593 0.81212044 0.42732587 0.8121204;;; 0.18787955 0.57267404 0.1878796 0.57267404; 0.29352847 0.70647156 0.29352847 0.70647156; 0.42732602 0.8121204 0.42732602 0.8121204] + @test ŷ ≈ y atol=1e-7 + end + + let m = trainmode!(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3), + x = Float32.(reshape(collect(1:prod(sizes)), sizes)) + y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) + y = reshape(m(y), sizes...) + @test m(x) == y + end + + # show that group norm is the same as instance norm when the group size is the same as the number of channels + let IN = trainmode!(InstanceNorm(4; affine=true)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5), + x = Float32.(reshape(collect(1:prod(sizes)), sizes)) + @test IN(x) ≈ GN(x) + end + + # show that group norm is the same as batch norm for a group of size 1 and batch of size 1 + let BN = trainmode!(BatchNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,1), + x = Float32.(reshape(collect(1:prod(sizes)), sizes)) + @test BN(x) ≈ GN(x) + end + + @test GroupNorm(5, 5; active=true).active === true + @test_throws Exception GroupNorm(5, 5; active=:something_else) +end @testset "WeightNorm" begin x = rand(Float32, 1, 3) @@ -453,14 +453,27 @@ evalwgrad(f, x...) = pullback(f, x...)[1] @test_throws ArgumentError WeightNorm(m, :bias) @test_throws "is all zero" WeightNorm(m, :bias) + og = (Zygote.gradient(m) do m + sum(m(x)) + end)[1] g = (Zygote.gradient(mn) do mn sum(mn(x)) end)[1] + @test g.layer.weight ≡ nothing # Original weight does not participate. @test g.layer.bias ≢ nothing @test g.g ≢ nothing @test g.v ≢ nothing + # Compare gradients with original layer. + + n2 = sum(abs2, mn.v; dims=2) + ϵ = eps(eltype(mn.v)) + @test (og.weight .* mn.v ./ sqrt.(n2 .+ ϵ)) ≈ g.g + @test (og.weight .* mn.g ./ n2 .- mn.g .* g.g .* mn.v ./ n2.^2) ≈ g.v atol=1f-6 + + # Test WeightNorm removal. + om = Flux.remove_weight_norms(mn) @test om isa Dense @test om.weight ≈ m.weight @@ -483,36 +496,36 @@ evalwgrad(f, x...) = pullback(f, x...)[1] @test c(x) ≈ oc(x) end -# @testset "second derivatives" begin -# m1 = Dropout(0.5) -# @test Zygote.hessian_reverse(sum∘m1, [1.0,2.0,3.0]) == zeros(3, 3) - -# m2 = Chain(BatchNorm(3), sum) -# @test Zygote.hessian_reverse(m2, Float32[1 2; 3 4; 5 6]) == zeros(Float32, 6, 6) broken = VERSION >= v"1.11" -# end - -# @testset "ForwardDiff" begin -# bn = BatchNorm(3) -# @test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32} -# # iszero(bn.μ) # is true. But ideally would not be, if Flux would automatically choose trainmode -# Flux.trainmode!(bn) -# # This was an error, https://github.com/FluxML/Flux.jl/issues/2122 -# @test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32} -# @test !iszero(bn.μ) - -# # Easy case of 2122, gradient with x -# x5 = rand(Float32, 5, 3) -# bn1 = BatchNorm(5, relu) -# bn2 = BatchNorm(5, relu) -# g1 = Zygote.gradient(x -> sum(abs2, bn1(x)), x5)[1] -# g2 = ForwardDiff.gradient(x -> sum(abs2, bn2(x)), x5) -# @test g1 ≈ g2 - -# # Harder case? -# v1, re1 = Flux.destructure(BatchNorm(5, relu)); -# g1 = Zygote.gradient(v -> sum(abs2, re1(v)(x5)), v1)[1] - -# v2, re2 = Flux.destructure(BatchNorm(5, relu)); -# g2 = ForwardDiff.gradient(v -> sum(abs2, re2(v)(x5)), v2) -# end +@testset "second derivatives" begin + m1 = Dropout(0.5) + @test Zygote.hessian_reverse(sum∘m1, [1.0,2.0,3.0]) == zeros(3, 3) + + m2 = Chain(BatchNorm(3), sum) + @test Zygote.hessian_reverse(m2, Float32[1 2; 3 4; 5 6]) == zeros(Float32, 6, 6) broken = VERSION >= v"1.11" +end + +@testset "ForwardDiff" begin + bn = BatchNorm(3) + @test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32} + # iszero(bn.μ) # is true. But ideally would not be, if Flux would automatically choose trainmode + Flux.trainmode!(bn) + # This was an error, https://github.com/FluxML/Flux.jl/issues/2122 + @test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32} + @test !iszero(bn.μ) + + # Easy case of 2122, gradient with x + x5 = rand(Float32, 5, 3) + bn1 = BatchNorm(5, relu) + bn2 = BatchNorm(5, relu) + g1 = Zygote.gradient(x -> sum(abs2, bn1(x)), x5)[1] + g2 = ForwardDiff.gradient(x -> sum(abs2, bn2(x)), x5) + @test g1 ≈ g2 + + # Harder case? + v1, re1 = Flux.destructure(BatchNorm(5, relu)); + g1 = Zygote.gradient(v -> sum(abs2, re1(v)(x5)), v1)[1] + + v2, re2 = Flux.destructure(BatchNorm(5, relu)); + g2 = ForwardDiff.gradient(v -> sum(abs2, re2(v)(x5)), v2) +end From 7f3a905928b09c73b7366b78d2f169eb4bb4dc36 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 12 Dec 2024 17:15:35 +0200 Subject: [PATCH 05/12] Fix link --- src/layers/normalise.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 56a7622f95..b22b772be2 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -607,7 +607,7 @@ true # Reference -Salimans & Kingma, _Weight Normalization_ (2016) https://arxiv.org/abs/1602.07868 +Salimans & Kingma, _Weight Normalization_ (2016) """ function WeightNorm(layer::L, which::Symbol = :weight; dims = -1) where L hasfield(L, which) || throw(ArgumentError("`$L` does not have field `:$which`.")) From 780a1386fca45ca25bef2c8f4d6044d8ac20f766 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 12 Dec 2024 19:43:05 +0200 Subject: [PATCH 06/12] Store 'v' in weights of the original layer --- src/layers/normalise.jl | 14 +++++++------- test/layers/normalisation.jl | 17 +++++++++-------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index b22b772be2..41028bf62b 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -569,10 +569,9 @@ See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`Laye """ hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine -struct WeightNorm{which, dims, L, G, V} +struct WeightNorm{which, dims, L, G} layer::L g::G - v::V end @layer WeightNorm @@ -625,16 +624,17 @@ function WeightNorm(layer::L, which::Symbol = :weight; dims = -1) where L end g = sqrt.(sum(abs2, x; dims) .+ eps(eltype(x))) - v = x ./ g - WeightNorm{which, dims, L, typeof(g), typeof(v)}(layer, g, v) + x ./= g # Store `v` in the original weights. + WeightNorm{which, dims, L, typeof(g)}(layer, g) end (w::WeightNorm)(x) = transform(w)(x) function transform(wn::WeightNorm{which, dims}) where {which, dims} - ϵ = eps(eltype(wn.v)) - n2 = sum(abs2, wn.v; dims) - w = @. wn.g * wn.v / sqrt(n2 + ϵ) + ϵ = eps(eltype(wn.g)) + v = getfield(wn.layer, which) + n2 = sum(abs2, v; dims) + w = @. wn.g * v / sqrt(n2 + ϵ) fields, ctor = Functors.functor(wn.layer) return ctor(merge( diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 2ea46bf03e..314dcbe293 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -443,8 +443,8 @@ end @testset "WeightNorm" begin x = rand(Float32, 1, 3) - m = Dense(1 => 2) - mn = WeightNorm(m) + mn = WeightNorm(Dense(1 => 2)) + m = Flux.remove_weight_norms(mn) @test m(x) ≈ mn(x) @test_throws ArgumentError WeightNorm(m, :weights) @@ -460,17 +460,18 @@ end sum(mn(x)) end)[1] - @test g.layer.weight ≡ nothing # Original weight does not participate. + @test g.layer.weight ≢ nothing # Original weight acts as a direction `v`. @test g.layer.bias ≢ nothing @test g.g ≢ nothing - @test g.v ≢ nothing # Compare gradients with original layer. - n2 = sum(abs2, mn.v; dims=2) - ϵ = eps(eltype(mn.v)) - @test (og.weight .* mn.v ./ sqrt.(n2 .+ ϵ)) ≈ g.g - @test (og.weight .* mn.g ./ n2 .- mn.g .* g.g .* mn.v ./ n2.^2) ≈ g.v atol=1f-6 + v = mn.layer.weight + ϵ = eps(eltype(v)) + n2 = sum(abs2, v; dims=2) + + @test (og.weight .* v ./ sqrt.(n2 .+ ϵ)) ≈ g.g + @test (og.weight .* mn.g ./ n2 .- mn.g .* g.g .* v ./ n2.^2) ≈ g.layer.weight atol=1f-6 # Test WeightNorm removal. From 52d0a7a423f8ddffadd335ece957204ee67e7a03 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 12 Dec 2024 20:44:23 +0200 Subject: [PATCH 07/12] Rename 'transform' to 'reparametrize' & other minor changes --- src/Flux.jl | 1 + src/layers/normalise.jl | 12 ++++++++---- test/layers/normalisation.jl | 19 +++++++++++++++---- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 5b3d3f7977..7041472a6b 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -94,6 +94,7 @@ export Chain, Dense, Embedding, EmbeddingBag, siamese_contrastive_loss, squared_hinge_loss, tversky_loss, + remove_weight_norms, )) include("gradient.jl") diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 41028bf62b..b06c6a146a 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -624,13 +624,17 @@ function WeightNorm(layer::L, which::Symbol = :weight; dims = -1) where L end g = sqrt.(sum(abs2, x; dims) .+ eps(eltype(x))) - x ./= g # Store `v` in the original weights. WeightNorm{which, dims, L, typeof(g)}(layer, g) end -(w::WeightNorm)(x) = transform(w)(x) +(w::WeightNorm)(x) = reparametrize(w)(x) -function transform(wn::WeightNorm{which, dims}) where {which, dims} +""" + reparametrize(wn::WeightNorm) + +Apply `WeightNorm` reparametrization and return underlying `layer`. +""" +function reparametrize(wn::WeightNorm{which, dims}) where {which, dims} ϵ = eps(eltype(wn.g)) v = getfield(wn.layer, which) n2 = sum(abs2, v; dims) @@ -681,4 +685,4 @@ Chain( ) # Total: 4 arrays, 22 parameters, 392 bytes. ``` """ -remove_weight_norms(x) = fmap(transform, x; exclude=l -> l isa WeightNorm) +remove_weight_norms(x) = fmap(reparametrize, x; exclude=l -> l isa WeightNorm) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 314dcbe293..744cefffbc 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -469,9 +469,10 @@ end v = mn.layer.weight ϵ = eps(eltype(v)) n2 = sum(abs2, v; dims=2) + v = v ./ sqrt.(n2 .+ ϵ) - @test (og.weight .* v ./ sqrt.(n2 .+ ϵ)) ≈ g.g - @test (og.weight .* mn.g ./ n2 .- mn.g .* g.g .* v ./ n2.^2) ≈ g.layer.weight atol=1f-6 + @test (og.weight .* v) ≈ g.g + @test (og.weight .* mn.g .- mn.g .* g.g .* v) ≈ g.layer.weight atol=1f-6 # Test WeightNorm removal. @@ -484,14 +485,24 @@ end c = Chain( WeightNorm(Conv((3,), 1 => 2)), - WeightNorm(Conv((3,), 2 => 2)), + Conv((3,), 2 => 2), + WeightNorm(Conv((3,), 2 => 3)), + x -> reshape(x, 18, :), + WeightNorm(Dense(18, 4)), + Dense(4, 1), ) @test c[1] isa WeightNorm - @test c[2] isa WeightNorm + @test c[2] isa Conv + @test c[3] isa WeightNorm + @test c[5] isa WeightNorm + @test c[6] isa Dense oc = Flux.remove_weight_norms(c) @test oc[1] isa Conv @test oc[2] isa Conv + @test oc[3] isa Conv + @test oc[5] isa Dense + @test oc[6] isa Dense x = rand(Float32, 12, 1, 1) @test c(x) ≈ oc(x) From ed7f51ab1d69ca3eb19fc8c4fce8c6ccd8268586 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 12 Dec 2024 22:25:59 +0200 Subject: [PATCH 08/12] Adjust for GPUs --- src/layers/normalise.jl | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index b06c6a146a..33350a904c 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -569,9 +569,12 @@ See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`Laye """ hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine -struct WeightNorm{which, dims, L, G} +struct WeightNorm{L, G, D} layer::L g::G + + which::Symbol + dims::D end @layer WeightNorm @@ -624,7 +627,7 @@ function WeightNorm(layer::L, which::Symbol = :weight; dims = -1) where L end g = sqrt.(sum(abs2, x; dims) .+ eps(eltype(x))) - WeightNorm{which, dims, L, typeof(g)}(layer, g) + WeightNorm(layer, g, which, dims) end (w::WeightNorm)(x) = reparametrize(w)(x) @@ -634,22 +637,22 @@ end Apply `WeightNorm` reparametrization and return underlying `layer`. """ -function reparametrize(wn::WeightNorm{which, dims}) where {which, dims} +function reparametrize(wn::WeightNorm) ϵ = eps(eltype(wn.g)) - v = getfield(wn.layer, which) - n2 = sum(abs2, v; dims) + v = getfield(wn.layer, wn.which) + n2 = sum(abs2, v; wn.dims) w = @. wn.g * v / sqrt(n2 + ϵ) fields, ctor = Functors.functor(wn.layer) return ctor(merge( - fields, NamedTuple{(which,)}((w,)), + fields, NamedTuple{(wn.which,)}((w,)), )) end -function Base.show(io::IO, w::WeightNorm{which, dims}) where {which, dims} +function Base.show(io::IO, w::WeightNorm) print(io, "WeightNorm(") Base.show(io, w.layer) - print(io, ", :", which, "; dims=", dims) + print(io, ", :", w.which, "; dims=", w.dims) print(io, ")") end From ee4a8c368ba8e5b0ac1ad5939fed3bc05ae0f104 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 12 Dec 2024 22:52:34 +0200 Subject: [PATCH 09/12] Add Flux testsuite --- test/layers/normalisation.jl | 67 -------------------------------- test/runtests.jl | 18 +++++++++ test/testsuite/normalization.jl | 68 +++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 67 deletions(-) create mode 100644 test/testsuite/normalization.jl diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 744cefffbc..b0c6584b84 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -441,73 +441,6 @@ end @test_throws Exception GroupNorm(5, 5; active=:something_else) end -@testset "WeightNorm" begin - x = rand(Float32, 1, 3) - mn = WeightNorm(Dense(1 => 2)) - m = Flux.remove_weight_norms(mn) - @test m(x) ≈ mn(x) - - @test_throws ArgumentError WeightNorm(m, :weights) - @test_throws "does not have field" WeightNorm(m, :weights) - - @test_throws ArgumentError WeightNorm(m, :bias) - @test_throws "is all zero" WeightNorm(m, :bias) - - og = (Zygote.gradient(m) do m - sum(m(x)) - end)[1] - g = (Zygote.gradient(mn) do mn - sum(mn(x)) - end)[1] - - @test g.layer.weight ≢ nothing # Original weight acts as a direction `v`. - @test g.layer.bias ≢ nothing - @test g.g ≢ nothing - - # Compare gradients with original layer. - - v = mn.layer.weight - ϵ = eps(eltype(v)) - n2 = sum(abs2, v; dims=2) - v = v ./ sqrt.(n2 .+ ϵ) - - @test (og.weight .* v) ≈ g.g - @test (og.weight .* mn.g .- mn.g .* g.g .* v) ≈ g.layer.weight atol=1f-6 - - # Test WeightNorm removal. - - om = Flux.remove_weight_norms(mn) - @test om isa Dense - @test om.weight ≈ m.weight - @test om.bias ≈ m.bias - - # Test with Chain. - - c = Chain( - WeightNorm(Conv((3,), 1 => 2)), - Conv((3,), 2 => 2), - WeightNorm(Conv((3,), 2 => 3)), - x -> reshape(x, 18, :), - WeightNorm(Dense(18, 4)), - Dense(4, 1), - ) - @test c[1] isa WeightNorm - @test c[2] isa Conv - @test c[3] isa WeightNorm - @test c[5] isa WeightNorm - @test c[6] isa Dense - - oc = Flux.remove_weight_norms(c) - @test oc[1] isa Conv - @test oc[2] isa Conv - @test oc[3] isa Conv - @test oc[5] isa Dense - @test oc[6] isa Dense - - x = rand(Float32, 12, 1, 1) - @test c(x) ≈ oc(x) -end - @testset "second derivatives" begin m1 = Dropout(0.5) @test Zygote.hessian_reverse(sum∘m1, [1.0,2.0,3.0]) == zeros(3, 3) diff --git a/test/runtests.jl b/test/runtests.jl index f9936fd3ae..d3476cc4d1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,8 +25,20 @@ include("test_utils.jl") # for test_gradients Random.seed!(0) +include("testsuite/normalization.jl") + +function flux_testsuite(dev) + @testset "Flux Test Suite" begin + @testset "Normalization" begin + normalization_testsuite(dev) + end + end +end + @testset verbose=true "Flux.jl" begin if get(ENV, "FLUX_TEST_CPU", "true") == "true" + flux_testsuite(cpu) + @testset "Utils" begin include("utils.jl") end @@ -84,6 +96,8 @@ Random.seed!(0) if CUDA.functional() @testset "CUDA" begin include("ext_cuda/runtests.jl") + + flux_testsuite(gpu) end else @warn "CUDA.jl package is not functional. Skipping CUDA tests." @@ -99,6 +113,8 @@ Random.seed!(0) if AMDGPU.functional() && AMDGPU.functional(:MIOpen) @testset "AMDGPU" begin include("ext_amdgpu/runtests.jl") + + flux_testsuite(gpu) end else @info "AMDGPU.jl package is not functional. Skipping AMDGPU tests." @@ -114,6 +130,8 @@ Random.seed!(0) if Metal.functional() @testset "Metal" begin include("ext_metal/runtests.jl") + + flux_testsuite(gpu) end else @info "Metal.jl package is not functional. Skipping Metal tests." diff --git a/test/testsuite/normalization.jl b/test/testsuite/normalization.jl new file mode 100644 index 0000000000..ac84154532 --- /dev/null +++ b/test/testsuite/normalization.jl @@ -0,0 +1,68 @@ +function normalization_testsuite(dev) + @testset "WeightNorm" begin + x = rand(Float32, 1, 3) |> dev + mn = WeightNorm(Dense(1 => 2)) |> dev + m = Flux.remove_weight_norms(mn) + @test m(x) ≈ mn(x) + + @test_throws ArgumentError WeightNorm(m, :weights) + @test_throws "does not have field" WeightNorm(m, :weights) + + @test_throws ArgumentError WeightNorm(m, :bias) + @test_throws "is all zero" WeightNorm(m, :bias) + + og = (Zygote.gradient(m) do m + sum(m(x)) + end)[1] + g = (Zygote.gradient(mn) do mn + sum(mn(x)) + end)[1] + + @test g.layer.weight ≢ nothing # Original weight acts as a direction `v`. + @test g.layer.bias ≢ nothing + @test g.g ≢ nothing + + # Compare gradients with original layer. + + v = mn.layer.weight + ϵ = eps(eltype(v)) + n2 = sum(abs2, v; dims=2) + v = v ./ sqrt.(n2 .+ ϵ) + + @test (og.weight .* v) ≈ g.g + @test (og.weight .* mn.g .- mn.g .* g.g .* v) ≈ g.layer.weight atol=1f-6 + + # Test WeightNorm removal. + + om = Flux.remove_weight_norms(mn) + @test om isa Dense + @test om.weight ≈ m.weight + @test om.bias ≈ m.bias + + # Test with Chain. + + c = Chain( + WeightNorm(Conv((3,), 1 => 2)), + Conv((3,), 2 => 2), + WeightNorm(Conv((3,), 2 => 3)), + x -> reshape(x, 18, :), + WeightNorm(Dense(18, 4)), + Dense(4, 1), + ) + @test c[1] isa WeightNorm + @test c[2] isa Conv + @test c[3] isa WeightNorm + @test c[5] isa WeightNorm + @test c[6] isa Dense + + oc = Flux.remove_weight_norms(c) + @test oc[1] isa Conv + @test oc[2] isa Conv + @test oc[3] isa Conv + @test oc[5] isa Dense + @test oc[6] isa Dense + + x = rand(Float32, 12, 1, 1) + @test c(x) ≈ oc(x) + end +end From f91c726ed3db1558258935c5abc180fd5a11f29d Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Fri, 13 Dec 2024 00:27:38 +0200 Subject: [PATCH 10/12] Fix doctests --- src/layers/normalise.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 33350a904c..ed490f919c 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -598,8 +598,9 @@ julia> wc = WeightNorm(c, :weight) WeightNorm( Conv((3,), 1 => 2), # 8 parameters 3×1×1 Array{Float32,...}, # 3 parameters - 3×1×2 Array{Float32,...}, # 6 parameters -) # Total: 4 arrays, 17 parameters, 348 bytes. + :weight, + 3, +) # Total: 3 arrays, 11 parameters, 276 bytes. julia> x = ones(Float32, 12, 1, 1); @@ -672,14 +673,16 @@ Chain( WeightNorm( Conv((3,), 1 => 2), # 8 parameters 3×1×1 Array{Float32,...}, # 3 parameters - 3×1×2 Array{Float32,...}, # 6 parameters + :weight, + 3, ), WeightNorm( Conv((3,), 2 => 2), # 14 parameters 3×2×1 Array{Float32,...}, # 6 parameters - 3×2×2 Array{Float32,...}, # 12 parameters + :weight, + 3, ), -) # Total: 8 arrays, 49 parameters, 756 bytes. +) # Total: 6 arrays, 31 parameters, 588 bytes. julia> Flux.remove_weight_norms(model) Chain( From 648883fd3848ab1f2303346f3a8ebaf7827cf8f6 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Fri, 13 Dec 2024 01:20:08 +0200 Subject: [PATCH 11/12] Fix doctests --- src/layers/normalise.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index ed490f919c..778968350f 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -665,10 +665,11 @@ Remove any [WeightNorm](@ref) parametrization in the model. ### Example ```jldoctest + julia> model = Chain( - WeightNorm(Conv((3,), 1 => 2), :weight), - WeightNorm(Conv((3,), 2 => 2), :weight), -) + WeightNorm(Conv((3,), 1 => 2), :weight), + WeightNorm(Conv((3,), 2 => 2), :weight), + ) Chain( WeightNorm( Conv((3,), 1 => 2), # 8 parameters From 54319e1314d7cd2147239b941bd748b96bcad71a Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Fri, 13 Dec 2024 10:18:12 +0200 Subject: [PATCH 12/12] Cleanup --- src/layers/normalise.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 778968350f..9cd0ff5c8b 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -665,7 +665,6 @@ Remove any [WeightNorm](@ref) parametrization in the model. ### Example ```jldoctest - julia> model = Chain( WeightNorm(Conv((3,), 1 => 2), :weight), WeightNorm(Conv((3,), 2 => 2), :weight),