Skip to content

Commit

Permalink
Rename Diagonal to Scale (#1927)
Browse files Browse the repository at this point in the history
* rename Diagonal to Scale

* fix a test

* types etc

* spaces
  • Loading branch information
mcabbott authored Apr 5, 2022
1 parent 6405ab3 commit 5f17f1c
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 38 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ been removed in favour of MLDatasets.jl.
* Many utily functions and the `DataLoader` are [now provided by MLUtils.jl](https://github.com/FluxML/Flux.jl/pull/1874).
* The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`.
* Added [truncated normal initialisation](https://github.com/FluxML/Flux.jl/pull/1877) of weights.
* The `Flux.Diagonal` layer is now called `Scale`, and accepts an activation function.

## v0.12.10
* `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838)
Expand Down
2 changes: 1 addition & 1 deletion docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Maxout
SkipConnection
Parallel
Flux.Bilinear
Flux.Diagonal
Flux.Scale
Flux.Embedding
```

Expand Down
9 changes: 9 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ function Optimise.update!(x::AbstractArray, x̄)
x .-=
end

function Diagonal(size::Integer...; kw...)
Base.depwarn("Flux.Diagonal is now Flux.Scale, and also allows an activation function.", :Diagonal)
Scale(size...; kw...)
end
function Diagonal(size::Tuple; kw...)
Base.depwarn("Flux.Diagonal is now Flux.Scale, and also allows an activation function.", :Diagonal)
Scale(size...; kw...)
end

# Channel notation: Changed to match Conv, but very softly deprecated!
# Perhaps change to @deprecate for v0.14, but there is no plan to remove these.
Dense(in::Integer, out::Integer, σ = identity; kw...) =
Expand Down
57 changes: 43 additions & 14 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,40 +171,69 @@ function Base.show(io::IO, l::Dense)
end

"""
Diagonal(size::Integer...; σ = identity, bias=true, init=ones32)
Diagonal(scale::AbstractArray, [bias, activation])
Scale(size::Integer..., σ=identity; bias=true, init=ones32)
Scale(scale::AbstractArray, [bias, σ])
Create an element-wise linear layer, which performs
Create an element-wise layer, whose forward pass is given by:
y = σ.(scale .* x .+ bias)
This uses `.*` instead of matrix multiplication `*` of [`Dense`](@ref).
The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`,
with `init=ones32` by default. You may specify the function `init`,
turn off trainable bias with `bias=false`, or provide the array(s) explicitly.
Used by [`LayerNorm`](@ref).
Used by [`LayerNorm`](@ref) with `affine=true`.
# Examples
```jldoctest
julia> a = Flux.Scale(2)
Scale(2) # 4 parameters
julia> Flux.params(a)
Params([Float32[1.0, 1.0], Float32[0.0, 0.0]])
julia> a([1 2 3])
2×3 Matrix{Float32}:
1.0 2.0 3.0
1.0 2.0 3.0
julia> b = Flux.Scale([1 2 3 4], false, abs2)
Scale(1, 4, abs2; bias=false) # 4 parameters
julia> b([1, 10])
2×4 Matrix{Int64}:
1 4 9 16
100 400 900 1600
julia> Flux.params(b)
Params([[1 2 3 4]])
```
"""
struct Diagonal{A<:AbstractArray, B, F}
struct Scale{F, A<:AbstractArray, B}
scale::A
bias::B
σ::F
function Diagonal(W::M, bias = true, σ::F = identity) where {M<:AbstractArray, F}
b = create_bias(W, bias, size(W)...)
new{M, typeof(b), F}(W, b, σ)
function Scale(scale::A, bias::B = true, σ::F = identity) where {A<:AbstractArray, B<:Union{Bool, AbstractArray}, F}
b = create_bias(scale, bias, size(scale)...)
new{F, A, typeof(b)}(scale, b, σ)
end
end

Diagonal(sz::Integer...; σ = identity, bias = true, init = ones32) = Diagonal(init(sz...), bias, σ)
Scale(s1::Integer, s23::Integer...; bias = true, init = ones32, _act = identity) = Scale(init(s1, s23...), bias, _act)
Scale(size_act...; bias = true, init = ones32) = Scale(size_act[1:end-1]...; bias, init, _act = size_act[end])

@functor Diagonal
@functor Scale

function (a::Diagonal)(x::AbstractArray)
function (a::Scale)(x::AbstractArray)
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
return σ === typeof(identity) ? a.scale .* x .+ a.bias : σ.(a.scale .* x .+ a.bias)
σ.(a.scale .* x .+ a.bias)
end

function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", join(size(l.scale), ", "))
function Base.show(io::IO, l::Scale)
print(io, "Scale(", join(size(l.scale), ", "))
l.σ == identity || print(io, ", ", l.σ)
l.bias == false && print(io, "; bias=false")
print(io, ")")
end
Expand Down
16 changes: 9 additions & 7 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ testmode!(m::AlphaDropout, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)

"""
LayerNorm(sz, λ=identity; affine=true, ϵ=1fe-5)
LayerNorm(size..., λ=identity; affine=true, ϵ=1fe-5)
A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be
used with recurrent hidden states.
Expand All @@ -151,10 +151,10 @@ for tuple `sz`, along the first dimension for integer `sz`.
The input is expected to have first dimensions' size equal to `sz`.
If `affine=true` also applies a learnable shift and rescaling
as in the [`Diagonal`](@ref) layer.
using the [`Scale`](@ref) layer.
Se also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`normalise`](@ref).
See also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`normalise`](@ref).
"""
struct LayerNorm{F,D,T,N}
λ::F
Expand All @@ -164,17 +164,19 @@ struct LayerNorm{F,D,T,N}
affine::Bool
end

function LayerNorm(sz, λ=identity; affine::Bool=true, ϵ::Real=1f-5)
diag = affine ? Diagonal(sz...; σ = λ) : Base.Fix1(broadcast, λ)
return LayerNorm(λ, diag, ϵ, Tuple(sz), affine)
function LayerNorm(size::Tuple{Vararg{Int}}, λ=identity; affine::Bool=true, ϵ::Real=1f-5)
diag = affine ? Scale(size..., λ) : λ!=identity ? Base.Fix1(broadcast, λ) : identity
return LayerNorm(λ, diag, ϵ, size, affine)
end
LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...)
LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...)

@functor LayerNorm

(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))

function Base.show(io::IO, l::LayerNorm)
print(io, "LayerNorm($(l.size)")
print(io, "LayerNorm(", join(l.size, ", "))
l.λ === identity || print(io, ", ", l.λ)
hasaffine(l) || print(io, ", affine=false")
print(io, ")")
Expand Down
2 changes: 1 addition & 1 deletion src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ _show_children(m::Maxout) = m.layers
_show_children(p::Parallel) = (p.connection, p.layers...)

for T in [
:Conv, :ConvTranspose, :CrossCor, :Dense, :Bilinear, :Embedding,
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding,
:BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm,
]
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
Expand Down
37 changes: 23 additions & 14 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,29 @@ import Flux: activations
end
end

@testset "Diagonal" begin
@test length(Flux.Diagonal(10)(randn(10))) == 10
@test length(Flux.Diagonal(10)(randn(1))) == 10
@test length(Flux.Diagonal(10; bias = false)(randn(10))) == 10
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))

@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
@test Flux.Diagonal(2)([1, 2]) == [1, 2]
@test Flux.Diagonal(2; bias = false)([1 2; 3 4]) == [1 2; 3 4]

@test Flux.Diagonal(2)(rand(2, 3, 4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2, 3;)(rand(2, 3, 4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2, 3, 4; bias = false)(rand(2, 3, 4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2, 3; bias = false)(rand(2, 1, 4)) |> size == (2, 3, 4)
@testset "Scale" begin
@test length(Flux.Scale(10)(randn(10))) == 10
@test length(Flux.Scale(10)(randn(1))) == 10
@test length(Flux.Scale(10; bias = false)(randn(10))) == 10
@test length(Flux.Scale(10, tanh)(randn(10))) == 10
@test_throws DimensionMismatch Flux.Scale(10)(randn(2))

@test Flux.Scale(2)([1 2]) == [1 2; 1 2]
@test Flux.Scale(2)([1, 2]) == [1, 2]
@test Flux.Scale(2; init = randn)([1, 2]) != [1, 2]
@test Flux.Scale(2; bias = false)([1 2; 3 4]) == [1 2; 3 4]
@test Flux.Scale(2, abs2; bias = false, init = ones)([1 2; 3 4]) == [1 4; 9 16]

@test Flux.Scale(2)(rand(2, 3, 4)) |> size == (2, 3, 4)
@test Flux.Scale(2, 3;)(rand(2, 3, 4)) |> size == (2, 3, 4)
@test Flux.Scale(2, 3, 4; bias = false)(rand(2, 3, 4)) |> size == (2, 3, 4)
@test Flux.Scale(2, 3; bias = false)(rand(2, 1, 4)) |> size == (2, 3, 4)
@test Flux.Scale(2, 3, tanh; bias = false, init = zeros)(rand(2, 1, 4)) == zeros(2, 3, 4)

@test_throws MethodError Flux.Scale(1.)
@test_throws MethodError Flux.Scale(1., 2.)
@test_throws Exception Flux.Scale()
@test_throws MethodError Flux.Scale(sin)
end

@testset "Maxout" begin
Expand Down
2 changes: 1 addition & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ end
+),
LayerNorm(8)))
@test length(mod_skip) == 6
@test mod_skip[end] isa Flux.Diagonal
@test mod_skip[end] isa Flux.Scale
end

@testset "Patience triggers" begin
Expand Down

0 comments on commit 5f17f1c

Please sign in to comment.