diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index b22b772be2..41028bf62b 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -569,10 +569,9 @@ See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`Laye """ hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine -struct WeightNorm{which, dims, L, G, V} +struct WeightNorm{which, dims, L, G} layer::L g::G - v::V end @layer WeightNorm @@ -625,16 +624,17 @@ function WeightNorm(layer::L, which::Symbol = :weight; dims = -1) where L end g = sqrt.(sum(abs2, x; dims) .+ eps(eltype(x))) - v = x ./ g - WeightNorm{which, dims, L, typeof(g), typeof(v)}(layer, g, v) + x ./= g # Store `v` in the original weights. + WeightNorm{which, dims, L, typeof(g)}(layer, g) end (w::WeightNorm)(x) = transform(w)(x) function transform(wn::WeightNorm{which, dims}) where {which, dims} - ϵ = eps(eltype(wn.v)) - n2 = sum(abs2, wn.v; dims) - w = @. wn.g * wn.v / sqrt(n2 + ϵ) + ϵ = eps(eltype(wn.g)) + v = getfield(wn.layer, which) + n2 = sum(abs2, v; dims) + w = @. wn.g * v / sqrt(n2 + ϵ) fields, ctor = Functors.functor(wn.layer) return ctor(merge( diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 2ea46bf03e..314dcbe293 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -443,8 +443,8 @@ end @testset "WeightNorm" begin x = rand(Float32, 1, 3) - m = Dense(1 => 2) - mn = WeightNorm(m) + mn = WeightNorm(Dense(1 => 2)) + m = Flux.remove_weight_norms(mn) @test m(x) ≈ mn(x) @test_throws ArgumentError WeightNorm(m, :weights) @@ -460,17 +460,18 @@ end sum(mn(x)) end)[1] - @test g.layer.weight ≡ nothing # Original weight does not participate. + @test g.layer.weight ≢ nothing # Original weight acts as a direction `v`. @test g.layer.bias ≢ nothing @test g.g ≢ nothing - @test g.v ≢ nothing # Compare gradients with original layer. - n2 = sum(abs2, mn.v; dims=2) - ϵ = eps(eltype(mn.v)) - @test (og.weight .* mn.v ./ sqrt.(n2 .+ ϵ)) ≈ g.g - @test (og.weight .* mn.g ./ n2 .- mn.g .* g.g .* mn.v ./ n2.^2) ≈ g.v atol=1f-6 + v = mn.layer.weight + ϵ = eps(eltype(v)) + n2 = sum(abs2, v; dims=2) + + @test (og.weight .* v ./ sqrt.(n2 .+ ϵ)) ≈ g.g + @test (og.weight .* mn.g ./ n2 .- mn.g .* g.g .* v ./ n2.^2) ≈ g.layer.weight atol=1f-6 # Test WeightNorm removal.