diff --git a/src/activations.jl b/src/activations.jl index 20e082ff3..3f49978a5 100644 --- a/src/activations.jl +++ b/src/activations.jl @@ -3,12 +3,13 @@ # Some of activation functions have its wrapper function for GPU in NNlibCUDA.jl. # https://github.com/JuliaGPU/CuArrays.jl/issues/614 -const ACTIVATIONS = - [:σ, :hardσ, :hardtanh, :relu, - :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :selu, - :celu, :softplus, :softsign, :logσ, :logcosh, - :mish, :tanhshrink, :softshrink, :trelu, - :lisht] +ACTIVATIONS = [ + :σ, :hardσ, :hardtanh, :relu, + :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :selu, + :celu, :softplus, :softsign, :logσ, :logcosh, + :mish, :tanhshrink, :softshrink, :trelu, :lisht, + :tanh_fast, :sigmoid_fast, + ] for f in ACTIVATIONS @eval export $(f) @@ -28,6 +29,8 @@ function. Unicode `σ` can be entered as `\\sigma` then tab, in many editors. The ascii name `sigmoid` is also exported. +See also [`sigmoid_fast`](@ref). + ``` julia> lineplot(sigmoid, -5, 5, height=7) ┌────────────────────────────────────────┐ @@ -113,6 +116,7 @@ julia> lineplot(logsigmoid, -5, 5, height=7) ``` """ logσ(x) = -softplus(-x) + const logsigmoid = logσ """ @@ -121,6 +125,7 @@ const logsigmoid = logσ Segment-wise linear approximation of `tanh`, much cheaper to compute. See ["Large Scale Machine Learning"](https://ronan.collobert.com/pub/matos/2004_phdthesis_lip6.pdf). +See also [`tanh_fast`](@ref). ``` julia> lineplot(hardtanh, -2, 2, height=7) ┌────────────────────────────────────────┐ @@ -652,15 +657,92 @@ for f in ACTIVATIONS error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.") end -## Define rrules for some activation functions, along with the +## Faster, less accurate, versions of some. + +""" + tanh_fast(x) + +This is a faster but slighly less accurate version of `tanh`. + +Where Julia's `tanh` function has an error under 2 eps, this +may be wrong by 5 eps, a reduction by less than one decimal digit. + +For `x::Float32` this is usually about 10 times faster, +with a smaller speedup for `x::Float64`. +For any other number types, it just calls `tanh`. + +See also [`sigmoid_fast`](@ref). + +``` +julia> tanh(0.5f0) +0.46211717f0 + +julia> tanh_fast(0.5f0) +0.46211714f0 + +julia> hard_tanh(0.5f0) +0.5f0 +``` +""" +@inline function tanh_fast(x::Float32) + x2 = abs2(x) + n = evalpoly(x2, (1.0f0, 0.1346604f0, 0.0035974074f0, 2.2332108f-5, 1.587199f-8)) + d = evalpoly(x2, (1.0f0, 0.4679937f0, 0.026262015f0, 0.0003453992f0, 8.7767893f-7)) + ifelse(x2 < 66f0, x * (n / d), sign(x)) +end + +@inline function tanh_fast(x::Float64) + exp2x = @fastmath exp(x + x) + y = (exp2x - 1) / (exp2x + 1) + # That has large errors near zero; using `expm1` would more accurate, but about as slow as `tanh`. + # Instead, we switch to a polynomial, which is very accurate within its range: + x2 = x * x + ypoly = x * evalpoly(x2, (1.0, -0.33333333333324583, 0.13333333325511604, -0.05396823125794372, 0.02186660872609521, -0.008697141630499953)) + ifelse(x2 > 900.0, sign(y), ifelse(x2 < 0.017, oftype(y, ypoly), y)) +end + +# These approximations are very badly behaved for Float16; none are fast. +# They are also a bit slower with ForwardDiff.Dual numbers, let's use Base: +tanh_fast(x::Real) = Base.tanh(x) + +""" + sigmoid_fast(x) + +This is a faster, and very slightly less accurate, version of `sigmoid`. +For `x::Float32, perhaps 3 times faster, and maximum errors 2 eps instead of 1. + +See also [`tanh_fast`](@ref). + +``` +julia> sigmoid(0.2f0) +0.54983395f0 + +julia> sigmoid_fast(0.2f0) +0.54983395f0 + +julia> hardσ(0.2f0) +0.53333336f0 +``` +""" +@inline function sigmoid_fast(x::Real) + t = @fastmath exp(-abs(x)) + y = ifelse(x ≥ 0, inv(1 + t), t / (1 + t)) + ifelse(x > 40, one(y), ifelse(x < -80, zero(y), y)) +end +# For x::Float32, this is not as quick as the rational tanh_fast(x) above, +# but that polynomial has poor relative accuracy for negative x. + +sigmoid_fast(x::Float16) = sigmoid(x) # sigmoid_fast is extremely badly behaved at large x + +## Define rrules for some activation functions, along with the ## broadcasted rrule activation functions. -## TODO: add to the lists below all activations. +## TODO: add to the lists below all activations. ## This is a performance hack specifically for Zygote, because it doesn't handle fused -## broadcasts well; but it generally should be good (or at least harmless) for any AD, as +## broadcasts well; but it generally should be good (or at least harmless) for any AD, as ## it saves ADing the broadcasting machinery. ## Related Issue https://github.com/JuliaDiff/ChainRulesCore.jl/issues/271 - + UNARY_ACTS = [ # f, df (:relu, :(x > 0)), (:hardtanh, :(-1 < x < 1)), @@ -668,6 +750,9 @@ UNARY_ACTS = [ # f, df (:σ, :(conj(Ω * (1 - Ω)))), (:elu, :(deriv_elu(Ω))), (:softplus, :(σ(x))), + + (:tanh_fast, :(conj(1 - Ω^2))), + (:sigmoid_fast, :(conj(Ω * (1 - Ω)))), ] for (f, df) in UNARY_ACTS diff --git a/test/activations.jl b/test/activations.jl index 12162649e..0a0c5a2f5 100644 --- a/test/activations.jl +++ b/test/activations.jl @@ -221,6 +221,97 @@ end @test trelu(0.9,0.5) == 0.9 end +## Faster variants + +using NNlib: tanh_fast, sigmoid_fast + +function countepsfrom(x::T, xtrue) where {T<:AbstractFloat} + target = T(xtrue) + for n in Iterators.flatten(zip(0:100, -1:-1:-100)) + nextfloat(x, n) === target && return n + end + return round(Int, (target - x) / eps(x)) +end + +mean_eps(f, g, xs) = mean(x -> abs(countepsfrom(f(x), g(big(x)))), xs) +worst_eps(f, g, xs) = maximum(x -> abs(countepsfrom(f(x), g(big(x)))), xs) +function find_worst(f, g, xs) + c, i = findmax(x -> abs(countepsfrom(f(x), g(big(x)))), xs) + c, xs[i] +end + +@testset "tanh_fast & sigmoid_fast: Float64" begin + + x64 = 1e-6:1e-4:5 + xbig = 6:3:200.0 + + @testset "tanh" begin + mean_eps(tanh, tanh, x64) # 0.06582 + worst_eps(tanh, tanh, x64) # 2 + + @test mean_eps(tanh_fast, tanh, x64) < 0.2 # 0.13164 + @test worst_eps(tanh_fast, tanh, x64) <= 5 # 5 + + @test mean_eps(tanh_fast, tanh, -x64) < 0.6 # 0.5248 + @test worst_eps(tanh_fast, tanh, -x64) <= 5 # 5 + + @test tanh_fast.(xbig) ≈ tanh.(xbig) + @test tanh_fast.(-xbig) ≈ tanh.(-xbig) + end + @testset "sigmoid" begin + mean_eps(sigmoid, sigmoid, x64) # 0.39246 + worst_eps(sigmoid, sigmoid, x64) # 1 + + @test mean_eps(sigmoid_fast, sigmoid, x64) < 0.5 # 0.40432 + @test worst_eps(sigmoid_fast, sigmoid, x64) <= 5 # 2 + + mean_eps(sigmoid, sigmoid, -x64) # 0.37672 + worst_eps(sigmoid, sigmoid, -x64) # 2 + + @test mean_eps(sigmoid_fast, sigmoid, -x64) < 0.6 # 0.56478 + @test worst_eps(sigmoid_fast, sigmoid, -x64) <= 5 # 4 + + @test sigmoid_fast.(xbig) ≈ sigmoid.(xbig) + @test sigmoid_fast.(-xbig) ≈ sigmoid.(-xbig) + end +end + +@testset "tanh_fast & sigmoid_fast: Float32" begin + + x32 = 1f-6:1f-4:5 + xbig32 = 6:3:200f0 + + @testset "tanh" begin + mean_eps(tanh, tanh, x32) # 0.065 + worst_eps(tanh, tanh, x32) # 1 + + @test mean_eps(tanh_fast, tanh, x32) < 0.8 # 0.65414 + @test worst_eps(tanh_fast, tanh, x32) <= 5 # 5 + + @test mean_eps(tanh_fast, tanh, -x32) < 0.8 # 0.65414 + @test worst_eps(tanh_fast, tanh, -x32) <= 5 # 5 + + @test tanh_fast.(xbig32) ≈ tanh.(xbig32) + @test tanh_fast.(-xbig32) ≈ tanh.(-xbig32) + end + @testset "sigmoid" begin + mean_eps(sigmoid, sigmoid, x32) # 0.38896 + worst_eps(sigmoid, sigmoid, x32) # 1 + + @test mean_eps(sigmoid_fast, sigmoid, x32) < 0.5 # 0.38896 + @test worst_eps(sigmoid_fast, sigmoid, x32) <= 2 # 2 + + mean_eps(sigmoid, sigmoid, -x32) # 0.38088 + worst_eps(sigmoid, sigmoid, -x32) # 2 + + @test mean_eps(sigmoid_fast, sigmoid, -x32) < 0.5 # 0.38088 + @test worst_eps(sigmoid_fast, sigmoid, -x32) <= 2 # 2 + + @test sigmoid_fast.(xbig32) ≈ sigmoid.(xbig32) + @test sigmoid_fast.(-xbig32) ≈ sigmoid.(-xbig32) + end +end + @testset "AutoDiff" begin local rng = StableRNG(17)