diff --git a/src/solver/postprocess.jl b/src/solver/postprocess.jl index 2a0ce1c85..6236573f9 100644 --- a/src/solver/postprocess.jl +++ b/src/solver/postprocess.jl @@ -104,26 +104,8 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) end end - opers = [acos, asin, atan] - exacts = [0, Symbolics.term(*, pi), Symbolics.term(/,pi,3), - Symbolics.term(/, pi, 2), - Symbolics.term(/, Symbolics.term(*, 2, pi), 3), - Symbolics.term(/, pi, 6), - Symbolics.term(/, Symbolics.term(*, 5, pi), 6), - Symbolics.term(/, pi, 4) - ] - - if any(isequal(oper, o) for o in opers) && isempty(Symbolics.get_variables(x)) - val = eval(Symbolics.toexpr(x)) - for i in eachindex(exacts) - exact_val = eval(Symbolics.toexpr(exacts[i])) - if isapprox(exact_val, val, atol=1e-6) - return exacts[i] - elseif isapprox(-exact_val, val, atol=1e-6) - return -exacts[i] - end - end - end + trig_simplified = check_trig_consts(x) + !isequal(trig_simplified, x) && return trig_simplified if oper === (+) args = arguments(x) @@ -153,3 +135,33 @@ function postprocess_root(x) end x # unreachable end + +function check_trig_consts(x) + !iscall(x) && return x + + oper = operation(x) + inv_opers = [asin, acos, atan] + inv_exacts = [0, Symbolics.term(*, pi), + Symbolics.term(/,pi,3), + Symbolics.term(/, pi, 2), + Symbolics.term(/, Symbolics.term(*, 2, pi), 3), + Symbolics.term(/, pi, 6), + Symbolics.term(/, Symbolics.term(*, 5, pi), 6), + Symbolics.term(/, pi, 4) + ] + + if any(isequal(oper, o) for o in inv_opers) && isempty(Symbolics.get_variables(x)) + val = eval(Symbolics.toexpr(x)) + for i in eachindex(inv_exacts) + exact_val = eval(Symbolics.toexpr(inv_exacts[i])) + if isapprox(exact_val, val, atol=1e-6) + return inv_exacts[i] + elseif isapprox(-exact_val, val, atol=1e-6) + return -inv_exacts[i] + end + end + end + + # add [sin, cos, tan] simplifications in the future? + return x +end