Skip to content

Commit

Permalink
Merge pull request #1237 from n0rbed/multivar_stuff
Browse files Browse the repository at this point in the history
changed output format of zero dimension sols
  • Loading branch information
ChrisRackauckas authored Sep 7, 2024
2 parents 38c83d5 + fe29321 commit f84b877
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 19 deletions.
8 changes: 8 additions & 0 deletions docs/src/manual/solver.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ Symbolics.symbolic_solve(eqs, [x,y,z])
- [ ] Systems of polynomial equations with parameters and positive dimensional systems
- [ ] Inequalities

### Expressions we can not solve (but aim to)
```
# Mathematica
In[1]:= Reduce[x^2 - x - 6 > 0, x]
Out[1]= x < -2 || x > 3
```

# References

[^1]: [Rouillier, F. Solving Zero-Dimensional Systems Through the Rational Univariate Representation. AAECC 9, 433–461 (1999).](https://doi.org/10.1007/s002000050114)
Expand Down
22 changes: 18 additions & 4 deletions ext/SymbolicsGroebnerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ end
# Given a GB in k[params][vars] produces a GB in k(params)[vars]
function demote(gb, vars::Vector{Num}, params::Vector{Num})
isequal(gb, [1]) && return gb

gb = Symbolics.wrap.(SymbolicUtils.toterm.(gb))
Symbolics.check_polynomial.(gb)

Expand All @@ -126,7 +127,7 @@ function demote(gb, vars::Vector{Num}, params::Vector{Num})
ring_param, params_demoted = Nemo.polynomial_ring(Nemo.base_ring(ring_flat), map(string, nemo_params))
ring_demoted, vars_demoted = Nemo.polynomial_ring(Nemo.fraction_field(ring_param), map(string, nemo_vars), internal_ordering=:lex)
varmap = Dict((nemo_vars .=> vars_demoted)..., (nemo_params .=> params_demoted)...)
gb_demoted = map(f -> nemo_crude_evaluate(f, varmap), nemo_gb)
gb_demoted = map(f -> ring_demoted(nemo_crude_evaluate(f, varmap)), nemo_gb)
result = empty(gb_demoted)
while true
gb_demoted = map(f -> Nemo.map_coefficients(c -> c // Nemo.leading_coefficient(f), f), gb_demoted)
Expand Down Expand Up @@ -176,6 +177,7 @@ function solve_zerodim(eqs::Vector, vars::Vector{Num}; dropmultiplicity=true, wa
# Use a new variable to separate the input polynomials (Reference above)
new_var = gen_separating_var(vars)
old_len = length(vars)
old_vars = deepcopy(vars)
vars = vcat(vars, new_var)

new_eqs = []
Expand Down Expand Up @@ -204,6 +206,13 @@ function solve_zerodim(eqs::Vector, vars::Vector{Num}; dropmultiplicity=true, wa
return []
end

for i in reverse(eachindex(new_eqs))
all_present = Symbolics.get_variables(new_eqs[i])
if length(intersect(all_present, vars)) < 1
deleteat!(new_eqs, i)
end
end

new_eqs = demote(new_eqs, vars, params)
new_eqs = map(Symbolics.unwrap, new_eqs)

Expand Down Expand Up @@ -233,7 +242,10 @@ function solve_zerodim(eqs::Vector, vars::Vector{Num}; dropmultiplicity=true, wa
end

# non-cyclic case
n_iterations > 10 && return []
if n_iterations > 10
warns && @warn("symbolic_solve can not currently solve this system of polynomials.")
return nothing
end

n_iterations += 1
end
Expand Down Expand Up @@ -295,11 +307,13 @@ function Symbolics.solve_multivar(eqs::Vector, vars::Vector{Num}; dropmultiplici
isempty(tr_basis) && return nothing
vars_gen = setdiff(vars, tr_basis)
sol = solve_zerodim(eqs, vars_gen; dropmultiplicity=dropmultiplicity, warns=warns)

for roots in sol
for x in tr_basis
roots[x] = x
end
end

sol
end

Expand All @@ -313,8 +327,8 @@ PrecompileTools.@setup_workload begin
PrecompileTools.@compile_workload begin
symbolic_solve(equation1, x)
symbolic_solve(equation_actually_polynomial)
symbolic_solve(simple_linear_equations, [x, y])
symbolic_solve(equations_intersect_sphere_line, [x, y, z])
symbolic_solve(simple_linear_equations, [x, y], warns=false)
symbolic_solve(equations_intersect_sphere_line, [x, y, z], warns=false)
end
end

Expand Down
19 changes: 9 additions & 10 deletions src/solver/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,6 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
for var in x
check_x(var)
end
if length(x) == 1
x = x[1]
x_univar = true
end
end

if !(expr isa Vector)
Expand Down Expand Up @@ -174,8 +170,8 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
sols = []
if expr_univar
sols = check_poly_inunivar(expr, x) ?
solve_univar(expr, x, dropmultiplicity = dropmultiplicity) :
ia_solve(expr, x, warns = warns)
solve_univar(expr, x, dropmultiplicity=dropmultiplicity) :
ia_solve(expr, x, warns=warns)
isequal(sols, nothing) && return nothing
else
for i in eachindex(expr)
Expand All @@ -185,7 +181,7 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
end
end
sols = solve_multipoly(
expr, x, dropmultiplicity = dropmultiplicity, warns = warns)
expr, x, dropmultiplicity=dropmultiplicity, warns=warns)
isequal(sols, nothing) && return nothing
end

Expand All @@ -203,11 +199,13 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
end
end

sols = solve_multivar(expr, x, dropmultiplicity = dropmultiplicity)
sols = solve_multivar(expr, x, dropmultiplicity=dropmultiplicity, warns=warns)
isequal(sols, nothing) && return nothing
for sol in sols
for var in x
sol[var] = postprocess_root(sol[var])
if haskey(sol, var)
sol[var] = postprocess_root(sol[var])
end
end
end

Expand All @@ -231,6 +229,7 @@ function symbolic_solve(expr; x...)
vars = wrap.(vars)
@assert all(v isa Num for v in vars) "All variables should be Nums or BasicSymbolics"

vars = isone(length(vars)) ? vars[1] : vars
return symbolic_solve(expr, vars; x...)
end

Expand All @@ -256,7 +255,7 @@ implemented in the function `get_roots` and its children.
# Examples
"""
function solve_univar(expression, x; dropmultiplicity = true)
function solve_univar(expression, x; dropmultiplicity=true)
args = []
mult_n = 1
expression = unwrap(expression)
Expand Down
21 changes: 16 additions & 5 deletions test/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,13 @@ end
# cyclic 3
@variables z1 z2 z3
eqs = [z1 + z2 + z3, z1*z2 + z1*z3 + z2*z3, z1*z2*z3 - 1]
sol = Symbolics.symbolic_solve(eqs, [z1,z2,z3])
sol = symbolic_solve(eqs, [z1,z2,z3])
backward = [Symbolics.substitute(eqs, s) for s in sol]
@test all(x -> all(isapprox.(eval(Symbolics.toexpr(x)), 0; atol=1e-6)), backward)

@variables x y
eqs = [2332//232*x + 2131232*y - 1//343434, x + y + 1]
sol = Symbolics.symbolic_solve(eqs, [x,y])
sol = symbolic_solve(eqs, [x,y])
backward = [Symbolics.substitute(eqs, s) for s in sol]
@test all(x -> all(isapprox.(eval(Symbolics.toexpr(x)), 0; atol=1e-6)), backward)

Expand All @@ -259,10 +259,12 @@ end
# at most 4 roots by Bézout's theorem
rand_eq(xs, d) = rand(-10:10) + rand(-10:10)*x + rand(-10:10)*y + rand(-10:10)*x*y + rand(-10:10)*x^2 + rand(-10:10)*y^2
eqs = [rand_eq([x,y],2), rand_eq([x,y],2)]
sol = Symbolics.symbolic_solve(eqs, [x,y])
sol = symbolic_solve(eqs, [x,y])
backward = [Symbolics.substitute(eqs, s) for s in sol]
@test all(x -> all(isapprox.(eval(Symbolics.toexpr(x)), 0; atol=1e-6)), backward)
end

@test isnothing(symbolic_solve([x^2, x*y, y^2], [x,y], warns=false))
end

@testset "Multivar parametric" begin
Expand All @@ -277,8 +279,17 @@ end
@test isnothing(symbolic_solve([x*y - a, sin(x)], [x, y]))

@variables t w u v
sol = symbolic_solve([t*w - 1 ~ 4, u + v + w ~ 1], [t,w,u,v])
@test isequal(sol, [Dict(u => u, t => -5 / (-1 + u + v), v => v, w => 1 - u - v)])
sol = symbolic_solve([t*w - 1 ~ 4, u + v + w ~ 1], [t,w])
@test isequal(sol, [Dict(t => -5 / (-1 + u + v), w => 1 - u - v)])

sol = symbolic_solve([x-y, y-z], [x])
@test isequal(sol, [Dict(x=>z)])

sol = symbolic_solve([x-y, y-z], [x, y])
@test isequal(sol, [Dict(x=>z, y=>z)])

sol = symbolic_solve([x + y - z, y - z], [x])
@test isequal(sol, [Dict(x=>0)])
end

@testset "Factorisation" begin
Expand Down

0 comments on commit f84b877

Please sign in to comment.