Skip to content

Commit

Permalink
Add WeightNorm reparametrization (#2550)
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th authored Dec 13, 2024
1 parent 2bbd8b3 commit 9050ef0
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 3 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.

## v0.15.3
* Add `WeightNorm` normalization layer.

## v0.15.0 (December 2024)
This release includes two **breaking changes**:
- The recurrent layers have been thoroughly revised. See below and read the [documentation](https://fluxml.ai/Flux.jl/v0.15/guide/models/recurrence/) for details.
Expand Down
2 changes: 2 additions & 0 deletions docs/src/reference/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ AlphaDropout
LayerNorm
InstanceNorm
GroupNorm
WeightNorm
Flux.remove_weight_norms
Flux.normalise
```

Expand Down
3 changes: 1 addition & 2 deletions ext/FluxAMDGPUExt/FluxAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ module FluxAMDGPUExt
import ChainRulesCore
import ChainRulesCore: NoTangent
import Flux
import Flux: adapt_storage, fmap
import Flux: DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias
import Flux: fmap, DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias
import NNlib
using MLDataDevices
using AMDGPU
Expand Down
3 changes: 2 additions & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export Chain, Dense, Embedding, EmbeddingBag,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Dropout, AlphaDropout,
LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
LayerNorm, BatchNorm, InstanceNorm, GroupNorm, WeightNorm,
MultiHeadAttention,
Upsample, PixelShuffle,
fmap, cpu, gpu, f32, f64, f16, rand32, randn32, zeros32, ones32,
Expand Down Expand Up @@ -94,6 +94,7 @@ export Chain, Dense, Embedding, EmbeddingBag,
siamese_contrastive_loss,
squared_hinge_loss,
tversky_loss,
remove_weight_norms,
))

include("gradient.jl")
Expand Down
124 changes: 124 additions & 0 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -568,3 +568,127 @@ scale parameters, `false` otherwise.
See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref).
"""
hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine

struct WeightNorm{L, G, D}
layer::L
g::G

which::Symbol
dims::D
end
@layer WeightNorm

"""
WeightNorm(layer::L, which::Symbol = :weight; dims = -1)
Apply weight normalization to a parameter given by `which` in a `layer`.
``w = g \\frac{\\mathbf{v}}{\\lVert \\mathbf{v} \\rVert}``
Decouples the magnitude of a weight tensor from its direction.
By default, normalization is applied along the output channel `dim=-1`
(equivalent to `dims=ndims(w)`).
### Example
```jldoctest
julia> c = Conv((3,), 1 => 2);
julia> wc = WeightNorm(c, :weight)
WeightNorm(
Conv((3,), 1 => 2), # 8 parameters
3×1×1 Array{Float32,...}, # 3 parameters
:weight,
3,
) # Total: 3 arrays, 11 parameters, 276 bytes.
julia> x = ones(Float32, 12, 1, 1);
julia> c(x) ≈ wc(x) # forward pass is the same as with the original layer
true
```
# Reference
Salimans & Kingma, _Weight Normalization_ (2016) <https://arxiv.org/abs/1602.07868>
"""
function WeightNorm(layer::L, which::Symbol = :weight; dims = -1) where L
hasfield(L, which) || throw(ArgumentError("`$L` does not have field `:$which`."))

x = getfield(layer, which)
iszero(x) && throw(ArgumentError(
"`$which` field for `$(typeof(layer))` is all zero, which will result in NaN."))

d = if dims isa Colon
1:ndims(x)
elseif dims == -1
dims = ndims(x)
else
dims
end

g = sqrt.(sum(abs2, x; dims) .+ eps(eltype(x)))
WeightNorm(layer, g, which, dims)
end

(w::WeightNorm)(x) = reparametrize(w)(x)

"""
reparametrize(wn::WeightNorm)
Apply `WeightNorm` reparametrization and return underlying `layer`.
"""
function reparametrize(wn::WeightNorm)
ϵ = eps(eltype(wn.g))
v = getfield(wn.layer, wn.which)
n2 = sum(abs2, v; wn.dims)
w = @. wn.g * v / sqrt(n2 + ϵ)

fields, ctor = Functors.functor(wn.layer)
return ctor(merge(
fields, NamedTuple{(wn.which,)}((w,)),
))
end

function Base.show(io::IO, w::WeightNorm)
print(io, "WeightNorm(")
Base.show(io, w.layer)
print(io, ", :", w.which, "; dims=", w.dims)
print(io, ")")
end

"""
remove_weight_norms(x)
Remove any [WeightNorm](@ref) parametrization in the model.
### Example
```jldoctest
julia> model = Chain(
WeightNorm(Conv((3,), 1 => 2), :weight),
WeightNorm(Conv((3,), 2 => 2), :weight),
)
Chain(
WeightNorm(
Conv((3,), 1 => 2), # 8 parameters
3×1×1 Array{Float32,...}, # 3 parameters
:weight,
3,
),
WeightNorm(
Conv((3,), 2 => 2), # 14 parameters
3×2×1 Array{Float32,...}, # 6 parameters
:weight,
3,
),
) # Total: 6 arrays, 31 parameters, 588 bytes.
julia> Flux.remove_weight_norms(model)
Chain(
Conv((3,), 1 => 2), # 8 parameters
Conv((3,), 2 => 2), # 14 parameters
) # Total: 4 arrays, 22 parameters, 392 bytes.
```
"""
remove_weight_norms(x) = fmap(reparametrize, x; exclude=l -> l isa WeightNorm)
1 change: 1 addition & 0 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ x = rand(Float32, 10)
# Run forward
res = rnn(x, h0)
```
"""
initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))

Expand Down
18 changes: 18 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,20 @@ include("test_utils.jl") # for test_gradients

Random.seed!(0)

include("testsuite/normalization.jl")

function flux_testsuite(dev)
@testset "Flux Test Suite" begin
@testset "Normalization" begin
normalization_testsuite(dev)
end
end
end

@testset verbose=true "Flux.jl" begin
if get(ENV, "FLUX_TEST_CPU", "true") == "true"
flux_testsuite(cpu)

@testset "Utils" begin
include("utils.jl")
end
Expand Down Expand Up @@ -84,6 +96,8 @@ Random.seed!(0)
if CUDA.functional()
@testset "CUDA" begin
include("ext_cuda/runtests.jl")

flux_testsuite(gpu)
end
else
@warn "CUDA.jl package is not functional. Skipping CUDA tests."
Expand All @@ -99,6 +113,8 @@ Random.seed!(0)
if AMDGPU.functional() && AMDGPU.functional(:MIOpen)
@testset "AMDGPU" begin
include("ext_amdgpu/runtests.jl")

flux_testsuite(gpu)
end
else
@info "AMDGPU.jl package is not functional. Skipping AMDGPU tests."
Expand All @@ -114,6 +130,8 @@ Random.seed!(0)
if Metal.functional()
@testset "Metal" begin
include("ext_metal/runtests.jl")

flux_testsuite(gpu)
end
else
@info "Metal.jl package is not functional. Skipping Metal tests."
Expand Down
68 changes: 68 additions & 0 deletions test/testsuite/normalization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
function normalization_testsuite(dev)
@testset "WeightNorm" begin
x = rand(Float32, 1, 3) |> dev
mn = WeightNorm(Dense(1 => 2)) |> dev
m = Flux.remove_weight_norms(mn)
@test m(x) mn(x)

@test_throws ArgumentError WeightNorm(m, :weights)
@test_throws "does not have field" WeightNorm(m, :weights)

@test_throws ArgumentError WeightNorm(m, :bias)
@test_throws "is all zero" WeightNorm(m, :bias)

og = (Zygote.gradient(m) do m
sum(m(x))
end)[1]
g = (Zygote.gradient(mn) do mn
sum(mn(x))
end)[1]

@test g.layer.weight nothing # Original weight acts as a direction `v`.
@test g.layer.bias nothing
@test g.g nothing

# Compare gradients with original layer.

v = mn.layer.weight
ϵ = eps(eltype(v))
n2 = sum(abs2, v; dims=2)
v = v ./ sqrt.(n2 .+ ϵ)

@test (og.weight .* v) g.g
@test (og.weight .* mn.g .- mn.g .* g.g .* v) g.layer.weight atol=1f-6

# Test WeightNorm removal.

om = Flux.remove_weight_norms(mn)
@test om isa Dense
@test om.weight m.weight
@test om.bias m.bias

# Test with Chain.

c = Chain(
WeightNorm(Conv((3,), 1 => 2)),
Conv((3,), 2 => 2),
WeightNorm(Conv((3,), 2 => 3)),
x -> reshape(x, 18, :),
WeightNorm(Dense(18, 4)),
Dense(4, 1),
)
@test c[1] isa WeightNorm
@test c[2] isa Conv
@test c[3] isa WeightNorm
@test c[5] isa WeightNorm
@test c[6] isa Dense

oc = Flux.remove_weight_norms(c)
@test oc[1] isa Conv
@test oc[2] isa Conv
@test oc[3] isa Conv
@test oc[5] isa Dense
@test oc[6] isa Dense

x = rand(Float32, 12, 1, 1)
@test c(x) oc(x)
end
end

0 comments on commit 9050ef0

Please sign in to comment.