Skip to content

Commit

Permalink
Merge pull request #1925 from theabhirath/diag-act
Browse files Browse the repository at this point in the history
Allow activation function for Diagonal
  • Loading branch information
CarloLucibello authored Apr 2, 2022
2 parents a0b804a + 9ab71f7 commit ebda582
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 24 deletions.
27 changes: 14 additions & 13 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,8 @@ end
@functor Dense

function (a::Dense)(x::AbstractVecOrMat)
W, b = a.weight, a.bias
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
return σ.(W*x .+ b)
return σ.(a.weight * x .+ a.bias)
end

(a::Dense)(x::AbstractArray) =
Expand All @@ -172,35 +171,37 @@ function Base.show(io::IO, l::Dense)
end

"""
Diagonal(size::Integer...; bias=true, init=ones32)
Diagonal(scale::AbstractArray, [bias])
Diagonal(size::Integer...; σ = identity, bias=true, init=ones32)
Diagonal(scale::AbstractArray, [bias, activation])
Create an element-wise linear layer, which performs
y = scale .* x .+ bias
y = σ.(scale .* x .+ bias)
with no activation function.
The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`,
with `init=ones32` by default. You may specify the function `init`,
turn off trainable bias with `bias=false`, or provide the array(s) explicitly.
Used by [`LayerNorm`](@ref).
"""
struct Diagonal{A<:AbstractArray, B}
struct Diagonal{A<:AbstractArray, B, F}
scale::A
bias::B
function Diagonal(W::M, bias = true) where M<:AbstractArray
σ::F
function Diagonal(W::M, bias = true, σ::F = identity) where {M<:AbstractArray, F}
b = create_bias(W, bias, size(W)...)
new{M, typeof(b)}(W, b)
new{M, typeof(b), F}(W, b, σ)
end
end

Diagonal(sz::Integer...; bias = true, init = ones32) = Diagonal(init(sz...), bias)
Diagonal(sz::Integer...; σ = identity, bias = true, init = ones32) = Diagonal(init(sz...), bias, σ)

@functor Diagonal

(a::Diagonal)(x) = a.scale .* x .+ a.bias
function (a::Diagonal)(x::AbstractArray)
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
return σ === typeof(identity) ? a.scale .* x .+ a.bias : σ.(a.scale .* x .+ a.bias)
end

function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", join(size(l.scale), ", "))
Expand All @@ -212,7 +213,7 @@ end
Maxout(layers...)
Maxout(f, n_alts)
This contains a number of internal layes, each of which receives the same input.
This contains a number of internal layers, each of which receives the same input.
Its output is the elementwise maximum of the the internal layers' outputs.
Instead of defining layers individually, you can provide a zero-argument function
Expand Down
7 changes: 2 additions & 5 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,13 @@ struct LayerNorm{F,D,T,N}
end

function LayerNorm(sz, λ=identity; affine::Bool=true, ϵ::Real=1f-5)
diag = affine ? Diagonal(sz...) : identity
diag = affine ? Diagonal(sz...; σ = λ) : Base.Fix1(broadcast, λ)
return LayerNorm(λ, diag, ϵ, Tuple(sz), affine)
end

@functor LayerNorm

function (a::LayerNorm)(x)
x = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
return a.λ === identity ? x : a.λ.(x)
end
(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))

function Base.show(io::IO, l::LayerNorm)
print(io, "LayerNorm($(l.size)")
Expand Down
11 changes: 5 additions & 6 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,18 @@ import Flux: activations

@testset "Diagonal" begin
@test length(Flux.Diagonal(10)(randn(10))) == 10
@test length(Flux.Diagonal(10)(1)) == 10
@test length(Flux.Diagonal(10)(randn(1))) == 10
@test length(Flux.Diagonal(10; bias = false)(randn(10))) == 10
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))

@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
@test Flux.Diagonal(2)([1,2]) == [1,2]
@test Flux.Diagonal(2)([1, 2]) == [1, 2]
@test Flux.Diagonal(2; bias = false)([1 2; 3 4]) == [1 2; 3 4]

@test Flux.Diagonal(2)(rand(2,3,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2,3)(rand(2,3,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2, 3, 4; bias = false)(rand(2,3,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2, 3; bias = false)(rand(2,1,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2)(rand(2, 3, 4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2, 3;)(rand(2, 3, 4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2, 3, 4; bias = false)(rand(2, 3, 4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2, 3; bias = false)(rand(2, 1, 4)) |> size == (2, 3, 4)
end

@testset "Maxout" begin
Expand Down

0 comments on commit ebda582

Please sign in to comment.