diff --git a/ext/SymbolicsGroebnerExt.jl b/ext/SymbolicsGroebnerExt.jl index ea156fe23..0d9c3f8fb 100644 --- a/ext/SymbolicsGroebnerExt.jl +++ b/ext/SymbolicsGroebnerExt.jl @@ -182,6 +182,7 @@ function solve_zerodim(eqs::Vector, vars::Vector{Num}; dropmultiplicity=true, wa generating = true n_iterations = 1 separating_form = new_var + eqs = Symbolics.wrap.(eqs) while generating new_eqs = copy(eqs) @@ -302,18 +303,17 @@ function Symbolics.solve_multivar(eqs::Vector, vars::Vector{Num}; dropmultiplici sol end +# Helps with precompilation time PrecompileTools.@setup_workload begin @variables a b c x y z equation1 = a*log(x)^b + c ~ 0 equation_actually_polynomial = sin(x^2 +1)^2 + sin(x^2 + 1) + 3 simple_linear_equations = [x - y, y + 2z] - expr_with_params = expand((x + b)*(x^2 + 2x + 1)*(x^2 - a)) equations_intersect_sphere_line = [x^2 + y^2 + z^2 - 9, x - 2y + 3, y - z] PrecompileTools.@compile_workload begin symbolic_solve(equation1, x) symbolic_solve(equation_actually_polynomial) symbolic_solve(simple_linear_equations, [x, y]) - symbolic_solve(expr_with_params, x, dropmultiplicity=false) symbolic_solve(equations_intersect_sphere_line, [x, y, z]) end end diff --git a/ext/SymbolicsNemoExt.jl b/ext/SymbolicsNemoExt.jl index 137a7660f..f40205289 100644 --- a/ext/SymbolicsNemoExt.jl +++ b/ext/SymbolicsNemoExt.jl @@ -1,5 +1,6 @@ module SymbolicsNemoExt using Nemo +import Symbolics.PrecompileTools if isdefined(Base, :get_extension) using Symbolics @@ -73,4 +74,16 @@ function Symbolics.gcd_use_nemo(poly1::Num, poly2::Num) return sym_gcd end + +# Helps with precompilation time +PrecompileTools.@setup_workload begin + @variables a b c x y z + expr_with_params = expand((x + b)*(x^2 + 2x + 1)*(x^2 - a)) + PrecompileTools.@compile_workload begin + symbolic_solve(expr_with_params, x, dropmultiplicity=false) + symbolic_solve(x^10 - a^10, x, dropmultiplicity=false) + symbolic_solve([x^2 - a^2, x + a], x) + end +end + end # module diff --git a/src/solver/main.jl b/src/solver/main.jl index 10392fd3e..be9fd73a7 100644 --- a/src/solver/main.jl +++ b/src/solver/main.jl @@ -145,6 +145,11 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where expr = Vector{Num}(expr) end + if expr_univar && !x_univar + expr = [expr] + expr_univar = false + end + if x_univar sols = [] if expr_univar diff --git a/test/solver.jl b/test/solver.jl index 9015b369a..6c9e4a905 100644 --- a/test/solver.jl +++ b/test/solver.jl @@ -54,7 +54,7 @@ function check_approx(arr1, arr2) return true end -@variables x y z +@variables x y z a b c d e @testset "Invalid input" begin @test_throws AssertionError Symbolics.get_roots(x, x^2) @@ -172,6 +172,11 @@ end # expr = x^3 + sqrt(complex(-2//1))*x + 2 end +@testset "Multipoly solver" begin + @test isequal(symbolic_solve([x^2 - 1, x + 1], x)[1], -1) + @test isequal(symbolic_solve([x^2 - a^2, x + a], x)[1], -a) + @test isequal(symbolic_solve([x^20 - a^20, x + a], x)[1], -a) +end @testset "Multivar solver" begin @variables x y z @test isequal(symbolic_solve([x^4 - 1, x - 2], [x]), []) @@ -356,6 +361,7 @@ end # standby # @test Symbolics.(log(y) + x , x) == 1 + @test Symbolics.n_func_occ(log(a*x) + b, x) == 1 @test Symbolics.n_func_occ(log(x + sin((x^2 + x)/log(x))), x) == 3 @test Symbolics.n_func_occ(x^2 + x + x^3, x) == 1 @@ -373,7 +379,6 @@ end @testset "Isolate/Attract solve" begin - @variables a b c d e x lhs = ia_solve(a*x^b + c, x)[1] lhs2 = symbolic_solve(a*x^b + c, x)[1] rhs = Symbolics.term(^, -c.val/a.val, 1/b.val) @@ -513,8 +518,6 @@ using LambertW #Testing @testset "Algebraic solver tests" begin - - @variables x y z a b c function correctAns(solve_roots, known_roots) solve_roots = sort_roots(eval.(Symbolics.toexpr.(solve_roots))) known_roots = sort_roots(known_roots)