From f4d9cb5e552adc1bd588022296908c966f5018e8 Mon Sep 17 00:00:00 2001 From: Valentin Kaisermayer Date: Sat, 26 Nov 2022 15:51:12 +0100 Subject: [PATCH] robustifies MOI term collection --- lib/OptimizationMOI/src/moi.jl | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/lib/OptimizationMOI/src/moi.jl b/lib/OptimizationMOI/src/moi.jl index c822bcd6c..a82b6188a 100644 --- a/lib/OptimizationMOI/src/moi.jl +++ b/lib/OptimizationMOI/src/moi.jl @@ -152,7 +152,7 @@ function SciMLBase.__solve(cache::MOIOptimizationCache) maxtime = maxtime, cache.solver_args...) - θ = _add_moi_variables!(opt_setup, cache) + Theta = _add_moi_variables!(opt_setup, cache) MOI.set(opt_setup, MOI.ObjectiveSense(), cache.sense === Optimization.MaxSense ? MOI.MAX_SENSE : MOI.MIN_SENSE) @@ -161,7 +161,7 @@ function SciMLBase.__solve(cache::MOIOptimizationCache) for cons_expr in cache.cons_expr expr = _replace_parameter_indices!(repl_getindex!(deepcopy(cons_expr.args[2])), # f(x) == 0 or f(x) <= 0 cache.p) - fixpoint_simplify_and_expand!(expr) + expr = fixpoint_simplify_and_expand!(expr; iter_max = length(Theta)^2) func, c = try get_moi_function(expr) # find: f(x) + c == 0 or f(x) + c <= 0 catch e @@ -183,7 +183,7 @@ function SciMLBase.__solve(cache::MOIOptimizationCache) # objective expr = _replace_parameter_indices!(repl_getindex!(deepcopy(cache.expr)), cache.p) - fixpoint_simplify_and_expand!(expr) + expr = fixpoint_simplify_and_expand!(expr; iter_max = length(Theta)^2) func, c = try get_moi_function(expr) catch e @@ -197,11 +197,11 @@ function SciMLBase.__solve(cache::MOIOptimizationCache) MOI.optimize!(opt_setup) if MOI.get(opt_setup, MOI.ResultCount()) >= 1 - minimizer = MOI.get(opt_setup, MOI.VariablePrimal(), θ) + minimizer = MOI.get(opt_setup, MOI.VariablePrimal(), Theta) minimum = MOI.get(opt_setup, MOI.ObjectiveValue()) opt_ret = Symbol(string(MOI.get(opt_setup, MOI.TerminationStatus()))) else - minimizer = fill(NaN, length(θ)) + minimizer = fill(NaN, length(Theta)) minimum = NaN opt_ret = :Default end @@ -282,10 +282,10 @@ function simplify_and_expand!(expr::Expr) # looks awful but is actually much fas return expr end -function fixpoint_simplify_and_expand!(expr; iter_max = 1000) +function fixpoint_simplify_and_expand!(expr; iter_max = Inf) i = 1 - while i < iter_max - expr_old = copy(expr) + while i < iter_max # unsure that this returns + expr_old = deepcopy(expr) expr = simplify_and_expand!(expr) expr_old == expr && break i += 1 @@ -305,7 +305,7 @@ function collect_moi_terms!(expr::Expr, affine_terms, quadratic_terms, constant) collect_moi_terms!(expr.args[i], affine_terms, quadratic_terms, constant) end elseif expr.args[1] == :(*) - if expr.args[2] isa Number + if expr.args[2] isa Number && isa(expr.args[3], Expr) if expr.args[3].head == :call && expr.args[3].args[1] == :(*) # a::Number * (x[i] * x[j]) x1 = _get_variable_index_from_expr(expr.args[3].args[2]) x2 = _get_variable_index_from_expr(expr.args[3].args[3]) @@ -318,7 +318,9 @@ function collect_moi_terms!(expr::Expr, affine_terms, quadratic_terms, constant) MOI.ScalarAffineTerm(Float64(expr.args[2]), _get_variable_index_from_expr(expr.args[3]))) end - else + elseif isa(expr.args[2], Number) && isa(expr.args[3], Number) # a::Number * b::Number + constant[] += expr.args[2] * expr.args[3] + elseif isa(expr.args[2], Expr) && isa(expr.args[3], Expr) if expr.args[2].head == :call && expr.args[2].args[1] == :(*) && expr.args[2].args[2] isa Number # (a::Number * x[i]) * x[j] x1 = _get_variable_index_from_expr(expr.args[2].args[3]) @@ -335,6 +337,8 @@ function collect_moi_terms!(expr::Expr, affine_terms, quadratic_terms, constant) MOI.ScalarQuadraticTerm(factor, x1, x2)) end + else + throw(MalformedExprException("$expr")) end end elseif expr.head == :ref # x[i]