From 35f18e309e14472004b745c3b79de12b1eeb72fc Mon Sep 17 00:00:00 2001 From: Dominique Makowski Date: Tue, 3 Sep 2024 08:44:51 +0100 Subject: [PATCH 01/12] Add "a" parameter to softplus() --- src/basicfuns.jl | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index e013adf..859db5b 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -165,9 +165,14 @@ Return `log(1+exp(x))` evaluated carefully for largish `x`. This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) transformation, being a smooth approximation to `max(0,x)`. Its inverse is [`logexpm1`](@ref). +The generalized `softplus` function (Wiemann et al., 2024) takes an additional optional parameter `a` that control +the approximation error with respect to the linear spline. It defaults to `a=1.0`, in which case the softplus is +equivalent to `log1pexp`. + See: * Martin Maechler (2012) [“Accurately Computing log(1 − exp(− |a|))”](http://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf) -""" + * Wiemann, P. F., Kneib, T., & Hambuckers, J. (2024). Using the softplus function to construct alternative link functions in generalized linear models and beyond. Statistical Papers, 65(5), 3155-3180. + """ log1pexp(x::Real) = _log1pexp(float(x)) # ensures that BigInt/BigFloat, Int/Float64 etc. dispatch to the same algorithm # Approximations based on Maechler (2012) @@ -255,10 +260,22 @@ Return `log(exp(x) - 1)` or the “invsoftplus” function. It is the inverse o [`log1pexp`](@ref) (aka “softplus”). """ logexpm1(x::Real) = x <= 18.0 ? log(_expm1(x)) : x <= 33.3 ? x - exp(-x) : oftype(exp(-x), x) -logexpm1(x::Float32) = x <= 9f0 ? log(expm1(x)) : x <= 16f0 ? x - exp(-x) : oftype(exp(-x), x) +logexpm1(x::Float32) = x <= 9.0f0 ? log(expm1(x)) : x <= 16.0f0 ? x - exp(-x) : oftype(exp(-x), x) + +function softplus(x; a::Real=1.0) + if a == 1.0 + return log1pexp(x) + end + return log1pexp(a * x) / a +end + +function invsoftplus(y; a::Real=1.0) + if a == 1.0 + return logexpm1(y) + end + return logexpm1(a * y) / a +end -const softplus = log1pexp -const invsoftplus = logexpm1 """ $(SIGNATURES) From 83acf1d01c53e12b36bbd3e5d308a861cf7218b4 Mon Sep 17 00:00:00 2001 From: Dominique Makowski Date: Wed, 4 Sep 2024 07:52:29 +0100 Subject: [PATCH 02/12] Update src/basicfuns.jl Co-authored-by: David Widmann --- src/basicfuns.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 859db5b..b6cc55b 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -260,7 +260,7 @@ Return `log(exp(x) - 1)` or the “invsoftplus” function. It is the inverse o [`log1pexp`](@ref) (aka “softplus”). """ logexpm1(x::Real) = x <= 18.0 ? log(_expm1(x)) : x <= 33.3 ? x - exp(-x) : oftype(exp(-x), x) -logexpm1(x::Float32) = x <= 9.0f0 ? log(expm1(x)) : x <= 16.0f0 ? x - exp(-x) : oftype(exp(-x), x) +logexpm1(x::Float32) = x <= 9f0 ? log(expm1(x)) : x <= 16f0 ? x - exp(-x) : oftype(exp(-x), x) function softplus(x; a::Real=1.0) if a == 1.0 From 0fdda7aff32aa1d4d2ef015dc579146104f27493 Mon Sep 17 00:00:00 2001 From: Dominique Makowski Date: Wed, 4 Sep 2024 07:53:51 +0100 Subject: [PATCH 03/12] Update src/basicfuns.jl Co-authored-by: David Widmann --- src/basicfuns.jl | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index b6cc55b..d2bc90a 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -262,19 +262,11 @@ Return `log(exp(x) - 1)` or the “invsoftplus” function. It is the inverse o logexpm1(x::Real) = x <= 18.0 ? log(_expm1(x)) : x <= 33.3 ? x - exp(-x) : oftype(exp(-x), x) logexpm1(x::Float32) = x <= 9f0 ? log(expm1(x)) : x <= 16f0 ? x - exp(-x) : oftype(exp(-x), x) -function softplus(x; a::Real=1.0) - if a == 1.0 - return log1pexp(x) - end - return log1pexp(a * x) / a -end +softplus(x::Real) = log1pexp(x) +softplus(x::Real, a::Real) = log1pexp(a * x) / a -function invsoftplus(y; a::Real=1.0) - if a == 1.0 - return logexpm1(y) - end - return logexpm1(a * y) / a -end +invsoftplus(y::Real) = logexpm1(y) +invsoftplus(y::Real, a::Real) = logexpm1(a * y) / a """ From bf85177ad28557dfc737935ea1103bd7976ef8d1 Mon Sep 17 00:00:00 2001 From: Dominique Makowski Date: Wed, 4 Sep 2024 08:06:52 +0100 Subject: [PATCH 04/12] add own docstring to softplus --- activate.jl | 31 +++++++++++++++++++++++++++++++ src/basicfuns.jl | 16 ++++++++++++---- 2 files changed, 43 insertions(+), 4 deletions(-) create mode 100644 activate.jl diff --git a/activate.jl b/activate.jl new file mode 100644 index 0000000..49da998 --- /dev/null +++ b/activate.jl @@ -0,0 +1,31 @@ +using Revise +using Pkg + +# Package +Pkg.activate("C:/Users/domma/Dropbox/Software/LogExpFunctions.jl/") + +using LogExpFunctions +using CairoMakie + + +xrange = range(-1.5, 1.5, length=100) +yexp = exp.(xrange) +ysoftplus1 = softplus.(xrange) +ysoftplus2 = softplus.(xrange; a=2) +ysoftplus3 = softplus.(xrange; a=3) + +ysoftplus5 = softplus.(xrange; a=5) +ysoftplus10 = softplus.(xrange; a=10) + + +# f = lines(xrange, yexp, color=:black) +f = lines(xrange, ysoftplus1, color=:red) +lines!(xrange, ysoftplus2, color=:orange) +lines!(xrange, ysoftplus3, color=:darkorange) +lines!(xrange, ysoftplus5, color=:green) +lines!(xrange, ysoftplus10, color=:blue) + +ablines!(0, 1, color=:grey, linestyle=:dash) +f + +softplus(0; a=3) \ No newline at end of file diff --git a/src/basicfuns.jl b/src/basicfuns.jl index d2bc90a..1e85c3c 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -165,13 +165,11 @@ Return `log(1+exp(x))` evaluated carefully for largish `x`. This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) transformation, being a smooth approximation to `max(0,x)`. Its inverse is [`logexpm1`](@ref). -The generalized `softplus` function (Wiemann et al., 2024) takes an additional optional parameter `a` that control -the approximation error with respect to the linear spline. It defaults to `a=1.0`, in which case the softplus is -equivalent to `log1pexp`. +This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) +transformation (in its default parametrization, see [`softplus`](@ref)), being a smooth approximation to `max(0,x)`. See: * Martin Maechler (2012) [“Accurately Computing log(1 − exp(− |a|))”](http://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf) - * Wiemann, P. F., Kneib, T., & Hambuckers, J. (2024). Using the softplus function to construct alternative link functions in generalized linear models and beyond. Statistical Papers, 65(5), 3155-3180. """ log1pexp(x::Real) = _log1pexp(float(x)) # ensures that BigInt/BigFloat, Int/Float64 etc. dispatch to the same algorithm @@ -262,6 +260,16 @@ Return `log(exp(x) - 1)` or the “invsoftplus” function. It is the inverse o logexpm1(x::Real) = x <= 18.0 ? log(_expm1(x)) : x <= 33.3 ? x - exp(-x) : oftype(exp(-x), x) logexpm1(x::Float32) = x <= 9f0 ? log(expm1(x)) : x <= 16f0 ? x - exp(-x) : oftype(exp(-x), x) +""" +$(SIGNATURES) + +The generalized `softplus` function (Wiemann et al., 2024) takes an additional optional parameter `a` that control +the approximation error with respect to the linear spline. It defaults to `a=1.0`, in which case the softplus is +equivalent to [`log1pexp`](@ref). + +See: + * Wiemann, P. F., Kneib, T., & Hambuckers, J. (2024). Using the softplus function to construct alternative link functions in generalized linear models and beyond. Statistical Papers, 65(5), 3155-3180. +""" softplus(x::Real) = log1pexp(x) softplus(x::Real, a::Real) = log1pexp(a * x) / a From af527b546d475264668a4cf28f14dd6733d31e8d Mon Sep 17 00:00:00 2001 From: Dominique Makowski Date: Wed, 4 Sep 2024 08:20:04 +0100 Subject: [PATCH 05/12] Update src/basicfuns.jl Co-authored-by: David Widmann --- src/basicfuns.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 1e85c3c..b35c4ae 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -170,7 +170,7 @@ transformation (in its default parametrization, see [`softplus`](@ref)), being a See: * Martin Maechler (2012) [“Accurately Computing log(1 − exp(− |a|))”](http://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf) - """ +""" log1pexp(x::Real) = _log1pexp(float(x)) # ensures that BigInt/BigFloat, Int/Float64 etc. dispatch to the same algorithm # Approximations based on Maechler (2012) From be659f86e92d6c7d863b52833883c6dab7904b16 Mon Sep 17 00:00:00 2001 From: Dominique Makowski Date: Wed, 4 Sep 2024 08:20:26 +0100 Subject: [PATCH 06/12] Delete activate.jl --- activate.jl | 31 ------------------------------- 1 file changed, 31 deletions(-) delete mode 100644 activate.jl diff --git a/activate.jl b/activate.jl deleted file mode 100644 index 49da998..0000000 --- a/activate.jl +++ /dev/null @@ -1,31 +0,0 @@ -using Revise -using Pkg - -# Package -Pkg.activate("C:/Users/domma/Dropbox/Software/LogExpFunctions.jl/") - -using LogExpFunctions -using CairoMakie - - -xrange = range(-1.5, 1.5, length=100) -yexp = exp.(xrange) -ysoftplus1 = softplus.(xrange) -ysoftplus2 = softplus.(xrange; a=2) -ysoftplus3 = softplus.(xrange; a=3) - -ysoftplus5 = softplus.(xrange; a=5) -ysoftplus10 = softplus.(xrange; a=10) - - -# f = lines(xrange, yexp, color=:black) -f = lines(xrange, ysoftplus1, color=:red) -lines!(xrange, ysoftplus2, color=:orange) -lines!(xrange, ysoftplus3, color=:darkorange) -lines!(xrange, ysoftplus5, color=:green) -lines!(xrange, ysoftplus10, color=:blue) - -ablines!(0, 1, color=:grey, linestyle=:dash) -f - -softplus(0; a=3) \ No newline at end of file From a113b43fff4ec8c45829dcc12a8691d2e5cf91e8 Mon Sep 17 00:00:00 2001 From: Dominique Makowski Date: Sat, 14 Sep 2024 16:33:31 +0100 Subject: [PATCH 07/12] Add docstring --- src/basicfuns.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index b35c4ae..d561fc0 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -273,6 +273,11 @@ See: softplus(x::Real) = log1pexp(x) softplus(x::Real, a::Real) = log1pexp(a * x) / a +""" +$(SIGNATURES) + +The inverse generalized `softplus` function (Wiemann et al., 2024). See [`softplus`](@ref). +""" invsoftplus(y::Real) = logexpm1(y) invsoftplus(y::Real, a::Real) = logexpm1(a * y) / a From 0a03aba24231fcea3d911b370eb4bc667b832156 Mon Sep 17 00:00:00 2001 From: Dominique Makowski Date: Sat, 14 Sep 2024 17:09:27 +0100 Subject: [PATCH 08/12] docs and test --- docs/src/index.md | 2 ++ src/basicfuns.jl | 2 ++ test/basicfuns.jl | 10 ++++++++++ 3 files changed, 14 insertions(+) diff --git a/docs/src/index.md b/docs/src/index.md index 5ad110b..3c0a7f3 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -18,6 +18,8 @@ logcosh logabssinh log1psq log1pexp +softplus +invsoftplus log1mexp log2mexp logexpm1 diff --git a/src/basicfuns.jl b/src/basicfuns.jl index d561fc0..abb15ff 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -272,6 +272,7 @@ See: """ softplus(x::Real) = log1pexp(x) softplus(x::Real, a::Real) = log1pexp(a * x) / a +softplus(x::Real; a::Real=1) = softplus(x, a) """ $(SIGNATURES) @@ -280,6 +281,7 @@ The inverse generalized `softplus` function (Wiemann et al., 2024). See [`softpl """ invsoftplus(y::Real) = logexpm1(y) invsoftplus(y::Real, a::Real) = logexpm1(a * y) / a +invsoftplus(y::Real; a::Real=1) = invsoftplus(y, a) """ diff --git a/test/basicfuns.jl b/test/basicfuns.jl index 1e615c0..23dcd58 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -161,6 +161,16 @@ end end end +@testset "softplus" begin + @test softplus(2) ≈ log1pexp(2) + @test softplus(2, 1) ≈ log1pexp(2) + @test softplus(2, a=1) ≈ log1pexp(2) + @test softplus(2, 10) < log1pexp(2) + @test invsoftplus(softplus(2), 1) ≈ 2 + @test invsoftplus(softplus(2, 10), a=10) ≈ 2 +end + + @testset "log1mexp" begin for T in (Float64, Float32, Float16) @test @inferred(log1mexp(-T(1))) isa T From 0c0ab2a854dc5eee2b10a64025b5341e74c86326 Mon Sep 17 00:00:00 2001 From: Dominique Makowski Date: Fri, 27 Sep 2024 09:39:55 +0100 Subject: [PATCH 09/12] Add inverse --- ext/LogExpFunctionsInverseFunctionsExt.jl | 3 +++ test/inverse.jl | 3 +++ 2 files changed, 6 insertions(+) diff --git a/ext/LogExpFunctionsInverseFunctionsExt.jl b/ext/LogExpFunctionsInverseFunctionsExt.jl index 1981493..b373b61 100644 --- a/ext/LogExpFunctionsInverseFunctionsExt.jl +++ b/ext/LogExpFunctionsInverseFunctionsExt.jl @@ -22,4 +22,7 @@ InverseFunctions.inverse(::typeof(logitexp)) = loglogistic InverseFunctions.inverse(::typeof(log1mlogistic)) = logit1mexp InverseFunctions.inverse(::typeof(logit1mexp)) = log1mlogistic +InverseFunctions.inverse(::typeof(softplus)) = invsofplus +InverseFunctions.inverse(::typeof(invsofplus)) = softplus + end # module diff --git a/test/inverse.jl b/test/inverse.jl index e489db5..7518b05 100644 --- a/test/inverse.jl +++ b/test/inverse.jl @@ -17,4 +17,7 @@ InverseFunctions.test_inverse(log1mlogistic, randexp()) InverseFunctions.test_inverse(logit1mexp, -randexp()) + + InverseFunctions.test_inverse(softplus, randn()) + InverseFunctions.test_inverse(invsoftplus, randexp()) end From 5d8a05b87dd08614d8a619bcc0dd1cd08b6b8b71 Mon Sep 17 00:00:00 2001 From: Dominique Makowski Date: Fri, 27 Sep 2024 09:44:58 +0100 Subject: [PATCH 10/12] typo --- ext/LogExpFunctionsInverseFunctionsExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/LogExpFunctionsInverseFunctionsExt.jl b/ext/LogExpFunctionsInverseFunctionsExt.jl index b373b61..611966f 100644 --- a/ext/LogExpFunctionsInverseFunctionsExt.jl +++ b/ext/LogExpFunctionsInverseFunctionsExt.jl @@ -22,7 +22,7 @@ InverseFunctions.inverse(::typeof(logitexp)) = loglogistic InverseFunctions.inverse(::typeof(log1mlogistic)) = logit1mexp InverseFunctions.inverse(::typeof(logit1mexp)) = log1mlogistic -InverseFunctions.inverse(::typeof(softplus)) = invsofplus -InverseFunctions.inverse(::typeof(invsofplus)) = softplus +InverseFunctions.inverse(::typeof(softplus)) = invsoftplus +InverseFunctions.inverse(::typeof(invsoftplus)) = softplus end # module From da5130f0ed6dbb62c1cb107d2e6eedc0513615b5 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 10 Dec 2024 20:55:11 +0100 Subject: [PATCH 11/12] Fix method overrides --- src/basicfuns.jl | 2 -- test/basicfuns.jl | 2 -- 2 files changed, 4 deletions(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index abb15ff..d561fc0 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -272,7 +272,6 @@ See: """ softplus(x::Real) = log1pexp(x) softplus(x::Real, a::Real) = log1pexp(a * x) / a -softplus(x::Real; a::Real=1) = softplus(x, a) """ $(SIGNATURES) @@ -281,7 +280,6 @@ The inverse generalized `softplus` function (Wiemann et al., 2024). See [`softpl """ invsoftplus(y::Real) = logexpm1(y) invsoftplus(y::Real, a::Real) = logexpm1(a * y) / a -invsoftplus(y::Real; a::Real=1) = invsoftplus(y, a) """ diff --git a/test/basicfuns.jl b/test/basicfuns.jl index 23dcd58..461c840 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -164,10 +164,8 @@ end @testset "softplus" begin @test softplus(2) ≈ log1pexp(2) @test softplus(2, 1) ≈ log1pexp(2) - @test softplus(2, a=1) ≈ log1pexp(2) @test softplus(2, 10) < log1pexp(2) @test invsoftplus(softplus(2), 1) ≈ 2 - @test invsoftplus(softplus(2, 10), a=10) ≈ 2 end From 5f1d99d287ac500f2d0e380263406a2f7004d052 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 10 Dec 2024 21:38:18 +0100 Subject: [PATCH 12/12] Add ChangesOfVariables definitions and extend tests --- ext/LogExpFunctionsChangesOfVariablesExt.jl | 14 +++++++++++++ ext/LogExpFunctionsInverseFunctionsExt.jl | 7 +++++++ test/basicfuns.jl | 22 ++++++++++++++++----- test/inverse.jl | 8 +++++--- test/with_logabsdet_jacobian.jl | 11 +++++++++++ 5 files changed, 54 insertions(+), 8 deletions(-) diff --git a/ext/LogExpFunctionsChangesOfVariablesExt.jl b/ext/LogExpFunctionsChangesOfVariablesExt.jl index 52a105b..cf7c960 100644 --- a/ext/LogExpFunctionsChangesOfVariablesExt.jl +++ b/ext/LogExpFunctionsChangesOfVariablesExt.jl @@ -8,11 +8,25 @@ function ChangesOfVariables.with_logabsdet_jacobian(::typeof(log1pexp), x::Real) y = log1pexp(x) return y, x - y end +function ChangesOfVariables.with_logabsdet_jacobian(::typeof(softplus), x::Real) + return ChangesOfVariables.with_logabsdet_jacobian(log1pexp, x) +end +function ChangesOfVariables.with_logabsdet_jacobian(f::Base.Fix2{typeof(softplus),<:Real}, x::Real) + y = f(x) + return y, f.x * (x - y) +end function ChangesOfVariables.with_logabsdet_jacobian(::typeof(logexpm1), x::Real) y = logexpm1(x) return y, x - y end +function ChangesOfVariables.with_logabsdet_jacobian(::typeof(invsoftplus), x::Real) + return ChangesOfVariables.with_logabsdet_jacobian(logexpm1, x) +end +function ChangesOfVariables.with_logabsdet_jacobian(f::Base.Fix2{typeof(invsoftplus),<:Real}, x::Real) + y = f(x) + return y, f.x * (x - y) +end function ChangesOfVariables.with_logabsdet_jacobian(::typeof(log1mexp), x::Real) y = log1mexp(x) diff --git a/ext/LogExpFunctionsInverseFunctionsExt.jl b/ext/LogExpFunctionsInverseFunctionsExt.jl index 611966f..6903303 100644 --- a/ext/LogExpFunctionsInverseFunctionsExt.jl +++ b/ext/LogExpFunctionsInverseFunctionsExt.jl @@ -23,6 +23,13 @@ InverseFunctions.inverse(::typeof(log1mlogistic)) = logit1mexp InverseFunctions.inverse(::typeof(logit1mexp)) = log1mlogistic InverseFunctions.inverse(::typeof(softplus)) = invsoftplus +function InverseFunctions.inverse(f::Base.Fix2{typeof(softplus),<:Real}) + Base.Fix2(invsoftplus, f.x) +end + InverseFunctions.inverse(::typeof(invsoftplus)) = softplus +function InverseFunctions.inverse(f::Base.Fix2{typeof(invsoftplus),<:Real}) + Base.Fix2(softplus, f.x) +end end # module diff --git a/test/basicfuns.jl b/test/basicfuns.jl index 461c840..72d0e44 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -162,13 +162,15 @@ end end @testset "softplus" begin - @test softplus(2) ≈ log1pexp(2) - @test softplus(2, 1) ≈ log1pexp(2) - @test softplus(2, 10) < log1pexp(2) - @test invsoftplus(softplus(2), 1) ≈ 2 + for T in (Int, Float64, Float32, Float16) + @test @inferred(softplus(T(2))) === log1pexp(T(2)) + @test @inferred(softplus(T(2), 1)) isa float(T) + @test @inferred(softplus(T(2), 1)) ≈ softplus(T(2)) + @test @inferred(softplus(T(2), 5)) ≈ softplus(5 * T(2)) / 5 + @test @inferred(softplus(T(2), 10)) ≈ softplus(10 * T(2)) / 10 + end end - @testset "log1mexp" begin for T in (Float64, Float32, Float16) @test @inferred(log1mexp(-T(1))) isa T @@ -194,6 +196,16 @@ end end end +@testset "invsoftplus" begin + for T in (Int, Float64, Float32, Float16) + @test @inferred(invsoftplus(T(2))) === logexpm1(T(2)) + @test @inferred(invsoftplus(T(2), 1)) isa float(T) + @test @inferred(invsoftplus(T(2), 1)) ≈ invsoftplus(T(2)) + @test @inferred(invsoftplus(T(2), 5)) ≈ invsoftplus(5 * T(2)) / 5 + @test @inferred(invsoftplus(T(2), 10)) ≈ invsoftplus(10 * T(2)) / 10 + end +end + @testset "log1pmx" begin @test iszero(log1pmx(0.0)) @test log1pmx(1.0) ≈ log(2.0) - 1.0 diff --git a/test/inverse.jl b/test/inverse.jl index 7518b05..630e2f7 100644 --- a/test/inverse.jl +++ b/test/inverse.jl @@ -1,6 +1,11 @@ @testset "inverse.jl" begin InverseFunctions.test_inverse(log1pexp, randn()) + InverseFunctions.test_inverse(softplus, randn()) + InverseFunctions.test_inverse(Base.Fix2(softplus, randexp()), randn()) + InverseFunctions.test_inverse(logexpm1, randexp()) + InverseFunctions.test_inverse(invsoftplus, randexp()) + InverseFunctions.test_inverse(Base.Fix2(invsoftplus, randexp()), randexp()) InverseFunctions.test_inverse(log1mexp, -randexp()) @@ -17,7 +22,4 @@ InverseFunctions.test_inverse(log1mlogistic, randexp()) InverseFunctions.test_inverse(logit1mexp, -randexp()) - - InverseFunctions.test_inverse(softplus, randn()) - InverseFunctions.test_inverse(invsoftplus, randexp()) end diff --git a/test/with_logabsdet_jacobian.jl b/test/with_logabsdet_jacobian.jl index d5f4484..3de5e9f 100644 --- a/test/with_logabsdet_jacobian.jl +++ b/test/with_logabsdet_jacobian.jl @@ -1,12 +1,23 @@ @testset "with_logabsdet_jacobian" begin derivative(f, x) = ChainRulesTestUtils.frule((ChainRulesTestUtils.NoTangent(), 1), f, x)[2] + derivative(::typeof(softplus), x) = derivative(log1pexp, x) + derivative(f::Base.Fix2{typeof(softplus),<:Real}, x) = derivative(log1pexp, f.x * x) + derivative(::typeof(invsoftplus), x) = derivative(logexpm1, x) + derivative(f::Base.Fix2{typeof(invsoftplus),<:Real}, x) = derivative(logexpm1, f.x * x) x = randexp() + y = randexp() ChangesOfVariables.test_with_logabsdet_jacobian(log1pexp, x, derivative) ChangesOfVariables.test_with_logabsdet_jacobian(log1pexp, -x, derivative) + ChangesOfVariables.test_with_logabsdet_jacobian(softplus, x, derivative) + ChangesOfVariables.test_with_logabsdet_jacobian(softplus, -x, derivative) + ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(softplus, y), x, derivative) + ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(softplus, y), -x, derivative) ChangesOfVariables.test_with_logabsdet_jacobian(logexpm1, x, derivative) + ChangesOfVariables.test_with_logabsdet_jacobian(invsoftplus, x, derivative) + ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(invsoftplus, y), x, derivative) ChangesOfVariables.test_with_logabsdet_jacobian(log1mexp, -x, derivative)