diff --git a/ext/SymbolicsGroebnerExt.jl b/ext/SymbolicsGroebnerExt.jl index b3ae7d8d0..ff65cd996 100644 --- a/ext/SymbolicsGroebnerExt.jl +++ b/ext/SymbolicsGroebnerExt.jl @@ -199,13 +199,19 @@ function Symbolics.solve_multivar(eqs::Vector, vars::Vector{Num}; dropmultiplici push!(new_eqs, eq) new_eqs = Symbolics.groebner_basis(new_eqs, ordering=Lex(vcat(vars, params))) + + # handle "unsolvable" case + if isequal(1, new_eqs[1]) + return [] + end + new_eqs = demote(new_eqs, vars, params) new_eqs = map(Symbolics.unwrap, new_eqs) - # condition for positive dimensionality + # condition for positive dimensionality, i.e. infinite solutions if length(new_eqs) < length(vars) - generating &= false - break + warns && @warn("Infinite number of solutions") + return nothing end # We exit when the system is in Shape Lemma case: @@ -232,15 +238,6 @@ function Symbolics.solve_multivar(eqs::Vector, vars::Vector{Num}; dropmultiplici solutions = [] - # handle "unsolvable" cases - if isequal(1, new_eqs[1]) - return solutions - end - if length(new_eqs) < length(vars) - warns && @warn("Infinite number of solutions") - return nothing - end - # first, solve the first minimal polynomial @assert length(new_eqs) == length(vars) @assert isequal(setdiff(Symbolics.get_variables(new_eqs[1]), params), [new_var]) diff --git a/src/solver/attract.jl b/src/solver/attract.jl index 8a5889823..027f85a99 100644 --- a/src/solver/attract.jl +++ b/src/solver/attract.jl @@ -128,7 +128,7 @@ function attract_logs(lhs, var) condition_y = expand(simplify(lhs, rewriter = SymbolicUtils.Postwalk(SymbolicUtils.Chain(r_conditiony)))) lhs = expand(simplify(lhs, rewriter = SymbolicUtils.Postwalk(SymbolicUtils.Chain(r_addlogs)))) - return lhs, [condition_x, condition_y] + return lhs, [(condition_x, >), (condition_y, >)] end """ diff --git a/src/solver/ia_main.jl b/src/solver/ia_main.jl index 0b52abe8c..0c7e5710d 100644 --- a/src/solver/ia_main.jl +++ b/src/solver/ia_main.jl @@ -94,21 +94,21 @@ function isolate(lhs, var; warns=true, conditions=[]) elseif oper === (log) || oper === (slog) lhs = args[1] rhs = map(sol -> term(^, Base.MathConstants.e, sol), rhs) - push!(conditions, args[1]) + push!(conditions, (args[1], >)) elseif oper === (log2) lhs = args[1] rhs = map(sol -> term(^, 2, sol), rhs) - push!(conditions, args[1]) + push!(conditions, (args[1], >)) elseif oper === (log10) lhs = args[1] rhs = map(sol -> term(^, 10, sol), rhs) - push!(conditions, args[1]) + push!(conditions, (args[1], >)) elseif oper === (sqrt) lhs = args[1] - append!(conditions, rhs) + append!(conditions, [(r, >=) for r in rhs]) rhs = map(sol -> term(^, sol, 2), rhs) elseif oper === (cbrt) @@ -284,9 +284,10 @@ function ia_solve(lhs, var; warns = true) end domain_error = false for j in eachindex(conditions) - cond_val = substitute(conditions[j], Dict(var=>eval(toexpr(sols[i])))) + condition, t = conditions[j] + cond_val = substitute(condition, Dict(var=>eval(toexpr(sols[i])))) cond_val isa Complex && continue - domain_error |= cond_val <= 0 + domain_error |= !t(cond_val, 0) end !domain_error && push!(filtered_sols, sols[i]) end diff --git a/src/solver/polynomialization.jl b/src/solver/polynomialization.jl index 542c5856d..095d7b5e6 100644 --- a/src/solver/polynomialization.jl +++ b/src/solver/polynomialization.jl @@ -316,6 +316,7 @@ function detect_sqrtpoly(lhs, var) !iscall(arg) && continue if isequal(check_sqrt(arg, sqrt_term, var), true) + sqrt_term_n += arg sqrt_term = true continue elseif isequal(check_sqrt(arg, sqrt_term, var), false)