Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify sqrt((A/B)^2) => A/B #1243

Merged
merged 2 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions src/solver/postprocess.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@

# Alex: make sure `Num`s are not processed here as they'd break it.
_is_const_number(x::Number) = true
function _is_const_number(x::SymbolicUtils.BasicSymbolic)
!iscall(x) && return false
all(_is_const_number, arguments(x))
end

_postprocess_root(x) = x

function _postprocess_root(x::Number)
Expand Down Expand Up @@ -60,21 +54,38 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
return arguments(x)[1]
end

# sqrt(M^2 * N) => M * sqrt(N)
# sqrt((N / D)^2 * M) => N / D * sqrt(M)
if iscall(x) && (operation(x) === sqrt || operation(x) === ssqrt)
arg = arguments(x)[1]
if arg isa Integer
square, radical = big(1), big(1)
for (p, d) in collect(Primes.factor(abs(arg)))
function squarefree_decomp(x::Integer)
square, squarefree = big(1), big(1)
for (p, d) in collect(Primes.factor(abs(x)))
q, r = divrem(d, 2)
square *= p^q
radical *= p^r
squarefree *= p^r
end
square, squarefree
end
arg = arguments(x)[1]
if arg isa Integer
square, squarefree = squarefree_decomp(arg)
if arg < 0
square = im * square
end
isone(radical) && return square
return square * Symbolics.term(Symbolics.operation(x), radical)
if !isone(square)
return square * Symbolics.term(Symbolics.operation(x), squarefree)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use maketerm to be more generic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use Symbolics.term since maketerm simplifies arithmetic expressions that involve Base.MathConstants.pi and friends into floats

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

end
elseif arg isa Rational
n, d = numerator(arg), denominator(arg)
n_square, n_squarefree = squarefree_decomp(n)
if n < 0
n_square = im * n_square
end
d_square, d_squarefree = squarefree_decomp(d)
nd_square = n_square // d_square
nd_squarefree = n_squarefree // d_squarefree
if !isone(nd_square)
return nd_square * Symbolics.term(Symbolics.operation(x), nd_squarefree)
end
end
end

Expand Down
16 changes: 7 additions & 9 deletions test/solver.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Symbolics
import Symbolics: ssqrt, slog, scbrt, symbolic_solve, ia_solve, postprocess_root, _is_const_number
import Symbolics: ssqrt, slog, scbrt, symbolic_solve, ia_solve, postprocess_root
using Groebner, Nemo
E = Base.MathConstants.e

Expand Down Expand Up @@ -309,14 +309,9 @@ end
@testset "Post Process roots" begin
SymbolicUtils.@syms __x
__symsqrt(x) = SymbolicUtils.term(ssqrt, x)
@test postprocess_root(2 // 1) == 2 && postprocess_root(2 + 0*im) == 2
@test postprocess_root(__symsqrt(4)) == 2
@test isequal(postprocess_root(__symsqrt(__x)^2), __x)

@test !_is_const_number(__x) && !_is_const_number(sqrt(__x))
@test _is_const_number(1) && _is_const_number(2 // 3) && _is_const_number(3 + 4im)
@test _is_const_number(SymbolicUtils.term(sqrt, 2) + 21)
@test _is_const_number((SymbolicUtils.term(exp, 2) * SymbolicUtils.term(exp, 2)) // 99)
@test Symbolics.postprocess_root(2 // 1) == 2 && Symbolics.postprocess_root(2 + 0*im) == 2
@test Symbolics.postprocess_root(__symsqrt(4)) == 2
@test isequal(Symbolics.postprocess_root(__symsqrt(__x)^2), __x)

@test Symbolics.postprocess_root( SymbolicUtils.term(^, __x, 0) ) == 1
@test Symbolics.postprocess_root( SymbolicUtils.term(^, Base.MathConstants.e, 0) ) == 1
Expand All @@ -326,6 +321,9 @@ end
x = Symbolics.term(sqrt, 2)
@test isequal(Symbolics.postprocess_root( expand((x + 1)^4) ), 17 + 12x)
@test isequal(Symbolics.postprocess_root( x^5 ), 4 * x)

@test isequal(Symbolics.postprocess_root(Symbolics.term(sqrt, 9//4)), 3//2)
@test isequal(Symbolics.postprocess_root(Symbolics.term(sqrt, -27//8)), im*3//2*Symbolics.term(sqrt, 3//2))
end


Expand Down
Loading