Skip to content

Commit

Permalink
Merge pull request #1271 from n0rbed/rational
Browse files Browse the repository at this point in the history
Simplifying fractions before solving
  • Loading branch information
n0rbed authored Sep 14, 2024
2 parents 6366280 + 8cb25a8 commit 7043d1d
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 16 deletions.
3 changes: 2 additions & 1 deletion .typos.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[default.extend-words]
numer = "numer"
Commun = "Commun"
nd = "nd"
nd = "nd"
assum = "assum"
22 changes: 16 additions & 6 deletions src/solver/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
for e in expr
for var in x
if !check_poly_inunivar(e, var)
warns && @warn("This system can not be currently solved by solve.")
warns && @warn("This system can not be currently solved by `symbolic_solve`.")
return nothing
end
end
Expand Down Expand Up @@ -276,7 +276,7 @@ function solve_univar(expression, x; dropmultiplicity=true)
end
end

subs, filtered_expr = filter_poly(expression, x)
subs, filtered_expr, assumptions = filter_poly(expression, x, assumptions=true)
coeffs, constant = polynomial_coeffs(filtered_expr, [x])
degree = sdegree(coeffs, x)

Expand All @@ -296,18 +296,28 @@ function solve_univar(expression, x; dropmultiplicity=true)
append!(arr_roots, og_arr_roots)
end
end

return arr_roots
end

if length(factors) != 1
for factor in factors_subbed
roots = solve_univar(factor, x, dropmultiplicity = dropmultiplicity)
for i in eachindex(factors_subbed)
if !any(isequal(x, var) for var in get_variables(factors[i]))
continue
end
roots = solve_univar(factors_subbed[i], x, dropmultiplicity = dropmultiplicity)
append!(arr_roots, roots)
end
end

for i in reverse(eachindex(arr_roots))
for j in eachindex(assumptions)
if isequal(substitute(assumptions[j], Dict(x=>arr_roots[i])), 0)
deleteat!(arr_roots, i)
end
end
end

if isequal(arr_roots, [])
@assert check_polynomial(expression) "This expression could not be solved by `symbolic_solve`."
return [RootsOf(wrap(expression), wrap(x))]
end

Expand Down
19 changes: 13 additions & 6 deletions src/solver/preprocess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,19 @@ function clean_f(filtered_expr, var, subs)
unwrapped_f = unwrap(filtered_expr)
!iscall(unwrapped_f) && return filtered_expr
oper = operation(unwrapped_f)
assumptions = []

if oper === (/)
args = arguments(unwrapped_f)
if any(isequal(var, x) for x in get_variables(args[2]))
return filtered_expr
filtered_expr = expand(args[1] * args[2])
push!(assumptions, substitute(args[2], subs, fold=false))
return filtered_expr, assumptions
end
filtered_expr = args[1]
@info "Assuming $(substitute(args[2], subs, fold=false) != 0)"
end
return filtered_expr
return filtered_expr, assumptions
end

"""
Expand Down Expand Up @@ -238,15 +241,17 @@ julia> filter_poly((x+1)*term(log, 3), x)
(Dict{Any, Any}(var"##247" => log(3)), var"##247"*(1 + x))
```
"""
function filter_poly(og_expr, var)
function filter_poly(og_expr, var; assumptions=false)
expr = deepcopy(og_expr)
expr = unwrap(expr)
vars = get_variables(expr)

# handle edge cases
if !isequal(vars, []) && isequal(vars[1], expr)
assumptions && return Dict{Any, Any}(), expr, []
return (Dict{Any, Any}(), expr)
elseif isequal(vars, [])
assumptions && return filter_stuff(expr), []
return filter_stuff(expr)
end

Expand All @@ -256,14 +261,16 @@ function filter_poly(og_expr, var)
# reassemble expr to avoid variables remembering original values issue and clean
args = arguments(expr)
oper = operation(expr)
new_expr = clean_f(term(oper, args...), var, subs)
new_expr, assum_array = clean_f(term(oper, args...), var, subs)

assumptions && return subs, new_expr, assum_array
return subs, new_expr
end
function filter_poly(og_expr)

function filter_poly(og_expr; assumptions=false)
new_var = gensym()
new_var = (@variables $(new_var))[1]
return filter_poly(og_expr, new_var)
return filter_poly(og_expr, new_var; assumptions=assumptions)
end


Expand Down
11 changes: 8 additions & 3 deletions test/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,14 @@ end
@variables x y z a b c d e

@testset "Invalid input" begin
@test_throws AssertionError Symbolics.get_roots(x, x^2)
@test_throws AssertionError Symbolics.get_roots(x^3 + sin(x), x)
@test_throws AssertionError Symbolics.get_roots(1/x, x)
@test_throws AssertionError symbolic_solve(x, x^2)
@test_throws AssertionError symbolic_solve(1/x, x)
end

@testset "Nice univar cases" begin
found_roots = symbolic_solve(1/x^2 ~ 1/y^2 - 2/x^3 * (x-y), x)
known_roots = Symbolics.unwrap.([y, -2y])
@test isequal(found_roots, known_roots)
end

@testset "Deg 1 univar" begin
Expand Down

0 comments on commit 7043d1d

Please sign in to comment.