Skip to content

Commit

Permalink
Add faster tanh implementation(s) (#345)
Browse files Browse the repository at this point in the history
* tanh_fast, and friends

* rm some lines

* Revert "improve relu, leakyrelyu, and add InplaceableThunk"

This reverts commit e6befa1.

* switch to new tanh_new from Remez.jl

* better Float64 version with polynomial

* better coeff, better cutoff

* translate to sigmoid, add tests

* change sigmoid implementation, just fastmath of basic one

* comments

* test bounds

* bad rebase + docstr

* update coeff

* don't say < 2
[skip ci]

* restrict to Float64

* Update src/activations.jl

Co-authored-by: Carlo Lucibello <[email protected]>

Co-authored-by: Carlo Lucibello <[email protected]>
  • Loading branch information
mcabbott and CarloLucibello authored Nov 8, 2021
1 parent df0b067 commit 37093c7
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 10 deletions.
105 changes: 95 additions & 10 deletions src/activations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
# Some of activation functions have its wrapper function for GPU in NNlibCUDA.jl.
# https://github.com/JuliaGPU/CuArrays.jl/issues/614

const ACTIVATIONS =
[, :hardσ, :hardtanh, :relu,
:leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :selu,
:celu, :softplus, :softsign, :logσ, :logcosh,
:mish, :tanhshrink, :softshrink, :trelu,
:lisht]
ACTIVATIONS = [
, :hardσ, :hardtanh, :relu,
:leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :selu,
:celu, :softplus, :softsign, :logσ, :logcosh,
:mish, :tanhshrink, :softshrink, :trelu, :lisht,
:tanh_fast, :sigmoid_fast,
]

for f in ACTIVATIONS
@eval export $(f)
Expand All @@ -28,6 +29,8 @@ function.
Unicode `σ` can be entered as `\\sigma` then tab, in many editors.
The ascii name `sigmoid` is also exported.
See also [`sigmoid_fast`](@ref).
```
julia> lineplot(sigmoid, -5, 5, height=7)
┌────────────────────────────────────────┐
Expand Down Expand Up @@ -113,6 +116,7 @@ julia> lineplot(logsigmoid, -5, 5, height=7)
```
"""
logσ(x) = -softplus(-x)

const logsigmoid = logσ

"""
Expand All @@ -121,6 +125,7 @@ const logsigmoid = logσ
Segment-wise linear approximation of `tanh`, much cheaper to compute.
See ["Large Scale Machine Learning"](https://ronan.collobert.com/pub/matos/2004_phdthesis_lip6.pdf).
See also [`tanh_fast`](@ref).
```
julia> lineplot(hardtanh, -2, 2, height=7)
┌────────────────────────────────────────┐
Expand Down Expand Up @@ -652,22 +657,102 @@ for f in ACTIVATIONS
error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.")
end

## Define rrules for some activation functions, along with the
## Faster, less accurate, versions of some.

"""
tanh_fast(x)
This is a faster but slighly less accurate version of `tanh`.
Where Julia's `tanh` function has an error under 2 eps, this
may be wrong by 5 eps, a reduction by less than one decimal digit.
For `x::Float32` this is usually about 10 times faster,
with a smaller speedup for `x::Float64`.
For any other number types, it just calls `tanh`.
See also [`sigmoid_fast`](@ref).
```
julia> tanh(0.5f0)
0.46211717f0
julia> tanh_fast(0.5f0)
0.46211714f0
julia> hard_tanh(0.5f0)
0.5f0
```
"""
@inline function tanh_fast(x::Float32)
x2 = abs2(x)
n = evalpoly(x2, (1.0f0, 0.1346604f0, 0.0035974074f0, 2.2332108f-5, 1.587199f-8))
d = evalpoly(x2, (1.0f0, 0.4679937f0, 0.026262015f0, 0.0003453992f0, 8.7767893f-7))
ifelse(x2 < 66f0, x * (n / d), sign(x))
end

@inline function tanh_fast(x::Float64)
exp2x = @fastmath exp(x + x)
y = (exp2x - 1) / (exp2x + 1)
# That has large errors near zero; using `expm1` would more accurate, but about as slow as `tanh`.
# Instead, we switch to a polynomial, which is very accurate within its range:
x2 = x * x
ypoly = x * evalpoly(x2, (1.0, -0.33333333333324583, 0.13333333325511604, -0.05396823125794372, 0.02186660872609521, -0.008697141630499953))
ifelse(x2 > 900.0, sign(y), ifelse(x2 < 0.017, oftype(y, ypoly), y))
end

# These approximations are very badly behaved for Float16; none are fast.
# They are also a bit slower with ForwardDiff.Dual numbers, let's use Base:
tanh_fast(x::Real) = Base.tanh(x)

"""
sigmoid_fast(x)
This is a faster, and very slightly less accurate, version of `sigmoid`.
For `x::Float32, perhaps 3 times faster, and maximum errors 2 eps instead of 1.
See also [`tanh_fast`](@ref).
```
julia> sigmoid(0.2f0)
0.54983395f0
julia> sigmoid_fast(0.2f0)
0.54983395f0
julia> hardσ(0.2f0)
0.53333336f0
```
"""
@inline function sigmoid_fast(x::Real)
t = @fastmath exp(-abs(x))
y = ifelse(x 0, inv(1 + t), t / (1 + t))
ifelse(x > 40, one(y), ifelse(x < -80, zero(y), y))
end
# For x::Float32, this is not as quick as the rational tanh_fast(x) above,
# but that polynomial has poor relative accuracy for negative x.

sigmoid_fast(x::Float16) = sigmoid(x) # sigmoid_fast is extremely badly behaved at large x

## Define rrules for some activation functions, along with the
## broadcasted rrule activation functions.
## TODO: add to the lists below all activations.
## TODO: add to the lists below all activations.

## This is a performance hack specifically for Zygote, because it doesn't handle fused
## broadcasts well; but it generally should be good (or at least harmless) for any AD, as
## broadcasts well; but it generally should be good (or at least harmless) for any AD, as
## it saves ADing the broadcasting machinery.
## Related Issue https://github.com/JuliaDiff/ChainRulesCore.jl/issues/271

UNARY_ACTS = [ # f, df
(:relu, :(x > 0)),
(:hardtanh, :(-1 < x < 1)),
(:selu, :(deriv_selu(Ω))),
(, :(conj* (1 - Ω)))),
(:elu, :(deriv_elu(Ω))),
(:softplus, :(σ(x))),

(:tanh_fast, :(conj(1 - Ω^2))),
(:sigmoid_fast, :(conj* (1 - Ω)))),
]

for (f, df) in UNARY_ACTS
Expand Down
91 changes: 91 additions & 0 deletions test/activations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,97 @@ end
@test trelu(0.9,0.5) == 0.9
end

## Faster variants

using NNlib: tanh_fast, sigmoid_fast

function countepsfrom(x::T, xtrue) where {T<:AbstractFloat}
target = T(xtrue)
for n in Iterators.flatten(zip(0:100, -1:-1:-100))
nextfloat(x, n) === target && return n
end
return round(Int, (target - x) / eps(x))
end

mean_eps(f, g, xs) = mean(x -> abs(countepsfrom(f(x), g(big(x)))), xs)
worst_eps(f, g, xs) = maximum(x -> abs(countepsfrom(f(x), g(big(x)))), xs)
function find_worst(f, g, xs)
c, i = findmax(x -> abs(countepsfrom(f(x), g(big(x)))), xs)
c, xs[i]
end

@testset "tanh_fast & sigmoid_fast: Float64" begin

x64 = 1e-6:1e-4:5
xbig = 6:3:200.0

@testset "tanh" begin
mean_eps(tanh, tanh, x64) # 0.06582
worst_eps(tanh, tanh, x64) # 2

@test mean_eps(tanh_fast, tanh, x64) < 0.2 # 0.13164
@test worst_eps(tanh_fast, tanh, x64) <= 5 # 5

@test mean_eps(tanh_fast, tanh, -x64) < 0.6 # 0.5248
@test worst_eps(tanh_fast, tanh, -x64) <= 5 # 5

@test tanh_fast.(xbig) tanh.(xbig)
@test tanh_fast.(-xbig) tanh.(-xbig)
end
@testset "sigmoid" begin
mean_eps(sigmoid, sigmoid, x64) # 0.39246
worst_eps(sigmoid, sigmoid, x64) # 1

@test mean_eps(sigmoid_fast, sigmoid, x64) < 0.5 # 0.40432
@test worst_eps(sigmoid_fast, sigmoid, x64) <= 5 # 2

mean_eps(sigmoid, sigmoid, -x64) # 0.37672
worst_eps(sigmoid, sigmoid, -x64) # 2

@test mean_eps(sigmoid_fast, sigmoid, -x64) < 0.6 # 0.56478
@test worst_eps(sigmoid_fast, sigmoid, -x64) <= 5 # 4

@test sigmoid_fast.(xbig) sigmoid.(xbig)
@test sigmoid_fast.(-xbig) sigmoid.(-xbig)
end
end

@testset "tanh_fast & sigmoid_fast: Float32" begin

x32 = 1f-6:1f-4:5
xbig32 = 6:3:200f0

@testset "tanh" begin
mean_eps(tanh, tanh, x32) # 0.065
worst_eps(tanh, tanh, x32) # 1

@test mean_eps(tanh_fast, tanh, x32) < 0.8 # 0.65414
@test worst_eps(tanh_fast, tanh, x32) <= 5 # 5

@test mean_eps(tanh_fast, tanh, -x32) < 0.8 # 0.65414
@test worst_eps(tanh_fast, tanh, -x32) <= 5 # 5

@test tanh_fast.(xbig32) tanh.(xbig32)
@test tanh_fast.(-xbig32) tanh.(-xbig32)
end
@testset "sigmoid" begin
mean_eps(sigmoid, sigmoid, x32) # 0.38896
worst_eps(sigmoid, sigmoid, x32) # 1

@test mean_eps(sigmoid_fast, sigmoid, x32) < 0.5 # 0.38896
@test worst_eps(sigmoid_fast, sigmoid, x32) <= 2 # 2

mean_eps(sigmoid, sigmoid, -x32) # 0.38088
worst_eps(sigmoid, sigmoid, -x32) # 2

@test mean_eps(sigmoid_fast, sigmoid, -x32) < 0.5 # 0.38088
@test worst_eps(sigmoid_fast, sigmoid, -x32) <= 2 # 2

@test sigmoid_fast.(xbig32) sigmoid.(xbig32)
@test sigmoid_fast.(-xbig32) sigmoid.(-xbig32)
end
end

@testset "AutoDiff" begin

local rng = StableRNG(17)
Expand Down

2 comments on commit 37093c7

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/48390

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.30 -m "<description of version>" 37093c706a974972e8cced77ffc61155375c55dc
git push origin v0.7.30

Please sign in to comment.