Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WeightNorm reparametrization #2550

Merged
merged 12 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.

## v0.15.3
* Add `WeightNorm` normalization layer.

## v0.15.0 (December 2024)
This release includes two **breaking changes**:
- The recurrent layers have been thoroughly revised. See below and read the [documentation](https://fluxml.ai/Flux.jl/v0.15/guide/models/recurrence/) for details.
Expand Down
2 changes: 2 additions & 0 deletions docs/src/reference/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ AlphaDropout
LayerNorm
InstanceNorm
GroupNorm
WeightNorm
Flux.remove_weight_norms
Flux.normalise
```

Expand Down
3 changes: 1 addition & 2 deletions ext/FluxAMDGPUExt/FluxAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ module FluxAMDGPUExt
import ChainRulesCore
import ChainRulesCore: NoTangent
import Flux
import Flux: adapt_storage, fmap
import Flux: DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias
import Flux: fmap, DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias
import NNlib
using MLDataDevices
using AMDGPU
Expand Down
3 changes: 2 additions & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export Chain, Dense, Embedding, EmbeddingBag,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Dropout, AlphaDropout,
LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
LayerNorm, BatchNorm, InstanceNorm, GroupNorm, WeightNorm,
pxl-th marked this conversation as resolved.
Show resolved Hide resolved
MultiHeadAttention,
Upsample, PixelShuffle,
fmap, cpu, gpu, f32, f64, f16, rand32, randn32, zeros32, ones32,
Expand Down Expand Up @@ -94,6 +94,7 @@ export Chain, Dense, Embedding, EmbeddingBag,
siamese_contrastive_loss,
squared_hinge_loss,
tversky_loss,
remove_weight_norms,
))

include("gradient.jl")
Expand Down
124 changes: 124 additions & 0 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -568,3 +568,127 @@ scale parameters, `false` otherwise.
See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref).
"""
hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine

struct WeightNorm{L, G, D}
layer::L
g::G

which::Symbol
dims::D
end
@layer WeightNorm

"""
WeightNorm(layer::L, which::Symbol = :weight; dims = -1)

Apply weight normalization to a parameter given by `which` in a `layer`.

``w = g \\frac{\\mathbf{v}}{\\lVert \\mathbf{v} \\rVert}``

Decouples the magnitude of a weight tensor from its direction.
By default, normalization is applied along the output channel `dim=-1`
(equivalent to `dims=ndims(w)`).

### Example

```jldoctest
julia> c = Conv((3,), 1 => 2);

julia> wc = WeightNorm(c, :weight)
WeightNorm(
Conv((3,), 1 => 2), # 8 parameters
3×1×1 Array{Float32,...}, # 3 parameters
:weight,
3,
) # Total: 3 arrays, 11 parameters, 276 bytes.

julia> x = ones(Float32, 12, 1, 1);

julia> c(x) ≈ wc(x) # forward pass is the same as with the original layer
true
```

# Reference

Salimans & Kingma, _Weight Normalization_ (2016) <https://arxiv.org/abs/1602.07868>
"""
function WeightNorm(layer::L, which::Symbol = :weight; dims = -1) where L
hasfield(L, which) || throw(ArgumentError("`$L` does not have field `:$which`."))

x = getfield(layer, which)
iszero(x) && throw(ArgumentError(
"`$which` field for `$(typeof(layer))` is all zero, which will result in NaN."))

d = if dims isa Colon
1:ndims(x)
elseif dims == -1
dims = ndims(x)
else
dims
end

g = sqrt.(sum(abs2, x; dims) .+ eps(eltype(x)))
WeightNorm(layer, g, which, dims)
end

(w::WeightNorm)(x) = reparametrize(w)(x)

"""
reparametrize(wn::WeightNorm)

Apply `WeightNorm` reparametrization and return underlying `layer`.
"""
function reparametrize(wn::WeightNorm)
ϵ = eps(eltype(wn.g))
v = getfield(wn.layer, wn.which)
n2 = sum(abs2, v; wn.dims)
w = @. wn.g * v / sqrt(n2 + ϵ)

fields, ctor = Functors.functor(wn.layer)
return ctor(merge(
fields, NamedTuple{(wn.which,)}((w,)),
))
end

function Base.show(io::IO, w::WeightNorm)
print(io, "WeightNorm(")
Base.show(io, w.layer)
print(io, ", :", w.which, "; dims=", w.dims)
print(io, ")")
end

"""
remove_weight_norms(x)

Remove any [WeightNorm](@ref) parametrization in the model.

### Example

```jldoctest
julia> model = Chain(
WeightNorm(Conv((3,), 1 => 2), :weight),
WeightNorm(Conv((3,), 2 => 2), :weight),
)
Chain(
WeightNorm(
Conv((3,), 1 => 2), # 8 parameters
3×1×1 Array{Float32,...}, # 3 parameters
:weight,
3,
),
WeightNorm(
Conv((3,), 2 => 2), # 14 parameters
3×2×1 Array{Float32,...}, # 6 parameters
:weight,
3,
),
) # Total: 6 arrays, 31 parameters, 588 bytes.

julia> Flux.remove_weight_norms(model)
Chain(
Conv((3,), 1 => 2), # 8 parameters
Conv((3,), 2 => 2), # 14 parameters
) # Total: 4 arrays, 22 parameters, 392 bytes.
```
"""
remove_weight_norms(x) = fmap(reparametrize, x; exclude=l -> l isa WeightNorm)
1 change: 1 addition & 0 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ x = rand(Float32, 10)

# Run forward
res = rnn(x, h0)
```
"""
initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))

Expand Down
18 changes: 18 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,20 @@ include("test_utils.jl") # for test_gradients

Random.seed!(0)

include("testsuite/normalization.jl")

function flux_testsuite(dev)
@testset "Flux Test Suite" begin
@testset "Normalization" begin
normalization_testsuite(dev)
end
end
end

@testset verbose=true "Flux.jl" begin
if get(ENV, "FLUX_TEST_CPU", "true") == "true"
flux_testsuite(cpu)

@testset "Utils" begin
include("utils.jl")
end
Expand Down Expand Up @@ -84,6 +96,8 @@ Random.seed!(0)
if CUDA.functional()
@testset "CUDA" begin
include("ext_cuda/runtests.jl")

flux_testsuite(gpu)
end
else
@warn "CUDA.jl package is not functional. Skipping CUDA tests."
Expand All @@ -99,6 +113,8 @@ Random.seed!(0)
if AMDGPU.functional() && AMDGPU.functional(:MIOpen)
@testset "AMDGPU" begin
include("ext_amdgpu/runtests.jl")

flux_testsuite(gpu)
end
else
@info "AMDGPU.jl package is not functional. Skipping AMDGPU tests."
Expand All @@ -114,6 +130,8 @@ Random.seed!(0)
if Metal.functional()
@testset "Metal" begin
include("ext_metal/runtests.jl")

flux_testsuite(gpu)
end
else
@info "Metal.jl package is not functional. Skipping Metal tests."
Expand Down
68 changes: 68 additions & 0 deletions test/testsuite/normalization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
function normalization_testsuite(dev)
@testset "WeightNorm" begin
x = rand(Float32, 1, 3) |> dev
mn = WeightNorm(Dense(1 => 2)) |> dev
m = Flux.remove_weight_norms(mn)
@test m(x) ≈ mn(x)

@test_throws ArgumentError WeightNorm(m, :weights)
@test_throws "does not have field" WeightNorm(m, :weights)

@test_throws ArgumentError WeightNorm(m, :bias)
@test_throws "is all zero" WeightNorm(m, :bias)

og = (Zygote.gradient(m) do m
sum(m(x))
end)[1]
g = (Zygote.gradient(mn) do mn
sum(mn(x))
end)[1]

@test g.layer.weight ≢ nothing # Original weight acts as a direction `v`.
@test g.layer.bias ≢ nothing
@test g.g ≢ nothing

# Compare gradients with original layer.

v = mn.layer.weight
ϵ = eps(eltype(v))
n2 = sum(abs2, v; dims=2)
v = v ./ sqrt.(n2 .+ ϵ)

@test (og.weight .* v) ≈ g.g
@test (og.weight .* mn.g .- mn.g .* g.g .* v) ≈ g.layer.weight atol=1f-6

# Test WeightNorm removal.

om = Flux.remove_weight_norms(mn)
@test om isa Dense
@test om.weight ≈ m.weight
@test om.bias ≈ m.bias

# Test with Chain.

c = Chain(
WeightNorm(Conv((3,), 1 => 2)),
Conv((3,), 2 => 2),
WeightNorm(Conv((3,), 2 => 3)),
x -> reshape(x, 18, :),
WeightNorm(Dense(18, 4)),
Dense(4, 1),
)
@test c[1] isa WeightNorm
@test c[2] isa Conv
@test c[3] isa WeightNorm
@test c[5] isa WeightNorm
@test c[6] isa Dense

oc = Flux.remove_weight_norms(c)
@test oc[1] isa Conv
@test oc[2] isa Conv
@test oc[3] isa Conv
@test oc[5] isa Dense
@test oc[6] isa Dense

x = rand(Float32, 12, 1, 1)
@test c(x) ≈ oc(x)
end
end
Loading