diff --git a/NEWS.md b/NEWS.md index 1b91c7823c..9c4907dc33 100644 --- a/NEWS.md +++ b/NEWS.md @@ -11,6 +11,7 @@ been removed in favour of MLDatasets.jl. * Many utily functions and the `DataLoader` are [now provided by MLUtils.jl](https://github.com/FluxML/Flux.jl/pull/1874). * The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`. * Added [truncated normal initialisation](https://github.com/FluxML/Flux.jl/pull/1877) of weights. +* The `Flux.Diagonal` layer is now called `Scale`, and accepts an activation function. ## v0.12.10 * `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index d44e291031..81fbb60a2d 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -56,7 +56,7 @@ Maxout SkipConnection Parallel Flux.Bilinear -Flux.Diagonal +Flux.Scale Flux.Embedding ``` diff --git a/src/deprecations.jl b/src/deprecations.jl index bce71bc2ab..597fc5a913 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -39,6 +39,15 @@ function Optimise.update!(x::AbstractArray, x̄) x .-= x̄ end +function Diagonal(size::Integer...; kw...) + Base.depwarn("Flux.Diagonal is now Flux.Scale, and also allows an activation function.", :Diagonal) + Scale(size...; kw...) +end +function Diagonal(size::Tuple; kw...) + Base.depwarn("Flux.Diagonal is now Flux.Scale, and also allows an activation function.", :Diagonal) + Scale(size...; kw...) +end + # Channel notation: Changed to match Conv, but very softly deprecated! # Perhaps change to @deprecate for v0.14, but there is no plan to remove these. Dense(in::Integer, out::Integer, σ = identity; kw...) = diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 3374e2fc1c..eb3e9b5180 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -171,40 +171,69 @@ function Base.show(io::IO, l::Dense) end """ - Diagonal(size::Integer...; σ = identity, bias=true, init=ones32) - Diagonal(scale::AbstractArray, [bias, activation]) + Scale(size::Integer..., σ=identity; bias=true, init=ones32) + Scale(scale::AbstractArray, [bias, σ]) -Create an element-wise linear layer, which performs +Create an element-wise layer, whose forward pass is given by: y = σ.(scale .* x .+ bias) +This uses `.*` instead of matrix multiplication `*` of [`Dense`](@ref). + The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`, with `init=ones32` by default. You may specify the function `init`, turn off trainable bias with `bias=false`, or provide the array(s) explicitly. -Used by [`LayerNorm`](@ref). +Used by [`LayerNorm`](@ref) with `affine=true`. + +# Examples +```jldoctest +julia> a = Flux.Scale(2) +Scale(2) # 4 parameters + +julia> Flux.params(a) +Params([Float32[1.0, 1.0], Float32[0.0, 0.0]]) + +julia> a([1 2 3]) +2×3 Matrix{Float32}: + 1.0 2.0 3.0 + 1.0 2.0 3.0 + +julia> b = Flux.Scale([1 2 3 4], false, abs2) +Scale(1, 4, abs2; bias=false) # 4 parameters + +julia> b([1, 10]) +2×4 Matrix{Int64}: + 1 4 9 16 + 100 400 900 1600 + +julia> Flux.params(b) +Params([[1 2 3 4]]) +``` """ -struct Diagonal{A<:AbstractArray, B, F} +struct Scale{F, A<:AbstractArray, B} scale::A bias::B σ::F - function Diagonal(W::M, bias = true, σ::F = identity) where {M<:AbstractArray, F} - b = create_bias(W, bias, size(W)...) - new{M, typeof(b), F}(W, b, σ) + function Scale(scale::A, bias::B = true, σ::F = identity) where {A<:AbstractArray, B<:Union{Bool, AbstractArray}, F} + b = create_bias(scale, bias, size(scale)...) + new{F, A, typeof(b)}(scale, b, σ) end end -Diagonal(sz::Integer...; σ = identity, bias = true, init = ones32) = Diagonal(init(sz...), bias, σ) +Scale(s1::Integer, s23::Integer...; bias = true, init = ones32, _act = identity) = Scale(init(s1, s23...), bias, _act) +Scale(size_act...; bias = true, init = ones32) = Scale(size_act[1:end-1]...; bias, init, _act = size_act[end]) -@functor Diagonal +@functor Scale -function (a::Diagonal)(x::AbstractArray) +function (a::Scale)(x::AbstractArray) σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc - return σ === typeof(identity) ? a.scale .* x .+ a.bias : σ.(a.scale .* x .+ a.bias) + σ.(a.scale .* x .+ a.bias) end -function Base.show(io::IO, l::Diagonal) - print(io, "Diagonal(", join(size(l.scale), ", ")) +function Base.show(io::IO, l::Scale) + print(io, "Scale(", join(size(l.scale), ", ")) + l.σ == identity || print(io, ", ", l.σ) l.bias == false && print(io, "; bias=false") print(io, ")") end diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 0324b4b90b..1a51615f6f 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -139,7 +139,7 @@ testmode!(m::AlphaDropout, mode=true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) """ - LayerNorm(sz, λ=identity; affine=true, ϵ=1fe-5) + LayerNorm(size..., λ=identity; affine=true, ϵ=1fe-5) A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be used with recurrent hidden states. @@ -151,10 +151,10 @@ for tuple `sz`, along the first dimension for integer `sz`. The input is expected to have first dimensions' size equal to `sz`. If `affine=true` also applies a learnable shift and rescaling -as in the [`Diagonal`](@ref) layer. +using the [`Scale`](@ref) layer. -Se also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`normalise`](@ref). +See also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`normalise`](@ref). """ struct LayerNorm{F,D,T,N} λ::F @@ -164,17 +164,19 @@ struct LayerNorm{F,D,T,N} affine::Bool end -function LayerNorm(sz, λ=identity; affine::Bool=true, ϵ::Real=1f-5) - diag = affine ? Diagonal(sz...; σ = λ) : Base.Fix1(broadcast, λ) - return LayerNorm(λ, diag, ϵ, Tuple(sz), affine) +function LayerNorm(size::Tuple{Vararg{Int}}, λ=identity; affine::Bool=true, ϵ::Real=1f-5) + diag = affine ? Scale(size..., λ) : λ!=identity ? Base.Fix1(broadcast, λ) : identity + return LayerNorm(λ, diag, ϵ, size, affine) end +LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...) +LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...) @functor LayerNorm (a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ)) function Base.show(io::IO, l::LayerNorm) - print(io, "LayerNorm($(l.size)") + print(io, "LayerNorm(", join(l.size, ", ")) l.λ === identity || print(io, ", ", l.λ) hasaffine(l) || print(io, ", affine=false") print(io, ")") diff --git a/src/layers/show.jl b/src/layers/show.jl index 47772f7e72..d393985042 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -55,7 +55,7 @@ _show_children(m::Maxout) = m.layers _show_children(p::Parallel) = (p.connection, p.layers...) for T in [ - :Conv, :ConvTranspose, :CrossCor, :Dense, :Bilinear, :Embedding, + :Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding, :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, ] @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 8c8e15422d..6c9030874d 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -87,20 +87,29 @@ import Flux: activations end end - @testset "Diagonal" begin - @test length(Flux.Diagonal(10)(randn(10))) == 10 - @test length(Flux.Diagonal(10)(randn(1))) == 10 - @test length(Flux.Diagonal(10; bias = false)(randn(10))) == 10 - @test_throws DimensionMismatch Flux.Diagonal(10)(randn(2)) - - @test Flux.Diagonal(2)([1 2]) == [1 2; 1 2] - @test Flux.Diagonal(2)([1, 2]) == [1, 2] - @test Flux.Diagonal(2; bias = false)([1 2; 3 4]) == [1 2; 3 4] - - @test Flux.Diagonal(2)(rand(2, 3, 4)) |> size == (2, 3, 4) - @test Flux.Diagonal(2, 3;)(rand(2, 3, 4)) |> size == (2, 3, 4) - @test Flux.Diagonal(2, 3, 4; bias = false)(rand(2, 3, 4)) |> size == (2, 3, 4) - @test Flux.Diagonal(2, 3; bias = false)(rand(2, 1, 4)) |> size == (2, 3, 4) + @testset "Scale" begin + @test length(Flux.Scale(10)(randn(10))) == 10 + @test length(Flux.Scale(10)(randn(1))) == 10 + @test length(Flux.Scale(10; bias = false)(randn(10))) == 10 + @test length(Flux.Scale(10, tanh)(randn(10))) == 10 + @test_throws DimensionMismatch Flux.Scale(10)(randn(2)) + + @test Flux.Scale(2)([1 2]) == [1 2; 1 2] + @test Flux.Scale(2)([1, 2]) == [1, 2] + @test Flux.Scale(2; init = randn)([1, 2]) != [1, 2] + @test Flux.Scale(2; bias = false)([1 2; 3 4]) == [1 2; 3 4] + @test Flux.Scale(2, abs2; bias = false, init = ones)([1 2; 3 4]) == [1 4; 9 16] + + @test Flux.Scale(2)(rand(2, 3, 4)) |> size == (2, 3, 4) + @test Flux.Scale(2, 3;)(rand(2, 3, 4)) |> size == (2, 3, 4) + @test Flux.Scale(2, 3, 4; bias = false)(rand(2, 3, 4)) |> size == (2, 3, 4) + @test Flux.Scale(2, 3; bias = false)(rand(2, 1, 4)) |> size == (2, 3, 4) + @test Flux.Scale(2, 3, tanh; bias = false, init = zeros)(rand(2, 1, 4)) == zeros(2, 3, 4) + + @test_throws MethodError Flux.Scale(1.) + @test_throws MethodError Flux.Scale(1., 2.) + @test_throws Exception Flux.Scale() + @test_throws MethodError Flux.Scale(sin) end @testset "Maxout" begin diff --git a/test/utils.jl b/test/utils.jl index cc97db4fc9..14b5ad9bbc 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -457,7 +457,7 @@ end +), LayerNorm(8))) @test length(mod_skip) == 6 - @test mod_skip[end] isa Flux.Diagonal + @test mod_skip[end] isa Flux.Scale end @testset "Patience triggers" begin