diff --git a/src/solver/main.jl b/src/solver/main.jl index e9d3f3aeb..e6671e1f5 100644 --- a/src/solver/main.jl +++ b/src/solver/main.jl @@ -195,7 +195,7 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where for e in expr for var in x if !check_poly_inunivar(e, var) - warns && @warn("This system can not be currently solved by solve.") + warns && @warn("This system can not be currently solved by `symbolic_solve`.") return nothing end end @@ -276,7 +276,7 @@ function solve_univar(expression, x; dropmultiplicity=true) end end - subs, filtered_expr = filter_poly(expression, x) + subs, filtered_expr, assumptions = filter_poly(expression, x, assum=true) coeffs, constant = polynomial_coeffs(filtered_expr, [x]) degree = sdegree(coeffs, x) @@ -296,18 +296,28 @@ function solve_univar(expression, x; dropmultiplicity=true) append!(arr_roots, og_arr_roots) end end - - return arr_roots end if length(factors) != 1 - for factor in factors_subbed - roots = solve_univar(factor, x, dropmultiplicity = dropmultiplicity) + for i in eachindex(factors_subbed) + if !any(isequal(x, var) for var in get_variables(factors[i])) + continue + end + roots = solve_univar(factors_subbed[i], x, dropmultiplicity = dropmultiplicity) append!(arr_roots, roots) end end + for i in reverse(eachindex(arr_roots)) + for j in eachindex(assumptions) + if isequal(substitute(assumptions[j], Dict(x=>arr_roots[i])), 0) + deleteat!(arr_roots, i) + end + end + end + if isequal(arr_roots, []) + @assert check_polynomial(expression) "This expression could not be solved by `symbolic_solve`." return [RootsOf(wrap(expression), wrap(x))] end diff --git a/src/solver/preprocess.jl b/src/solver/preprocess.jl index cd586edb3..2e0e238af 100644 --- a/src/solver/preprocess.jl +++ b/src/solver/preprocess.jl @@ -40,16 +40,19 @@ function clean_f(filtered_expr, var, subs) unwrapped_f = unwrap(filtered_expr) !iscall(unwrapped_f) && return filtered_expr oper = operation(unwrapped_f) + assumptions = [] if oper === (/) args = arguments(unwrapped_f) if any(isequal(var, x) for x in get_variables(args[2])) - return filtered_expr + filtered_expr = expand(args[1] * args[2]) + push!(assumptions, substitute(args[2], subs, fold=false)) + return filtered_expr, assumptions end filtered_expr = args[1] @info "Assuming $(substitute(args[2], subs, fold=false) != 0)" end - return filtered_expr + return filtered_expr, assumptions end """ @@ -238,15 +241,17 @@ julia> filter_poly((x+1)*term(log, 3), x) (Dict{Any, Any}(var"##247" => log(3)), var"##247"*(1 + x)) ``` """ -function filter_poly(og_expr, var) +function filter_poly(og_expr, var; assum=false) expr = deepcopy(og_expr) expr = unwrap(expr) vars = get_variables(expr) # handle edge cases if !isequal(vars, []) && isequal(vars[1], expr) + assum && return Dict{Any, Any}(), expr, [] return (Dict{Any, Any}(), expr) elseif isequal(vars, []) + assum && return filter_stuff(expr), [] return filter_stuff(expr) end @@ -256,14 +261,16 @@ function filter_poly(og_expr, var) # reassemble expr to avoid variables remembering original values issue and clean args = arguments(expr) oper = operation(expr) - new_expr = clean_f(term(oper, args...), var, subs) + new_expr, assumptions = clean_f(term(oper, args...), var, subs) + assum && return subs, new_expr, assumptions return subs, new_expr end -function filter_poly(og_expr) + +function filter_poly(og_expr; assum=false) new_var = gensym() new_var = (@variables $(new_var))[1] - return filter_poly(og_expr, new_var) + return filter_poly(og_expr, new_var; assum=assum) end diff --git a/test/solver.jl b/test/solver.jl index 7a39c8924..9eb2e8ec8 100644 --- a/test/solver.jl +++ b/test/solver.jl @@ -57,9 +57,8 @@ end @variables x y z a b c d e @testset "Invalid input" begin - @test_throws AssertionError Symbolics.get_roots(x, x^2) - @test_throws AssertionError Symbolics.get_roots(x^3 + sin(x), x) - @test_throws AssertionError Symbolics.get_roots(1/x, x) + @test_throws AssertionError symbolic_solve(x, x^2) + @test_throws AssertionError symbolic_solve(1/x, x) end @testset "Deg 1 univar" begin