Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "a" parameter to softplus() #83 #85

Merged
merged 13 commits into from
Dec 11, 2024
2 changes: 2 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ logcosh
logabssinh
log1psq
log1pexp
softplus
invsoftplus
log1mexp
log2mexp
logexpm1
Expand Down
14 changes: 14 additions & 0 deletions ext/LogExpFunctionsChangesOfVariablesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,25 @@ function ChangesOfVariables.with_logabsdet_jacobian(::typeof(log1pexp), x::Real)
y = log1pexp(x)
return y, x - y
end
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(softplus), x::Real)
return ChangesOfVariables.with_logabsdet_jacobian(log1pexp, x)
end
function ChangesOfVariables.with_logabsdet_jacobian(f::Base.Fix2{typeof(softplus),<:Real}, x::Real)
y = f(x)
return y, f.x * (x - y)
end

function ChangesOfVariables.with_logabsdet_jacobian(::typeof(logexpm1), x::Real)
y = logexpm1(x)
return y, x - y
end
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(invsoftplus), x::Real)
return ChangesOfVariables.with_logabsdet_jacobian(logexpm1, x)
end
function ChangesOfVariables.with_logabsdet_jacobian(f::Base.Fix2{typeof(invsoftplus),<:Real}, x::Real)
y = f(x)
return y, f.x * (x - y)
end

function ChangesOfVariables.with_logabsdet_jacobian(::typeof(log1mexp), x::Real)
y = log1mexp(x)
Expand Down
10 changes: 10 additions & 0 deletions ext/LogExpFunctionsInverseFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,14 @@ InverseFunctions.inverse(::typeof(logitexp)) = loglogistic
InverseFunctions.inverse(::typeof(log1mlogistic)) = logit1mexp
InverseFunctions.inverse(::typeof(logit1mexp)) = log1mlogistic

InverseFunctions.inverse(::typeof(softplus)) = invsoftplus
function InverseFunctions.inverse(f::Base.Fix2{typeof(softplus),<:Real})
Base.Fix2(invsoftplus, f.x)
end

InverseFunctions.inverse(::typeof(invsoftplus)) = softplus
function InverseFunctions.inverse(f::Base.Fix2{typeof(invsoftplus),<:Real})
Base.Fix2(softplus, f.x)
end

end # module
26 changes: 24 additions & 2 deletions src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ Return `log(1+exp(x))` evaluated carefully for largish `x`.
This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
transformation, being a smooth approximation to `max(0,x)`. Its inverse is [`logexpm1`](@ref).

This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
transformation (in its default parametrization, see [`softplus`](@ref)), being a smooth approximation to `max(0,x)`.

See:
* Martin Maechler (2012) [“Accurately Computing log(1 − exp(− |a|))”](http://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf)
"""
Expand Down Expand Up @@ -257,8 +260,27 @@ Return `log(exp(x) - 1)` or the “invsoftplus” function. It is the inverse o
logexpm1(x::Real) = x <= 18.0 ? log(_expm1(x)) : x <= 33.3 ? x - exp(-x) : oftype(exp(-x), x)
logexpm1(x::Float32) = x <= 9f0 ? log(expm1(x)) : x <= 16f0 ? x - exp(-x) : oftype(exp(-x), x)

const softplus = log1pexp
const invsoftplus = logexpm1
"""
$(SIGNATURES)

The generalized `softplus` function (Wiemann et al., 2024) takes an additional optional parameter `a` that control
the approximation error with respect to the linear spline. It defaults to `a=1.0`, in which case the softplus is
equivalent to [`log1pexp`](@ref).

See:
* Wiemann, P. F., Kneib, T., & Hambuckers, J. (2024). Using the softplus function to construct alternative link functions in generalized linear models and beyond. Statistical Papers, 65(5), 3155-3180.
"""
softplus(x::Real) = log1pexp(x)
softplus(x::Real, a::Real) = log1pexp(a * x) / a

"""
$(SIGNATURES)

The inverse generalized `softplus` function (Wiemann et al., 2024). See [`softplus`](@ref).
"""
invsoftplus(y::Real) = logexpm1(y)
DominiqueMakowski marked this conversation as resolved.
Show resolved Hide resolved
invsoftplus(y::Real, a::Real) = logexpm1(a * y) / a


"""
$(SIGNATURES)
Expand Down
20 changes: 20 additions & 0 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@ end
end
end

@testset "softplus" begin
for T in (Int, Float64, Float32, Float16)
@test @inferred(softplus(T(2))) === log1pexp(T(2))
@test @inferred(softplus(T(2), 1)) isa float(T)
@test @inferred(softplus(T(2), 1)) ≈ softplus(T(2))
@test @inferred(softplus(T(2), 5)) ≈ softplus(5 * T(2)) / 5
@test @inferred(softplus(T(2), 10)) ≈ softplus(10 * T(2)) / 10
end
end

@testset "log1mexp" begin
for T in (Float64, Float32, Float16)
@test @inferred(log1mexp(-T(1))) isa T
Expand All @@ -186,6 +196,16 @@ end
end
end

@testset "invsoftplus" begin
for T in (Int, Float64, Float32, Float16)
@test @inferred(invsoftplus(T(2))) === logexpm1(T(2))
@test @inferred(invsoftplus(T(2), 1)) isa float(T)
@test @inferred(invsoftplus(T(2), 1)) ≈ invsoftplus(T(2))
@test @inferred(invsoftplus(T(2), 5)) ≈ invsoftplus(5 * T(2)) / 5
@test @inferred(invsoftplus(T(2), 10)) ≈ invsoftplus(10 * T(2)) / 10
end
end

@testset "log1pmx" begin
@test iszero(log1pmx(0.0))
@test log1pmx(1.0) ≈ log(2.0) - 1.0
Expand Down
5 changes: 5 additions & 0 deletions test/inverse.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
@testset "inverse.jl" begin
InverseFunctions.test_inverse(log1pexp, randn())
InverseFunctions.test_inverse(softplus, randn())
InverseFunctions.test_inverse(Base.Fix2(softplus, randexp()), randn())

InverseFunctions.test_inverse(logexpm1, randexp())
InverseFunctions.test_inverse(invsoftplus, randexp())
InverseFunctions.test_inverse(Base.Fix2(invsoftplus, randexp()), randexp())

InverseFunctions.test_inverse(log1mexp, -randexp())

Expand Down
11 changes: 11 additions & 0 deletions test/with_logabsdet_jacobian.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
@testset "with_logabsdet_jacobian" begin
derivative(f, x) = ChainRulesTestUtils.frule((ChainRulesTestUtils.NoTangent(), 1), f, x)[2]
derivative(::typeof(softplus), x) = derivative(log1pexp, x)
derivative(f::Base.Fix2{typeof(softplus),<:Real}, x) = derivative(log1pexp, f.x * x)
derivative(::typeof(invsoftplus), x) = derivative(logexpm1, x)
derivative(f::Base.Fix2{typeof(invsoftplus),<:Real}, x) = derivative(logexpm1, f.x * x)

x = randexp()
y = randexp()

ChangesOfVariables.test_with_logabsdet_jacobian(log1pexp, x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(log1pexp, -x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(softplus, x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(softplus, -x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(softplus, y), x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(softplus, y), -x, derivative)

ChangesOfVariables.test_with_logabsdet_jacobian(logexpm1, x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(invsoftplus, x, derivative)
ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(invsoftplus, y), x, derivative)

ChangesOfVariables.test_with_logabsdet_jacobian(log1mexp, -x, derivative)

Expand Down
Loading