Skip to content

Commit

Permalink
Allow activation function for Diagonal
Browse files Browse the repository at this point in the history
  • Loading branch information
theabhirath committed Apr 2, 2022
1 parent 57beb23 commit c2c6ab7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 19 deletions.
30 changes: 16 additions & 14 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,13 @@ julia> Flux.params(d1) # no trainable bias
Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]])
```
"""
struct Dense{F, M<:AbstractMatrix, B}
struct Dense{M<:AbstractMatrix, B, F}
weight::M
bias::B
σ::F
function Dense(W::M, bias = true, σ::F = identity) where {M<:AbstractMatrix, F}
b = create_bias(W, bias, size(W,1))
new{F,M,typeof(b)}(W, b, σ)
new{M, typeof(b), F}(W, b, σ)
end
end

Expand All @@ -158,7 +158,7 @@ end
function (a::Dense)(x::AbstractVecOrMat)
W, b = a.weight, a.bias
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
return σ.(W*x .+ b)
return σ.(W * x .+ b)
end

(a::Dense)(x::AbstractArray) =
Expand All @@ -172,35 +172,37 @@ function Base.show(io::IO, l::Dense)
end

"""
Diagonal(size::Integer...; bias=true, init=ones32)
Diagonal(scale::AbstractArray, [bias])
Diagonal(size::Integer...; σ = identity, bias=true, init=ones32)
Diagonal(scale::AbstractArray, [bias, activation])
Create an element-wise linear layer, which performs
y = scale .* x .+ bias
y = σ.(scale .* x .+ bias)
with no activation function.
The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`,
with `init=ones32` by default. You may specify the function `init`,
turn off trainable bias with `bias=false`, or provide the array(s) explicitly.
Used by [`LayerNorm`](@ref).
"""
struct Diagonal{A<:AbstractArray, B}
struct Diagonal{A<:AbstractArray, B, F}
scale::A
bias::B
function Diagonal(W::M, bias = true) where M<:AbstractArray
σ::F
function Diagonal(W::M, bias = true, σ::F = identity) where {M<:AbstractArray, F}
b = create_bias(W, bias, size(W)...)
new{M, typeof(b)}(W, b)
new{M, typeof(b), F}(W, b, σ)
end
end

Diagonal(sz::Integer...; bias = true, init = ones32) = Diagonal(init(sz...), bias)
Diagonal(sz::Integer...; σ = identity, bias = true, init = ones32) = Diagonal(init(sz...), bias, σ)

@functor Diagonal

(a::Diagonal)(x) = a.scale .* x .+ a.bias
function (a::Diagonal)(x::AbstractArray)
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
return σ == typeof(identity) ? a.scale .* x .+ a.bias : σ.(a.scale .* x .+ a.bias)
end

function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", join(size(l.scale), ", "))
Expand All @@ -212,7 +214,7 @@ end
Maxout(layers...)
Maxout(f, n_alts)
This contains a number of internal layes, each of which receives the same input.
This contains a number of internal layers, each of which receives the same input.
Its output is the elementwise maximum of the the internal layers' outputs.
Instead of defining layers individually, you can provide a zero-argument function
Expand Down
7 changes: 2 additions & 5 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,13 @@ struct LayerNorm{F,D,T,N}
end

function LayerNorm(sz, λ=identity; affine::Bool=true, ϵ::Real=1f-5)
diag = affine ? Diagonal(sz...) : identity
diag = affine ? Diagonal(sz...; σ = λ) : Base.Fix1(broadcast, λ)
return LayerNorm(λ, diag, ϵ, Tuple(sz), affine)
end

@functor LayerNorm

function (a::LayerNorm)(x)
x = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
return a.λ === identity ? x : a.λ.(x)
end
(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))

function Base.show(io::IO, l::LayerNorm)
print(io, "LayerNorm($(l.size)")
Expand Down

0 comments on commit c2c6ab7

Please sign in to comment.