diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e76efea60a..3374e2fc1c 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -156,9 +156,8 @@ end @functor Dense function (a::Dense)(x::AbstractVecOrMat) - W, b = a.weight, a.bias σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc - return σ.(W*x .+ b) + return σ.(a.weight * x .+ a.bias) end (a::Dense)(x::AbstractArray) = @@ -172,35 +171,37 @@ function Base.show(io::IO, l::Dense) end """ - Diagonal(size::Integer...; bias=true, init=ones32) - Diagonal(scale::AbstractArray, [bias]) + Diagonal(size::Integer...; σ = identity, bias=true, init=ones32) + Diagonal(scale::AbstractArray, [bias, activation]) Create an element-wise linear layer, which performs - y = scale .* x .+ bias + y = σ.(scale .* x .+ bias) -with no activation function. - The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`, with `init=ones32` by default. You may specify the function `init`, turn off trainable bias with `bias=false`, or provide the array(s) explicitly. Used by [`LayerNorm`](@ref). """ -struct Diagonal{A<:AbstractArray, B} +struct Diagonal{A<:AbstractArray, B, F} scale::A bias::B - function Diagonal(W::M, bias = true) where M<:AbstractArray + σ::F + function Diagonal(W::M, bias = true, σ::F = identity) where {M<:AbstractArray, F} b = create_bias(W, bias, size(W)...) - new{M, typeof(b)}(W, b) + new{M, typeof(b), F}(W, b, σ) end end -Diagonal(sz::Integer...; bias = true, init = ones32) = Diagonal(init(sz...), bias) +Diagonal(sz::Integer...; σ = identity, bias = true, init = ones32) = Diagonal(init(sz...), bias, σ) @functor Diagonal -(a::Diagonal)(x) = a.scale .* x .+ a.bias +function (a::Diagonal)(x::AbstractArray) + σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc + return σ === typeof(identity) ? a.scale .* x .+ a.bias : σ.(a.scale .* x .+ a.bias) +end function Base.show(io::IO, l::Diagonal) print(io, "Diagonal(", join(size(l.scale), ", ")) @@ -212,7 +213,7 @@ end Maxout(layers...) Maxout(f, n_alts) -This contains a number of internal layes, each of which receives the same input. +This contains a number of internal layers, each of which receives the same input. Its output is the elementwise maximum of the the internal layers' outputs. Instead of defining layers individually, you can provide a zero-argument function diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index f9d8bcfec4..0324b4b90b 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -165,16 +165,13 @@ struct LayerNorm{F,D,T,N} end function LayerNorm(sz, λ=identity; affine::Bool=true, ϵ::Real=1f-5) - diag = affine ? Diagonal(sz...) : identity + diag = affine ? Diagonal(sz...; σ = λ) : Base.Fix1(broadcast, λ) return LayerNorm(λ, diag, ϵ, Tuple(sz), affine) end @functor LayerNorm -function (a::LayerNorm)(x) - x = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ)) - return a.λ === identity ? x : a.λ.(x) -end +(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ)) function Base.show(io::IO, l::LayerNorm) print(io, "LayerNorm($(l.size)") diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 0c12b22d11..8c8e15422d 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -89,19 +89,18 @@ import Flux: activations @testset "Diagonal" begin @test length(Flux.Diagonal(10)(randn(10))) == 10 - @test length(Flux.Diagonal(10)(1)) == 10 @test length(Flux.Diagonal(10)(randn(1))) == 10 @test length(Flux.Diagonal(10; bias = false)(randn(10))) == 10 @test_throws DimensionMismatch Flux.Diagonal(10)(randn(2)) @test Flux.Diagonal(2)([1 2]) == [1 2; 1 2] - @test Flux.Diagonal(2)([1,2]) == [1,2] + @test Flux.Diagonal(2)([1, 2]) == [1, 2] @test Flux.Diagonal(2; bias = false)([1 2; 3 4]) == [1 2; 3 4] - @test Flux.Diagonal(2)(rand(2,3,4)) |> size == (2, 3, 4) - @test Flux.Diagonal(2,3)(rand(2,3,4)) |> size == (2, 3, 4) - @test Flux.Diagonal(2, 3, 4; bias = false)(rand(2,3,4)) |> size == (2, 3, 4) - @test Flux.Diagonal(2, 3; bias = false)(rand(2,1,4)) |> size == (2, 3, 4) + @test Flux.Diagonal(2)(rand(2, 3, 4)) |> size == (2, 3, 4) + @test Flux.Diagonal(2, 3;)(rand(2, 3, 4)) |> size == (2, 3, 4) + @test Flux.Diagonal(2, 3, 4; bias = false)(rand(2, 3, 4)) |> size == (2, 3, 4) + @test Flux.Diagonal(2, 3; bias = false)(rand(2, 1, 4)) |> size == (2, 3, 4) end @testset "Maxout" begin