From 0fdda7aff32aa1d4d2ef015dc579146104f27493 Mon Sep 17 00:00:00 2001 From: Dominique Makowski Date: Wed, 4 Sep 2024 07:53:51 +0100 Subject: [PATCH] Update src/basicfuns.jl Co-authored-by: David Widmann --- src/basicfuns.jl | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index b6cc55b..d2bc90a 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -262,19 +262,11 @@ Return `log(exp(x) - 1)` or the “invsoftplus” function. It is the inverse o logexpm1(x::Real) = x <= 18.0 ? log(_expm1(x)) : x <= 33.3 ? x - exp(-x) : oftype(exp(-x), x) logexpm1(x::Float32) = x <= 9f0 ? log(expm1(x)) : x <= 16f0 ? x - exp(-x) : oftype(exp(-x), x) -function softplus(x; a::Real=1.0) - if a == 1.0 - return log1pexp(x) - end - return log1pexp(a * x) / a -end +softplus(x::Real) = log1pexp(x) +softplus(x::Real, a::Real) = log1pexp(a * x) / a -function invsoftplus(y; a::Real=1.0) - if a == 1.0 - return logexpm1(y) - end - return logexpm1(a * y) / a -end +invsoftplus(y::Real) = logexpm1(y) +invsoftplus(y::Real, a::Real) = logexpm1(a * y) / a """