Skip to content

Commit

Permalink
Merge pull request #1234 from n0rbed/multivar_stuff
Browse files Browse the repository at this point in the history
Fixed some bugs, symbolic_solve(single expr, [multiple_vars]), and precompiled nemo
  • Loading branch information
ChrisRackauckas authored Aug 25, 2024
2 parents 046d2ef + bf454dc commit 7097212
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 6 deletions.
4 changes: 2 additions & 2 deletions ext/SymbolicsGroebnerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ function solve_zerodim(eqs::Vector, vars::Vector{Num}; dropmultiplicity=true, wa
generating = true
n_iterations = 1
separating_form = new_var
eqs = Symbolics.wrap.(eqs)

while generating
new_eqs = copy(eqs)
Expand Down Expand Up @@ -302,18 +303,17 @@ function Symbolics.solve_multivar(eqs::Vector, vars::Vector{Num}; dropmultiplici
sol
end

# Helps with precompilation time
PrecompileTools.@setup_workload begin
@variables a b c x y z
equation1 = a*log(x)^b + c ~ 0
equation_actually_polynomial = sin(x^2 +1)^2 + sin(x^2 + 1) + 3
simple_linear_equations = [x - y, y + 2z]
expr_with_params = expand((x + b)*(x^2 + 2x + 1)*(x^2 - a))
equations_intersect_sphere_line = [x^2 + y^2 + z^2 - 9, x - 2y + 3, y - z]
PrecompileTools.@compile_workload begin
symbolic_solve(equation1, x)
symbolic_solve(equation_actually_polynomial)
symbolic_solve(simple_linear_equations, [x, y])
symbolic_solve(expr_with_params, x, dropmultiplicity=false)
symbolic_solve(equations_intersect_sphere_line, [x, y, z])
end
end
Expand Down
13 changes: 13 additions & 0 deletions ext/SymbolicsNemoExt.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module SymbolicsNemoExt
using Nemo
import Symbolics.PrecompileTools

if isdefined(Base, :get_extension)
using Symbolics
Expand Down Expand Up @@ -73,4 +74,16 @@ function Symbolics.gcd_use_nemo(poly1::Num, poly2::Num)
return sym_gcd
end


# Helps with precompilation time
PrecompileTools.@setup_workload begin
@variables a b c x y z
expr_with_params = expand((x + b)*(x^2 + 2x + 1)*(x^2 - a))
PrecompileTools.@compile_workload begin
symbolic_solve(expr_with_params, x, dropmultiplicity=false)
symbolic_solve(x^10 - a^10, x, dropmultiplicity=false)
symbolic_solve([x^2 - a^2, x + a], x)
end
end

end # module
5 changes: 5 additions & 0 deletions src/solver/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
expr = Vector{Num}(expr)
end

if expr_univar && !x_univar
expr = [expr]
expr_univar = false
end

if x_univar
sols = []
if expr_univar
Expand Down
11 changes: 7 additions & 4 deletions test/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function check_approx(arr1, arr2)
return true
end

@variables x y z
@variables x y z a b c d e

@testset "Invalid input" begin
@test_throws AssertionError Symbolics.get_roots(x, x^2)
Expand Down Expand Up @@ -172,6 +172,11 @@ end
# expr = x^3 + sqrt(complex(-2//1))*x + 2
end

@testset "Multipoly solver" begin
@test isequal(symbolic_solve([x^2 - 1, x + 1], x)[1], -1)
@test isequal(symbolic_solve([x^2 - a^2, x + a], x)[1], -a)
@test isequal(symbolic_solve([x^20 - a^20, x + a], x)[1], -a)
end
@testset "Multivar solver" begin
@variables x y z
@test isequal(symbolic_solve([x^4 - 1, x - 2], [x]), [])
Expand Down Expand Up @@ -356,6 +361,7 @@ end

# standby
# @test Symbolics.(log(y) + x , x) == 1
@test Symbolics.n_func_occ(log(a*x) + b, x) == 1

@test Symbolics.n_func_occ(log(x + sin((x^2 + x)/log(x))), x) == 3
@test Symbolics.n_func_occ(x^2 + x + x^3, x) == 1
Expand All @@ -373,7 +379,6 @@ end


@testset "Isolate/Attract solve" begin
@variables a b c d e x
lhs = ia_solve(a*x^b + c, x)[1]
lhs2 = symbolic_solve(a*x^b + c, x)[1]
rhs = Symbolics.term(^, -c.val/a.val, 1/b.val)
Expand Down Expand Up @@ -513,8 +518,6 @@ using LambertW
#Testing

@testset "Algebraic solver tests" begin

@variables x y z a b c
function correctAns(solve_roots, known_roots)
solve_roots = sort_roots(eval.(Symbolics.toexpr.(solve_roots)))
known_roots = sort_roots(known_roots)
Expand Down

0 comments on commit 7097212

Please sign in to comment.