diff --git a/NEWS.md b/NEWS.md index b0a2a40136..2a40a64aec 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,9 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release. +## v0.15.3 +* Add `WeightNorm` normalization layer. + ## v0.15.0 (December 2024) This release includes two **breaking changes**: - The recurrent layers have been thoroughly revised. See below and read the [documentation](https://fluxml.ai/Flux.jl/v0.15/guide/models/recurrence/) for details. diff --git a/docs/src/reference/models/layers.md b/docs/src/reference/models/layers.md index 355d3e7833..b57141ea2b 100644 --- a/docs/src/reference/models/layers.md +++ b/docs/src/reference/models/layers.md @@ -126,6 +126,8 @@ AlphaDropout LayerNorm InstanceNorm GroupNorm +WeightNorm +Flux.remove_weight_norms Flux.normalise ``` diff --git a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl index 85ce0365cb..2398b322d5 100644 --- a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl +++ b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl @@ -3,8 +3,7 @@ module FluxAMDGPUExt import ChainRulesCore import ChainRulesCore: NoTangent import Flux -import Flux: adapt_storage, fmap -import Flux: DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias +import Flux: fmap, DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias import NNlib using MLDataDevices using AMDGPU diff --git a/src/Flux.jl b/src/Flux.jl index 8fb2351aa2..7041472a6b 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -42,7 +42,7 @@ export Chain, Dense, Embedding, EmbeddingBag, SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv, AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, Dropout, AlphaDropout, - LayerNorm, BatchNorm, InstanceNorm, GroupNorm, + LayerNorm, BatchNorm, InstanceNorm, GroupNorm, WeightNorm, MultiHeadAttention, Upsample, PixelShuffle, fmap, cpu, gpu, f32, f64, f16, rand32, randn32, zeros32, ones32, @@ -94,6 +94,7 @@ export Chain, Dense, Embedding, EmbeddingBag, siamese_contrastive_loss, squared_hinge_loss, tversky_loss, + remove_weight_norms, )) include("gradient.jl") diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 85cece9477..9cd0ff5c8b 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -568,3 +568,127 @@ scale parameters, `false` otherwise. See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref). """ hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine + +struct WeightNorm{L, G, D} + layer::L + g::G + + which::Symbol + dims::D +end +@layer WeightNorm + +""" + WeightNorm(layer::L, which::Symbol = :weight; dims = -1) + +Apply weight normalization to a parameter given by `which` in a `layer`. + +``w = g \\frac{\\mathbf{v}}{\\lVert \\mathbf{v} \\rVert}`` + +Decouples the magnitude of a weight tensor from its direction. +By default, normalization is applied along the output channel `dim=-1` +(equivalent to `dims=ndims(w)`). + +### Example + +```jldoctest +julia> c = Conv((3,), 1 => 2); + +julia> wc = WeightNorm(c, :weight) +WeightNorm( + Conv((3,), 1 => 2), # 8 parameters + 3×1×1 Array{Float32,...}, # 3 parameters + :weight, + 3, +) # Total: 3 arrays, 11 parameters, 276 bytes. + +julia> x = ones(Float32, 12, 1, 1); + +julia> c(x) ≈ wc(x) # forward pass is the same as with the original layer +true +``` + +# Reference + +Salimans & Kingma, _Weight Normalization_ (2016) +""" +function WeightNorm(layer::L, which::Symbol = :weight; dims = -1) where L + hasfield(L, which) || throw(ArgumentError("`$L` does not have field `:$which`.")) + + x = getfield(layer, which) + iszero(x) && throw(ArgumentError( + "`$which` field for `$(typeof(layer))` is all zero, which will result in NaN.")) + + d = if dims isa Colon + 1:ndims(x) + elseif dims == -1 + dims = ndims(x) + else + dims + end + + g = sqrt.(sum(abs2, x; dims) .+ eps(eltype(x))) + WeightNorm(layer, g, which, dims) +end + +(w::WeightNorm)(x) = reparametrize(w)(x) + +""" + reparametrize(wn::WeightNorm) + +Apply `WeightNorm` reparametrization and return underlying `layer`. +""" +function reparametrize(wn::WeightNorm) + ϵ = eps(eltype(wn.g)) + v = getfield(wn.layer, wn.which) + n2 = sum(abs2, v; wn.dims) + w = @. wn.g * v / sqrt(n2 + ϵ) + + fields, ctor = Functors.functor(wn.layer) + return ctor(merge( + fields, NamedTuple{(wn.which,)}((w,)), + )) +end + +function Base.show(io::IO, w::WeightNorm) + print(io, "WeightNorm(") + Base.show(io, w.layer) + print(io, ", :", w.which, "; dims=", w.dims) + print(io, ")") +end + +""" + remove_weight_norms(x) + +Remove any [WeightNorm](@ref) parametrization in the model. + +### Example + +```jldoctest +julia> model = Chain( + WeightNorm(Conv((3,), 1 => 2), :weight), + WeightNorm(Conv((3,), 2 => 2), :weight), + ) +Chain( + WeightNorm( + Conv((3,), 1 => 2), # 8 parameters + 3×1×1 Array{Float32,...}, # 3 parameters + :weight, + 3, + ), + WeightNorm( + Conv((3,), 2 => 2), # 14 parameters + 3×2×1 Array{Float32,...}, # 6 parameters + :weight, + 3, + ), +) # Total: 6 arrays, 31 parameters, 588 bytes. + +julia> Flux.remove_weight_norms(model) +Chain( + Conv((3,), 1 => 2), # 8 parameters + Conv((3,), 2 => 2), # 14 parameters +) # Total: 4 arrays, 22 parameters, 392 bytes. +``` +""" +remove_weight_norms(x) = fmap(reparametrize, x; exclude=l -> l isa WeightNorm) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 9386a3fc2d..40864be294 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -103,6 +103,7 @@ x = rand(Float32, 10) # Run forward res = rnn(x, h0) +``` """ initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2)) diff --git a/test/runtests.jl b/test/runtests.jl index f9936fd3ae..d3476cc4d1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,8 +25,20 @@ include("test_utils.jl") # for test_gradients Random.seed!(0) +include("testsuite/normalization.jl") + +function flux_testsuite(dev) + @testset "Flux Test Suite" begin + @testset "Normalization" begin + normalization_testsuite(dev) + end + end +end + @testset verbose=true "Flux.jl" begin if get(ENV, "FLUX_TEST_CPU", "true") == "true" + flux_testsuite(cpu) + @testset "Utils" begin include("utils.jl") end @@ -84,6 +96,8 @@ Random.seed!(0) if CUDA.functional() @testset "CUDA" begin include("ext_cuda/runtests.jl") + + flux_testsuite(gpu) end else @warn "CUDA.jl package is not functional. Skipping CUDA tests." @@ -99,6 +113,8 @@ Random.seed!(0) if AMDGPU.functional() && AMDGPU.functional(:MIOpen) @testset "AMDGPU" begin include("ext_amdgpu/runtests.jl") + + flux_testsuite(gpu) end else @info "AMDGPU.jl package is not functional. Skipping AMDGPU tests." @@ -114,6 +130,8 @@ Random.seed!(0) if Metal.functional() @testset "Metal" begin include("ext_metal/runtests.jl") + + flux_testsuite(gpu) end else @info "Metal.jl package is not functional. Skipping Metal tests." diff --git a/test/testsuite/normalization.jl b/test/testsuite/normalization.jl new file mode 100644 index 0000000000..ac84154532 --- /dev/null +++ b/test/testsuite/normalization.jl @@ -0,0 +1,68 @@ +function normalization_testsuite(dev) + @testset "WeightNorm" begin + x = rand(Float32, 1, 3) |> dev + mn = WeightNorm(Dense(1 => 2)) |> dev + m = Flux.remove_weight_norms(mn) + @test m(x) ≈ mn(x) + + @test_throws ArgumentError WeightNorm(m, :weights) + @test_throws "does not have field" WeightNorm(m, :weights) + + @test_throws ArgumentError WeightNorm(m, :bias) + @test_throws "is all zero" WeightNorm(m, :bias) + + og = (Zygote.gradient(m) do m + sum(m(x)) + end)[1] + g = (Zygote.gradient(mn) do mn + sum(mn(x)) + end)[1] + + @test g.layer.weight ≢ nothing # Original weight acts as a direction `v`. + @test g.layer.bias ≢ nothing + @test g.g ≢ nothing + + # Compare gradients with original layer. + + v = mn.layer.weight + ϵ = eps(eltype(v)) + n2 = sum(abs2, v; dims=2) + v = v ./ sqrt.(n2 .+ ϵ) + + @test (og.weight .* v) ≈ g.g + @test (og.weight .* mn.g .- mn.g .* g.g .* v) ≈ g.layer.weight atol=1f-6 + + # Test WeightNorm removal. + + om = Flux.remove_weight_norms(mn) + @test om isa Dense + @test om.weight ≈ m.weight + @test om.bias ≈ m.bias + + # Test with Chain. + + c = Chain( + WeightNorm(Conv((3,), 1 => 2)), + Conv((3,), 2 => 2), + WeightNorm(Conv((3,), 2 => 3)), + x -> reshape(x, 18, :), + WeightNorm(Dense(18, 4)), + Dense(4, 1), + ) + @test c[1] isa WeightNorm + @test c[2] isa Conv + @test c[3] isa WeightNorm + @test c[5] isa WeightNorm + @test c[6] isa Dense + + oc = Flux.remove_weight_norms(c) + @test oc[1] isa Conv + @test oc[2] isa Conv + @test oc[3] isa Conv + @test oc[5] isa Dense + @test oc[6] isa Dense + + x = rand(Float32, 12, 1, 1) + @test c(x) ≈ oc(x) + end +end