diff --git a/src/Symbolics_customised.jl b/src/Symbolics_customised.jl index 6ce207dfa..b643cdbc8 100644 --- a/src/Symbolics_customised.jl +++ b/src/Symbolics_customised.jl @@ -32,34 +32,40 @@ using Symbolics: expand, operation -"Returns true if expr is an exponential" -is_exp(expr) = isterm(expr) && expr.f == exp -"Expand powers of exponential such that exp(x)^n => exp(x*n) " -expand_exp_power(expr) = - ispow(expr) && is_exp(expr.base) ? exp(expr.base.arguments[1] * expr.exp) : expr -expand_exp_power_add(expr) = sum([expand_exp_power(arg) for arg in arguments(expr)]) -expand_exp_power_mul(expr) = prod([expand_exp_power(arg) for arg in arguments(expr)]) +# macro wrap_method(f) +# return quote +# @eval function $f(x::Num) +# wrap($f(unwrap(x))) +# end +# end +# end + +# @wrap_method expand_all +# @wrap_method expand_exp_power expand_exp_power(expr::Num) = expand_exp_power(expr.val) +expand_all(x::Num) = Num(expand_all(x.val)) + + +"Returns true if expr is an exponential" +isexp(expr) = isterm(expr) && expr.f == exp +"Expand powers of exponential such that exp(x)^n => exp(x*n) " function expand_exp_power(expr::BasicSymbolic) - if isadd(expr) - return expand_exp_power_add(expr) - elseif ismul(expr) - return expand_exp_power_mul(expr) - else - return if ispow(expr) && is_exp(expr.base) - exp(expr.base.arguments[1] * expr.exp) - else - expr - end + @compactified expr::BasicSymbolic begin + Add => sum([expand_exp_power(arg) for arg in arguments(expr)]) + Mul => prod([expand_exp_power(arg) for arg in arguments(expr)]) + _ => ispow(expr) && isexp(expr.base) ? exp(expr.base.arguments[1] * expr.exp) : expr end end +expand_exp_power(expr) = expr "Expands using SymbolicUtils.expand and expand_exp_power (changes exp(x)^n to exp(x*n)" -expand_all(x) = Postwalk(expand_exp_power)(SymbolicUtils.expand(x)) +function expand_all(x) + result = Postwalk(expand_exp_power)(SymbolicUtils.expand(x)) + return isnothing(result) ? x : result +end expand_all(x::Complex{Num}) = expand_all(x.re) + im * expand_all(x.im) -expand_all(x::Num) = Num(expand_all(x.val)) "Apply a function f on every member of a sum or a product" function _apply_termwise(f, x::BasicSymbolic) @@ -85,7 +91,7 @@ end "Simplify products of exponentials such that exp(a)*exp(b) => exp(a+b) This is included in SymbolicUtils as of 17.0 but the method here avoid other simplify calls" function simplify_exp_products_mul(expr) - ind = findall(x -> is_exp(x), arguments(expr)) + ind = findall(x -> isexp(x), arguments(expr)) rest_ind = setdiff(1:length(arguments(expr)), ind) rest = isempty(rest_ind) ? 1 : prod(arguments(expr)[rest_ind]) total = isempty(ind) ? 0 : sum(getindex.(arguments.(arguments(expr)[ind]), 1)) diff --git a/test/symbolics.jl b/test/symbolics.jl index 694605707..bbba52951 100644 --- a/test/symbolics.jl +++ b/test/symbolics.jl @@ -13,9 +13,10 @@ macro eqtest(expr) end @testset "exp(x)^n => exp(x*n)" begin - using HarmonicBalance: expand_all + using HarmonicBalance: expand_all, expand_exp_power @variables a n + @eqtest expand_exp_power(exp(a)^3) == exp(3 * a) @eqtest simplify(exp(a)^3) == exp(3 * a) @eqtest simplify(exp(a)^n) == exp(n * a) @eqtest expand_all(exp(a)^3) == exp(3 * a)