diff --git a/Project.toml b/Project.toml index 3bda4fd20..3271b8d54 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ LambertW = "984bce1d-4616-540c-a9ee-88d1112d94c9" Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" @@ -57,6 +58,7 @@ IfElse = "0.1" LaTeXStrings = "1.3" LambertW = "0.4.5" Latexify = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16" +LogExpFunctions = "0.3" MacroTools = "0.5" NaNMath = "0.3, 1" PrecompileTools = "1" diff --git a/src/Symbolics.jl b/src/Symbolics.jl index 47066f31e..9138b234a 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -110,6 +110,9 @@ include("integral.jl") include("array-lib.jl") +using LogExpFunctions +include("logexpfunctions-lib.jl") + include("linear_algebra.jl") using Groebner diff --git a/src/logexpfunctions-lib.jl b/src/logexpfunctions-lib.jl new file mode 100644 index 000000000..95cc6f94b --- /dev/null +++ b/src/logexpfunctions-lib.jl @@ -0,0 +1,16 @@ +# Implement a few of the LogExpFuncitons methods when those rely on boolean workflows. + +LogExpFunctions.log1mexp(x::RCNum) = log(1 - exp(x)) +LogExpFunctions.log1pexp(x::RCNum) = log(1 + exp(x)) +LogExpFunctions.logexpm1(x::RCNum) = log(exp(x) - 1) +LogExpFunctions.logmxp1(x::RCNum) = log(x) - x + 1 +for (f, op) in ((:logaddexp, +), (:logsubexp, -)) + @eval begin + LogExpFunctions.$(f)(x::RCNum, y::Real) = log($(op)(exp(x), exp(y))) + LogExpFunctions.$(f)(x::Real, y::RCNum) = log($(op)(exp(x), exp(y))) + LogExpFunctions.$(f)(x::RCNum, y::RCNum) = log($(op)(exp(x), exp(y))) + end +end +function LogExpFunctions.logsumexp(x::Union{AbstractVector{<:Num}, Arr}) + log(sum(exp, x; init = 0.0)) +end diff --git a/test/logexpfunctions.jl b/test/logexpfunctions.jl new file mode 100644 index 000000000..37b7f76ef --- /dev/null +++ b/test/logexpfunctions.jl @@ -0,0 +1,25 @@ +using Symbolics +using LogExpFunctions + +N = 10 + +@variables a, b, c, x[1:N] + +_a = -0.2 +_b = -1.0 +_c = 2.0 +_x = rand(N) + +vals = Dict(a => _a, b => _b, c => _c, x => _x) + +@test substitute(logaddexp(a, b), vals) ≈ logaddexp(_a, _b) +@test substitute(logaddexp(a, _b), vals) ≈ logaddexp(_a, _b) +@test substitute(logaddexp(_a, b), vals) ≈ logaddexp(_a, _b) +@test substitute(logsubexp(a, b), vals) ≈ logsubexp(_a, _b) +@test substitute(logsubexp(a, _b), vals) ≈ logsubexp(_a, _b) +@test substitute(logsubexp(_a, b), vals) ≈ logsubexp(_a, _b) +@test substitute(log1mexp(a), vals) ≈ log1mexp(_a) +@test substitute(log1pexp(a), vals) ≈ log1pexp(_a) +@test substitute(logexpm1(c), vals) ≈ logexpm1(_c) +@test substitute(logmxp1(c), vals) ≈ logmxp1(_c) +@test substitute(logsumexp(x), vals) ≈ logsumexp(_x) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 6b4e2d3b8..86cd4ac24 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,6 +44,7 @@ if GROUP == "All" || GROUP == "Core" @safetestset "Inequality Test" begin include("inequality.jl") end @safetestset "Integral Test" begin include("integral.jl") end @safetestset "CartesianIndex Test" begin include("cartesianindex.jl") end + @safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end end if GROUP == "Downstream"