Skip to content

Commit

Permalink
Store 'v' in weights of the original layer
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Dec 12, 2024
1 parent 7f3a905 commit 780a138
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
14 changes: 7 additions & 7 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -569,10 +569,9 @@ See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`Laye
"""
hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine

struct WeightNorm{which, dims, L, G, V}
struct WeightNorm{which, dims, L, G}
layer::L
g::G
v::V
end
@layer WeightNorm

Expand Down Expand Up @@ -625,16 +624,17 @@ function WeightNorm(layer::L, which::Symbol = :weight; dims = -1) where L
end

g = sqrt.(sum(abs2, x; dims) .+ eps(eltype(x)))
v = x ./ g
WeightNorm{which, dims, L, typeof(g), typeof(v)}(layer, g, v)
x ./= g # Store `v` in the original weights.
WeightNorm{which, dims, L, typeof(g)}(layer, g)

Check warning on line 628 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L626-L628

Added lines #L626 - L628 were not covered by tests
end

(w::WeightNorm)(x) = transform(w)(x)

Check warning on line 631 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L631

Added line #L631 was not covered by tests

function transform(wn::WeightNorm{which, dims}) where {which, dims}
ϵ = eps(eltype(wn.v))
n2 = sum(abs2, wn.v; dims)
w = @. wn.g * wn.v / sqrt(n2 + ϵ)
ϵ = eps(eltype(wn.g))
v = getfield(wn.layer, which)
n2 = sum(abs2, v; dims)
w = @. wn.g * v / sqrt(n2 + ϵ)

Check warning on line 637 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L633-L637

Added lines #L633 - L637 were not covered by tests

fields, ctor = Functors.functor(wn.layer)
return ctor(merge(

Check warning on line 640 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L639-L640

Added lines #L639 - L640 were not covered by tests
Expand Down
17 changes: 9 additions & 8 deletions test/layers/normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,8 @@ end

@testset "WeightNorm" begin
x = rand(Float32, 1, 3)
m = Dense(1 => 2)
mn = WeightNorm(m)
mn = WeightNorm(Dense(1 => 2))
m = Flux.remove_weight_norms(mn)
@test m(x) mn(x)

@test_throws ArgumentError WeightNorm(m, :weights)
Expand All @@ -460,17 +460,18 @@ end
sum(mn(x))
end)[1]

@test g.layer.weight nothing # Original weight does not participate.
@test g.layer.weight nothing # Original weight acts as a direction `v`.
@test g.layer.bias nothing
@test g.g nothing
@test g.v nothing

# Compare gradients with original layer.

n2 = sum(abs2, mn.v; dims=2)
ϵ = eps(eltype(mn.v))
@test (og.weight .* mn.v ./ sqrt.(n2 .+ ϵ)) g.g
@test (og.weight .* mn.g ./ n2 .- mn.g .* g.g .* mn.v ./ n2.^2) g.v atol=1f-6
v = mn.layer.weight
ϵ = eps(eltype(v))
n2 = sum(abs2, v; dims=2)

@test (og.weight .* v ./ sqrt.(n2 .+ ϵ)) g.g
@test (og.weight .* mn.g ./ n2 .- mn.g .* g.g .* v ./ n2.^2) g.layer.weight atol=1f-6

# Test WeightNorm removal.

Expand Down

0 comments on commit 780a138

Please sign in to comment.