Skip to content

Commit

Permalink
@compact expand_all
Browse files Browse the repository at this point in the history
  • Loading branch information
oameye committed Oct 13, 2024
1 parent 3d2a655 commit 1c47abc
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 21 deletions.
46 changes: 26 additions & 20 deletions src/Symbolics_customised.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,34 +32,40 @@ using Symbolics:
expand,
operation

"Returns true if expr is an exponential"
is_exp(expr) = isterm(expr) && expr.f == exp

"Expand powers of exponential such that exp(x)^n => exp(x*n) "
expand_exp_power(expr) =
ispow(expr) && is_exp(expr.base) ? exp(expr.base.arguments[1] * expr.exp) : expr
expand_exp_power_add(expr) = sum([expand_exp_power(arg) for arg in arguments(expr)])
expand_exp_power_mul(expr) = prod([expand_exp_power(arg) for arg in arguments(expr)])
# macro wrap_method(f)
# return quote
# @eval function $f(x::Num)
# wrap($f(unwrap(x)))
# end
# end
# end

# @wrap_method expand_all
# @wrap_method expand_exp_power
expand_exp_power(expr::Num) = expand_exp_power(expr.val)
expand_all(x::Num) = Num(expand_all(x.val))


"Returns true if expr is an exponential"
isexp(expr) = isterm(expr) && expr.f == exp

"Expand powers of exponential such that exp(x)^n => exp(x*n) "
function expand_exp_power(expr::BasicSymbolic)
if isadd(expr)
return expand_exp_power_add(expr)
elseif ismul(expr)
return expand_exp_power_mul(expr)
else
return if ispow(expr) && is_exp(expr.base)
exp(expr.base.arguments[1] * expr.exp)
else
expr
end
@compactified expr::BasicSymbolic begin
Add => sum([expand_exp_power(arg) for arg in arguments(expr)])
Mul => prod([expand_exp_power(arg) for arg in arguments(expr)])
_ => ispow(expr) && isexp(expr.base) ? exp(expr.base.arguments[1] * expr.exp) : expr
end
end
expand_exp_power(expr) = expr

"Expands using SymbolicUtils.expand and expand_exp_power (changes exp(x)^n to exp(x*n)"
expand_all(x) = Postwalk(expand_exp_power)(SymbolicUtils.expand(x))
function expand_all(x)
result = Postwalk(expand_exp_power)(SymbolicUtils.expand(x))
return isnothing(result) ? x : result
end
expand_all(x::Complex{Num}) = expand_all(x.re) + im * expand_all(x.im)
expand_all(x::Num) = Num(expand_all(x.val))

"Apply a function f on every member of a sum or a product"
function _apply_termwise(f, x::BasicSymbolic)
Expand All @@ -85,7 +91,7 @@ end
"Simplify products of exponentials such that exp(a)*exp(b) => exp(a+b)
This is included in SymbolicUtils as of 17.0 but the method here avoid other simplify calls"
function simplify_exp_products_mul(expr)
ind = findall(x -> is_exp(x), arguments(expr))
ind = findall(x -> isexp(x), arguments(expr))
rest_ind = setdiff(1:length(arguments(expr)), ind)
rest = isempty(rest_ind) ? 1 : prod(arguments(expr)[rest_ind])
total = isempty(ind) ? 0 : sum(getindex.(arguments.(arguments(expr)[ind]), 1))
Expand Down
3 changes: 2 additions & 1 deletion test/symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ macro eqtest(expr)
end

@testset "exp(x)^n => exp(x*n)" begin
using HarmonicBalance: expand_all
using HarmonicBalance: expand_all, expand_exp_power
@variables a n

@eqtest expand_exp_power(exp(a)^3) == exp(3 * a)
@eqtest simplify(exp(a)^3) == exp(3 * a)
@eqtest simplify(exp(a)^n) == exp(n * a)
@eqtest expand_all(exp(a)^3) == exp(3 * a)
Expand Down

0 comments on commit 1c47abc

Please sign in to comment.