Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
n0rbed committed Aug 17, 2024
1 parent 11a5bc1 commit 8675b38
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 19 deletions.
21 changes: 9 additions & 12 deletions ext/SymbolicsGroebnerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,19 @@ function Symbolics.solve_multivar(eqs::Vector, vars::Vector{Num}; dropmultiplici
push!(new_eqs, eq)

new_eqs = Symbolics.groebner_basis(new_eqs, ordering=Lex(vcat(vars, params)))

# handle "unsolvable" case
if isequal(1, new_eqs[1])
return []
end

new_eqs = demote(new_eqs, vars, params)
new_eqs = map(Symbolics.unwrap, new_eqs)

# condition for positive dimensionality
# condition for positive dimensionality, i.e. infinite solutions
if length(new_eqs) < length(vars)
generating &= false
break
warns && @warn("Infinite number of solutions")
return nothing
end

# We exit when the system is in Shape Lemma case:
Expand All @@ -232,15 +238,6 @@ function Symbolics.solve_multivar(eqs::Vector, vars::Vector{Num}; dropmultiplici

solutions = []

# handle "unsolvable" cases
if isequal(1, new_eqs[1])
return solutions
end
if length(new_eqs) < length(vars)
warns && @warn("Infinite number of solutions")
return nothing
end

# first, solve the first minimal polynomial
@assert length(new_eqs) == length(vars)
@assert isequal(setdiff(Symbolics.get_variables(new_eqs[1]), params), [new_var])
Expand Down
2 changes: 1 addition & 1 deletion src/solver/attract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function attract_logs(lhs, var)
condition_y = expand(simplify(lhs, rewriter = SymbolicUtils.Postwalk(SymbolicUtils.Chain(r_conditiony))))
lhs = expand(simplify(lhs, rewriter = SymbolicUtils.Postwalk(SymbolicUtils.Chain(r_addlogs))))

return lhs, [condition_x, condition_y]
return lhs, [(condition_x, >), (condition_y, >)]
end

"""
Expand Down
13 changes: 7 additions & 6 deletions src/solver/ia_main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,21 @@ function isolate(lhs, var; warns=true, conditions=[])
elseif oper === (log) || oper === (slog)
lhs = args[1]
rhs = map(sol -> term(^, Base.MathConstants.e, sol), rhs)
push!(conditions, args[1])
push!(conditions, (args[1], >))

elseif oper === (log2)
lhs = args[1]
rhs = map(sol -> term(^, 2, sol), rhs)
push!(conditions, args[1])
push!(conditions, (args[1], >))

elseif oper === (log10)
lhs = args[1]
rhs = map(sol -> term(^, 10, sol), rhs)
push!(conditions, args[1])
push!(conditions, (args[1], >))

elseif oper === (sqrt)
lhs = args[1]
append!(conditions, rhs)
append!(conditions, [(r, >=) for r in rhs])
rhs = map(sol -> term(^, sol, 2), rhs)

elseif oper === (cbrt)
Expand Down Expand Up @@ -284,9 +284,10 @@ function ia_solve(lhs, var; warns = true)
end
domain_error = false
for j in eachindex(conditions)
cond_val = substitute(conditions[j], Dict(var=>eval(toexpr(sols[i]))))
condition, t = conditions[j]
cond_val = substitute(condition, Dict(var=>eval(toexpr(sols[i]))))
cond_val isa Complex && continue
domain_error |= cond_val <= 0
domain_error |= !t(cond_val, 0)
end
!domain_error && push!(filtered_sols, sols[i])
end
Expand Down
1 change: 1 addition & 0 deletions src/solver/polynomialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ function detect_sqrtpoly(lhs, var)
!iscall(arg) && continue

if isequal(check_sqrt(arg, sqrt_term, var), true)
sqrt_term_n += arg
sqrt_term = true
continue
elseif isequal(check_sqrt(arg, sqrt_term, var), false)
Expand Down

0 comments on commit 8675b38

Please sign in to comment.