Skip to content

Commit

Permalink
docs and const arr
Browse files Browse the repository at this point in the history
  • Loading branch information
n0rbed committed Sep 13, 2024
1 parent d352d8a commit 4a5b616
Showing 1 changed file with 34 additions and 14 deletions.
48 changes: 34 additions & 14 deletions src/solver/postprocess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
end
end

trig_simplified = check_trig_consts(x)
!isequal(trig_simplified, x) && return trig_simplified
x = convert_consts(x)

if oper === (+)
args = arguments(x)
Expand Down Expand Up @@ -136,28 +135,49 @@ function postprocess_root(x)
x # unreachable
end

function check_trig_consts(x)
!iscall(x) && return x

oper = operation(x)
inv_opers = [asin, acos, atan]
inv_exacts = [0, Symbolics.term(*, pi),
inv_exacts = [0, Symbolics.term(*, pi),
Symbolics.term(/,pi,3),
Symbolics.term(/, pi, 2),
Symbolics.term(/, Symbolics.term(*, 2, pi), 3),
Symbolics.term(/, pi, 6),
Symbolics.term(/, Symbolics.term(*, 5, pi), 6),
Symbolics.term(/, pi, 4)
]
]
inv_evald = Symbolics.symbolic_to_float.(inv_exacts)

const inv_pairs = collect(zip(inv_exacts, inv_evald))
"""
function convert_consts(x)
This function takes BasicSymbolic terms as input (x) and attempts
to simplify these basic symbolic terms using known values.
Currently, this function only supports inverse trignometric functions.

Check warning on line 154 in src/solver/postprocess.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"trignometric" should be "trigonometric".
## Examples
```jldoctest
julia> Symbolics.convert_consts(Symbolics.term(acos, 0))
π / 2
julia> Symbolics.convert_consts(Symbolics.term(atan, 0))
0
julia> Symbolics.convert_consts(Symbolics.term(atan, 1))
π / 4
```
"""
function convert_consts(x)
!iscall(x) && return x

oper = operation(x)
inv_opers = [asin, acos, atan]

if any(isequal(oper, o) for o in inv_opers) && isempty(Symbolics.get_variables(x))
val = Symbolics.symbolic_to_float(x)
for i in eachindex(inv_exacts)
exact_val = Symbolics.symbolic_to_float(inv_exacts[i])
if isapprox(exact_val, val, atol=1e-6)
return inv_exacts[i]
elseif isapprox(-exact_val, val, atol=1e-6)
return -inv_exacts[i]
for (exact, evald) in inv_pairs
if isapprox(evald, val)
return exact
elseif isapprox(-evald, val)
return -exact
end
end
end
Expand Down

0 comments on commit 4a5b616

Please sign in to comment.