Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WeightNorm reparametrization #2550

Merged
merged 12 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.

## v0.15.3
* Add `WeightNorm` normalization layer.

## v0.15.0 (December 2024)
This release includes two **breaking changes**:
- The recurrent layers have been thoroughly revised. See below and read the [documentation](https://fluxml.ai/Flux.jl/v0.15/guide/models/recurrence/) for details.
Expand Down
2 changes: 2 additions & 0 deletions docs/src/reference/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ AlphaDropout
LayerNorm
InstanceNorm
GroupNorm
WeightNorm
Flux.remove_weight_norms
Flux.normalise
```

Expand Down
3 changes: 1 addition & 2 deletions ext/FluxAMDGPUExt/FluxAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ module FluxAMDGPUExt
import ChainRulesCore
import ChainRulesCore: NoTangent
import Flux
import Flux: adapt_storage, fmap
import Flux: DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias
import Flux: fmap, DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias
import NNlib
using MLDataDevices
using AMDGPU
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export Chain, Dense, Embedding, EmbeddingBag,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Dropout, AlphaDropout,
LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
LayerNorm, BatchNorm, InstanceNorm, GroupNorm, WeightNorm,
MultiHeadAttention,
Upsample, PixelShuffle,
fmap, cpu, gpu, f32, f64, f16, rand32, randn32, zeros32, ones32,
Expand Down
114 changes: 114 additions & 0 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -568,3 +568,117 @@
See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref).
"""
hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine

struct WeightNorm{which, dims, L, G}
layer::L
g::G
end
@layer WeightNorm

"""
WeightNorm(layer::L, which::Symbol = :weight; dims = -1)

Apply weight normalization to a parameter given by `which` in a `layer`.

``w = g \\frac{\\mathbf{v}}{\\lVert \\mathbf{v} \\rVert}``

Decouples the magnitude of a weight tensor from its direction.
By default, normalization is applied along the output channel `dim=-1`
(equivalent to `dims=ndims(w)`).

### Example

```jldoctest
julia> c = Conv((3,), 1 => 2);

julia> wc = WeightNorm(c, :weight)
WeightNorm(
Conv((3,), 1 => 2), # 8 parameters
3×1×1 Array{Float32,...}, # 3 parameters
3×1×2 Array{Float32,...}, # 6 parameters
) # Total: 4 arrays, 17 parameters, 348 bytes.

julia> x = ones(Float32, 12, 1, 1);

julia> c(x) ≈ wc(x) # forward pass is the same as with the original layer
true
```

# Reference

Salimans & Kingma, _Weight Normalization_ (2016) <https://arxiv.org/abs/1602.07868>
"""
function WeightNorm(layer::L, which::Symbol = :weight; dims = -1) where L
hasfield(L, which) || throw(ArgumentError("`$L` does not have field `:$which`."))

Check warning on line 612 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L611-L612

Added lines #L611 - L612 were not covered by tests

x = getfield(layer, which)
iszero(x) && throw(ArgumentError(

Check warning on line 615 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L614-L615

Added lines #L614 - L615 were not covered by tests
"`$which` field for `$(typeof(layer))` is all zero, which will result in NaN."))

d = if dims isa Colon
1:ndims(x)
elseif dims == -1
dims = ndims(x)

Check warning on line 621 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L618-L621

Added lines #L618 - L621 were not covered by tests
else
dims

Check warning on line 623 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L623

Added line #L623 was not covered by tests
end

g = sqrt.(sum(abs2, x; dims) .+ eps(eltype(x)))
x ./= g # Store `v` in the original weights.
WeightNorm{which, dims, L, typeof(g)}(layer, g)

Check warning on line 628 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L626-L628

Added lines #L626 - L628 were not covered by tests
end

(w::WeightNorm)(x) = transform(w)(x)

Check warning on line 631 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L631

Added line #L631 was not covered by tests

function transform(wn::WeightNorm{which, dims}) where {which, dims}
ϵ = eps(eltype(wn.g))
v = getfield(wn.layer, which)
n2 = sum(abs2, v; dims)
w = @. wn.g * v / sqrt(n2 + ϵ)

Check warning on line 637 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L633-L637

Added lines #L633 - L637 were not covered by tests

fields, ctor = Functors.functor(wn.layer)
return ctor(merge(

Check warning on line 640 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L639-L640

Added lines #L639 - L640 were not covered by tests
fields, NamedTuple{(which,)}((w,)),
))
end

function Base.show(io::IO, w::WeightNorm{which, dims}) where {which, dims}
print(io, "WeightNorm(")
Base.show(io, w.layer)
print(io, ", :", which, "; dims=", dims)
print(io, ")")

Check warning on line 649 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L645-L649

Added lines #L645 - L649 were not covered by tests
end

"""
remove_weight_norms(x)

Remove any [WeightNorm](@ref) parametrization in the model.

### Example

```jldoctest
julia> model = Chain(
WeightNorm(Conv((3,), 1 => 2), :weight),
WeightNorm(Conv((3,), 2 => 2), :weight),
)
Chain(
WeightNorm(
Conv((3,), 1 => 2), # 8 parameters
3×1×1 Array{Float32,...}, # 3 parameters
3×1×2 Array{Float32,...}, # 6 parameters
),
WeightNorm(
Conv((3,), 2 => 2), # 14 parameters
3×2×1 Array{Float32,...}, # 6 parameters
3×2×2 Array{Float32,...}, # 12 parameters
),
) # Total: 8 arrays, 49 parameters, 756 bytes.

julia> Flux.remove_weight_norms(model)
Chain(
Conv((3,), 1 => 2), # 8 parameters
Conv((3,), 2 => 2), # 14 parameters
) # Total: 4 arrays, 22 parameters, 392 bytes.
```
"""
remove_weight_norms(x) = fmap(transform, x; exclude=l -> l isa WeightNorm)

Check warning on line 684 in src/layers/normalise.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/normalise.jl#L684

Added line #L684 was not covered by tests
1 change: 1 addition & 0 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ x = rand(Float32, 10)

# Run forward
res = rnn(x, h0)
```
"""
initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))

Expand Down
56 changes: 56 additions & 0 deletions test/layers/normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,62 @@ end
@test_throws Exception GroupNorm(5, 5; active=:something_else)
end

@testset "WeightNorm" begin
x = rand(Float32, 1, 3)
mn = WeightNorm(Dense(1 => 2))
m = Flux.remove_weight_norms(mn)
@test m(x) ≈ mn(x)

@test_throws ArgumentError WeightNorm(m, :weights)
@test_throws "does not have field" WeightNorm(m, :weights)

@test_throws ArgumentError WeightNorm(m, :bias)
@test_throws "is all zero" WeightNorm(m, :bias)

og = (Zygote.gradient(m) do m
sum(m(x))
end)[1]
g = (Zygote.gradient(mn) do mn
sum(mn(x))
end)[1]

@test g.layer.weight ≢ nothing # Original weight acts as a direction `v`.
@test g.layer.bias ≢ nothing
@test g.g ≢ nothing

# Compare gradients with original layer.

v = mn.layer.weight
ϵ = eps(eltype(v))
n2 = sum(abs2, v; dims=2)

@test (og.weight .* v ./ sqrt.(n2 .+ ϵ)) ≈ g.g
@test (og.weight .* mn.g ./ n2 .- mn.g .* g.g .* v ./ n2.^2) ≈ g.layer.weight atol=1f-6

# Test WeightNorm removal.

om = Flux.remove_weight_norms(mn)
@test om isa Dense
@test om.weight ≈ m.weight
@test om.bias ≈ m.bias

# Test with Chain.

c = Chain(
WeightNorm(Conv((3,), 1 => 2)),
WeightNorm(Conv((3,), 2 => 2)),
)
@test c[1] isa WeightNorm
@test c[2] isa WeightNorm

oc = Flux.remove_weight_norms(c)
@test oc[1] isa Conv
@test oc[2] isa Conv

x = rand(Float32, 12, 1, 1)
@test c(x) ≈ oc(x)
end

@testset "second derivatives" begin
m1 = Dropout(0.5)
@test Zygote.hessian_reverse(sum∘m1, [1.0,2.0,3.0]) == zeros(3, 3)
Expand Down
Loading