Skip to content

Commit

Permalink
Merge branch 'RootFinding' of https://github.com/n0rbed/Symbolics.jl
Browse files Browse the repository at this point in the history
…into RootFinding
  • Loading branch information
n0rbed committed Aug 17, 2024
2 parents 968fd9e + 0a21bbb commit 4863017
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 13 deletions.
19 changes: 12 additions & 7 deletions ext/SymbolicsGroebnerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ end

function gen_separating_var(vars)
n = 1
new_var = (Symbolics.@variables HAT)[1]
new_var = (Symbolics.@variables _T)[1]
present = any(isequal(new_var, var) for var in vars)
while present
new_var = Symbolics.variables(repeat("_", n) * "HAT")[1]
new_var = Symbolics.variables(repeat("_", n) * "_T")[1]
present = any(isequal(new_var, var) for var in vars)
n += 1
end
Expand All @@ -112,6 +112,7 @@ end

# Given a GB in k[params][vars] produces a GB in k(params)[vars]
function demote(gb, vars::Vector{Num}, params::Vector{Num})
isequal(gb, [1]) && return gb
gb = Symbolics.wrap.(SymbolicUtils.toterm.(gb))
Symbolics.check_polynomial.(gb)

Expand Down Expand Up @@ -173,6 +174,8 @@ function Symbolics.solve_multivar(eqs::Vector, vars::Vector{Num}; dropmultiplici
# Through the Rational Univariate Representation.
# AAECC 9, 433–461 (1999). https://doi.org/10.1007/s002000050114

rng = Groebner.Random.Xoshiro(42)

all_indeterminates = reduce(union, map(Symbolics.get_variables, eqs))
params = map(Symbolics.Num Symbolics.wrap, setdiff(all_indeterminates, vars))

Expand All @@ -184,19 +187,20 @@ function Symbolics.solve_multivar(eqs::Vector, vars::Vector{Num}; dropmultiplici
new_eqs = []
generating = true
n_iterations = 1
separating_form = new_var

while generating
new_eqs = copy(eqs)
eq = new_var
separating_form = new_var
for i = 1:(old_len)
eq += BigInt(rand(-n_iterations:n_iterations))*vars[i]
separating_form += BigInt(rand(rng, -n_iterations:n_iterations))*vars[i]
end

if isequal(eq, new_var)
if isequal(separating_form, new_var)
continue
end

push!(new_eqs, eq)
push!(new_eqs, separating_form)

new_eqs = Symbolics.groebner_basis(new_eqs, ordering=Lex(vcat(vars, params)))

Expand Down Expand Up @@ -260,7 +264,8 @@ function Symbolics.solve_multivar(eqs::Vector, vars::Vector{Num}; dropmultiplici
subbded_eq = Symbolics.substitute(subbded_eq, Dict([var_tosolve => 0]); fold=false)
new_var_sols = [-subbded_eq]
@assert length(new_var_sols) == 1
roots[var_tosolve] = new_var_sols[1]
root = new_var_sols[1]
roots[var_tosolve] = root
end
end

Expand Down
3 changes: 2 additions & 1 deletion src/solver/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
isequal(sols, nothing) && return nothing
end

sols = map(sol -> sol isa RootsOf ? sol : postprocess_root(sol), sols)
# sols = map(sol -> sol isa RootsOf ? sol : postprocess_root(sol), sols)
sols = map(postprocess_root, sols)
return sols
end

Expand Down
1 change: 1 addition & 0 deletions src/solver/postprocess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,5 @@ function postprocess_root(x)
end
isequal(typeof(old_x), typeof(x)) && isequal(old_x, x) && return x
end
x # unreachable
end
6 changes: 1 addition & 5 deletions src/solver/solve_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,8 @@ function slog(n)
return term(slog, n)
end

struct RootsOf
poly::Num
var::Num
end
const RootsOf = (SymbolicUtils.@syms roots_of(poly,var))[1]

Base.show(io::IO, r::RootsOf) = print(io, "roots_of(", r.poly, ", ", r.var, ")")
Base.show(io::IO, f::typeof(ssqrt)) = print(io, "")
Base.show(io::IO, r::typeof(scbrt)) = print(io, "")
Base.show(io::IO, r::typeof(slog)) = print(io, "slog")
Expand Down
5 changes: 5 additions & 0 deletions test/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ end
end

@testset "Multivar solver" begin
@test isequal(symbolic_solve([x^4 - 1, x - 2], [x]), [])

# TODO: test this properly
sol = symbolic_solve([x^3 + 1, x*y^3 - 1], [x, y])

eqs = [x*y + 2x^2, y^2 -1]
arr_calcd_roots = sort_arr(symbolic_solve(eqs, [x,y]), [x,y])
arr_known_roots = sort_arr([Dict(x=>-1//2, y=>1), Dict(x=>0, y=>-1), Dict(x=>0, y=>1), Dict(x=>1//2, y=>-1)], [x,y])
Expand Down

0 comments on commit 4863017

Please sign in to comment.