Skip to content

Commit

Permalink
Merge pull request #1243 from sumiya11/master
Browse files Browse the repository at this point in the history
Simplify sqrt((A/B)^2) => A/B
  • Loading branch information
ChrisRackauckas authored Aug 30, 2024
2 parents 8cc74f9 + 95e61a2 commit 0c9cb7b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 23 deletions.
39 changes: 25 additions & 14 deletions src/solver/postprocess.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@

# Alex: make sure `Num`s are not processed here as they'd break it.
_is_const_number(x::Number) = true
function _is_const_number(x::SymbolicUtils.BasicSymbolic)
!iscall(x) && return false
all(_is_const_number, arguments(x))
end

_postprocess_root(x) = x

function _postprocess_root(x::Number)
Expand Down Expand Up @@ -60,21 +54,38 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
return arguments(x)[1]
end

# sqrt(M^2 * N) => M * sqrt(N)
# sqrt((N / D)^2 * M) => N / D * sqrt(M)
if iscall(x) && (operation(x) === sqrt || operation(x) === ssqrt)
arg = arguments(x)[1]
if arg isa Integer
square, radical = big(1), big(1)
for (p, d) in collect(Primes.factor(abs(arg)))
function squarefree_decomp(x::Integer)
square, squarefree = big(1), big(1)
for (p, d) in collect(Primes.factor(abs(x)))
q, r = divrem(d, 2)
square *= p^q
radical *= p^r
squarefree *= p^r
end
square, squarefree
end
arg = arguments(x)[1]
if arg isa Integer
square, squarefree = squarefree_decomp(arg)
if arg < 0
square = im * square
end
isone(radical) && return square
return square * Symbolics.term(Symbolics.operation(x), radical)
if !isone(square)
return square * Symbolics.term(Symbolics.operation(x), squarefree)
end
elseif arg isa Rational
n, d = numerator(arg), denominator(arg)
n_square, n_squarefree = squarefree_decomp(n)
if n < 0
n_square = im * n_square
end
d_square, d_squarefree = squarefree_decomp(d)
nd_square = n_square // d_square
nd_squarefree = n_squarefree // d_squarefree
if !isone(nd_square)
return nd_square * Symbolics.term(Symbolics.operation(x), nd_squarefree)
end
end
end

Expand Down
16 changes: 7 additions & 9 deletions test/solver.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Symbolics
import Symbolics: ssqrt, slog, scbrt, symbolic_solve, ia_solve, postprocess_root, _is_const_number
import Symbolics: ssqrt, slog, scbrt, symbolic_solve, ia_solve, postprocess_root
using Groebner, Nemo
E = Base.MathConstants.e

Expand Down Expand Up @@ -309,14 +309,9 @@ end
@testset "Post Process roots" begin
SymbolicUtils.@syms __x
__symsqrt(x) = SymbolicUtils.term(ssqrt, x)
@test postprocess_root(2 // 1) == 2 && postprocess_root(2 + 0*im) == 2
@test postprocess_root(__symsqrt(4)) == 2
@test isequal(postprocess_root(__symsqrt(__x)^2), __x)

@test !_is_const_number(__x) && !_is_const_number(sqrt(__x))
@test _is_const_number(1) && _is_const_number(2 // 3) && _is_const_number(3 + 4im)
@test _is_const_number(SymbolicUtils.term(sqrt, 2) + 21)
@test _is_const_number((SymbolicUtils.term(exp, 2) * SymbolicUtils.term(exp, 2)) // 99)
@test Symbolics.postprocess_root(2 // 1) == 2 && Symbolics.postprocess_root(2 + 0*im) == 2
@test Symbolics.postprocess_root(__symsqrt(4)) == 2
@test isequal(Symbolics.postprocess_root(__symsqrt(__x)^2), __x)

@test Symbolics.postprocess_root( SymbolicUtils.term(^, __x, 0) ) == 1
@test Symbolics.postprocess_root( SymbolicUtils.term(^, Base.MathConstants.e, 0) ) == 1
Expand All @@ -326,6 +321,9 @@ end
x = Symbolics.term(sqrt, 2)
@test isequal(Symbolics.postprocess_root( expand((x + 1)^4) ), 17 + 12x)
@test isequal(Symbolics.postprocess_root( x^5 ), 4 * x)

@test isequal(Symbolics.postprocess_root(Symbolics.term(sqrt, 9//4)), 3//2)
@test isequal(Symbolics.postprocess_root(Symbolics.term(sqrt, -27//8)), im*3//2*Symbolics.term(sqrt, 3//2))
end


Expand Down

0 comments on commit 0c9cb7b

Please sign in to comment.