diff --git a/src/Symbolics_customised.jl b/src/Symbolics_customised.jl
index 6ce207dfa..b643cdbc8 100644
--- a/src/Symbolics_customised.jl
+++ b/src/Symbolics_customised.jl
@@ -32,34 +32,40 @@ using Symbolics:
     expand,
     operation
 
-"Returns true if expr is an exponential"
-is_exp(expr) = isterm(expr) && expr.f == exp
 
-"Expand powers of exponential such that exp(x)^n => exp(x*n) "
-expand_exp_power(expr) =
-    ispow(expr) && is_exp(expr.base) ? exp(expr.base.arguments[1] * expr.exp) : expr
-expand_exp_power_add(expr) = sum([expand_exp_power(arg) for arg in arguments(expr)])
-expand_exp_power_mul(expr) = prod([expand_exp_power(arg) for arg in arguments(expr)])
+# macro wrap_method(f)
+#     return quote
+#         @eval function $f(x::Num)
+#             wrap($f(unwrap(x)))
+#         end
+#     end
+# end
+
+# @wrap_method expand_all
+# @wrap_method expand_exp_power
 expand_exp_power(expr::Num) = expand_exp_power(expr.val)
+expand_all(x::Num) = Num(expand_all(x.val))
+
+
+"Returns true if expr is an exponential"
+isexp(expr) = isterm(expr) && expr.f == exp
 
+"Expand powers of exponential such that exp(x)^n => exp(x*n) "
 function expand_exp_power(expr::BasicSymbolic)
-    if isadd(expr)
-        return expand_exp_power_add(expr)
-    elseif ismul(expr)
-        return expand_exp_power_mul(expr)
-    else
-        return if ispow(expr) && is_exp(expr.base)
-            exp(expr.base.arguments[1] * expr.exp)
-        else
-            expr
-        end
+    @compactified expr::BasicSymbolic begin
+    Add => sum([expand_exp_power(arg) for arg in arguments(expr)])
+    Mul => prod([expand_exp_power(arg) for arg in arguments(expr)])
+    _   => ispow(expr) && isexp(expr.base) ? exp(expr.base.arguments[1] * expr.exp) : expr
     end
 end
+expand_exp_power(expr) = expr
 
 "Expands using SymbolicUtils.expand and expand_exp_power (changes exp(x)^n to exp(x*n)"
-expand_all(x) = Postwalk(expand_exp_power)(SymbolicUtils.expand(x))
+function expand_all(x)
+    result = Postwalk(expand_exp_power)(SymbolicUtils.expand(x))
+    return isnothing(result) ? x : result
+end
 expand_all(x::Complex{Num}) = expand_all(x.re) + im * expand_all(x.im)
-expand_all(x::Num) = Num(expand_all(x.val))
 
 "Apply a function f on every member of a sum or a product"
 function _apply_termwise(f, x::BasicSymbolic)
@@ -85,7 +91,7 @@ end
 "Simplify products of exponentials such that exp(a)*exp(b) => exp(a+b)
 This is included in SymbolicUtils as of 17.0 but the method here avoid other simplify calls"
 function simplify_exp_products_mul(expr)
-    ind = findall(x -> is_exp(x), arguments(expr))
+    ind = findall(x -> isexp(x), arguments(expr))
     rest_ind = setdiff(1:length(arguments(expr)), ind)
     rest = isempty(rest_ind) ? 1 : prod(arguments(expr)[rest_ind])
     total = isempty(ind) ? 0 : sum(getindex.(arguments.(arguments(expr)[ind]), 1))
diff --git a/test/symbolics.jl b/test/symbolics.jl
index 694605707..bbba52951 100644
--- a/test/symbolics.jl
+++ b/test/symbolics.jl
@@ -13,9 +13,10 @@ macro eqtest(expr)
 end
 
 @testset "exp(x)^n => exp(x*n)" begin
-    using HarmonicBalance: expand_all
+    using HarmonicBalance: expand_all, expand_exp_power
     @variables a n
 
+    @eqtest expand_exp_power(exp(a)^3) == exp(3 * a)
     @eqtest simplify(exp(a)^3) == exp(3 * a)
     @eqtest simplify(exp(a)^n) == exp(n * a)
     @eqtest expand_all(exp(a)^3) == exp(3 * a)