diff --git a/lib/OptimizationMOI/src/OptimizationMOI.jl b/lib/OptimizationMOI/src/OptimizationMOI.jl index bad06ccca..9505448a8 100644 --- a/lib/OptimizationMOI/src/OptimizationMOI.jl +++ b/lib/OptimizationMOI/src/OptimizationMOI.jl @@ -111,19 +111,19 @@ end include("nlp.jl") include("moi.jl") -function SciMLBase.solve(prob::OptimizationProblem, - opt::Union{MOI.AbstractOptimizer, MOI.OptimizerWithAttributes}; - kwargs...) - cache = SciMLBase.init(prob, opt) - SciMLBase.solve(cache; kwargs...) -end +SciMLBase.supports_opt_cache_interface(alg::Union{MOI.AbstractOptimizer, MOI.OptimizerWithAttributes}) = true -function SciMLBase.init(prob::OptimizationProblem, - opt::Union{MOI.AbstractOptimizer, MOI.OptimizerWithAttributes}) +function SciMLBase.__init(prob::OptimizationProblem, + opt::Union{MOI.AbstractOptimizer, MOI.OptimizerWithAttributes}; + maxiters::Union{Number, Nothing} = nothing, + maxtime::Union{Number, Nothing} = nothing, + abstol::Union{Number, Nothing} = nothing, + reltol::Union{Number, Nothing} = nothing, + kwargs...) cache = if MOI.supports(_create_new_optimizer(opt), MOI.NLPBlock()) - MOIOptimizationNLPCache(prob, opt) + MOIOptimizationNLPCache(prob, opt; maxiters, maxtime, abstol, reltol, kwargs...) else - MOIOptimizationCache(prob, opt) + MOIOptimizationCache(prob, opt; maxiters, maxtime, abstol, reltol, kwargs...) end return cache end diff --git a/lib/OptimizationMOI/src/moi.jl b/lib/OptimizationMOI/src/moi.jl index 41e73df47..0684f8cf0 100644 --- a/lib/OptimizationMOI/src/moi.jl +++ b/lib/OptimizationMOI/src/moi.jl @@ -10,9 +10,10 @@ mutable struct MOIOptimizationCache{F <: OptimizationFunction, uType, P, LB, UB, expr::EX cons_expr::CEX opt::O + solver_args::NamedTuple end -function MOIOptimizationCache(prob::OptimizationProblem, opt) +function MOIOptimizationCache(prob::OptimizationProblem, opt; kwargs...) isnothing(prob.f.sys) && throw(ArgumentError("Expected a `OptimizationProblem` that was setup via an `OptimizationSystem`, consider `modelingtoolkitize(prob).`")) @@ -32,7 +33,8 @@ function MOIOptimizationCache(prob::OptimizationProblem, opt) prob.sense, expr, cons_expr, - opt) + opt, + NamedTuple(kwargs)) end SciMLBase.has_reinit(cache::MOIOptimizationCache) = true @@ -86,21 +88,16 @@ function _add_moi_variables!(opt_setup, cache::MOIOptimizationCache) return θ end -function SciMLBase.solve(cache::MOIOptimizationCache; - maxiters::Union{Number, Nothing} = nothing, - maxtime::Union{Number, Nothing} = nothing, - abstol::Union{Number, Nothing} = nothing, - reltol::Union{Number, Nothing} = nothing, - kwargs...) - maxiters = Optimization._check_and_convert_maxiters(maxiters) - maxtime = Optimization._check_and_convert_maxtime(maxtime) +function SciMLBase.__solve(cache::MOIOptimizationCache) + maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters) + maxtime = Optimization._check_and_convert_maxtime(cache.solver_args.maxtime) opt_setup = __map_optimizer_args(cache, cache.opt; - abstol = abstol, - reltol = reltol, + abstol = cache.solver_args.abstol, + reltol = cache.solver_args.reltol, maxiters = maxiters, maxtime = maxtime, - kwargs...) + cache.solver_args...) θ = _add_moi_variables!(opt_setup, cache) MOI.set(opt_setup, diff --git a/lib/OptimizationMOI/src/nlp.jl b/lib/OptimizationMOI/src/nlp.jl index 843433cb9..91f434e0d 100644 --- a/lib/OptimizationMOI/src/nlp.jl +++ b/lib/OptimizationMOI/src/nlp.jl @@ -27,6 +27,7 @@ struct MOIOptimizationNLPCache{E <: MOIOptimizationNLPEvaluator, O} <: SciMLBase.AbstractOptimizationCache evaluator::E opt::O + solver_args::NamedTuple end function SciMLBase.get_p(sol::SciMLBase.OptimizationSolution{T, N, uType, C}) where {T, N, @@ -63,7 +64,7 @@ function SciMLBase.get_paramsyms(sol::SciMLBase.OptimizationSolution{T, N, uType sol.cache.evaluator.f.paramsyms end -function MOIOptimizationNLPCache(prob::OptimizationProblem, opt) +function MOIOptimizationNLPCache(prob::OptimizationProblem, opt; kwargs...) reinit_cache = Optimization.ReInitCache(prob.u0, prob.p) # everything that can be changed via `reinit` num_cons = prob.ucons === nothing ? 0 : length(prob.ucons) @@ -100,7 +101,7 @@ function MOIOptimizationNLPCache(prob::OptimizationProblem, opt) J, H, cons_H) - return MOIOptimizationNLPCache(evaluator, opt) + return MOIOptimizationNLPCache(evaluator, opt, NamedTuple(kwargs)) end SciMLBase.has_reinit(cache::MOIOptimizationNLPCache) = true @@ -319,21 +320,16 @@ function _add_moi_variables!(opt_setup, evaluator::MOIOptimizationNLPEvaluator) return θ end -function SciMLBase.solve(cache::MOIOptimizationNLPCache; - maxiters::Union{Number, Nothing} = nothing, - maxtime::Union{Number, Nothing} = nothing, - abstol::Union{Number, Nothing} = nothing, - reltol::Union{Number, Nothing} = nothing, - kwargs...) - maxiters = Optimization._check_and_convert_maxiters(maxiters) - maxtime = Optimization._check_and_convert_maxtime(maxtime) +function SciMLBase.__solve(cache::MOIOptimizationNLPCache) + maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters) + maxtime = Optimization._check_and_convert_maxtime(cache.solver_args.maxtime) opt_setup = __map_optimizer_args(cache, cache.opt; - abstol = abstol, - reltol = reltol, + abstol = cache.solver_args.abstol, + reltol = cache.solver_args.reltol, maxiters = maxiters, maxtime = maxtime, - kwargs...) + cache.solver_args...) θ = _add_moi_variables!(opt_setup, cache.evaluator) MOI.set(opt_setup, diff --git a/lib/OptimizationMOI/test/runtests.jl b/lib/OptimizationMOI/test/runtests.jl index 65e63b48a..a30c9cf06 100644 --- a/lib/OptimizationMOI/test/runtests.jl +++ b/lib/OptimizationMOI/test/runtests.jl @@ -105,12 +105,12 @@ end @named sys = OptimizationSystem((x - a)^2, [x], [a];) prob = OptimizationProblem(sys, [x => 0.0], []; grad = true, hess = true) - cache = Optimization.init(prob, Ipopt.Optimizer()) - sol = Optimization.solve(cache; print_level = 0) + cache = Optimization.init(prob, Ipopt.Optimizer(); print_level = 0) + sol = Optimization.solve(cache) @test sol.u ≈ [1.0] # ≈ [1] cache = OptimizationMOI.reinit!(cache; p = [2.0]) - sol = Optimization.solve(cache; print_level = 0) + sol = Optimization.solve(cache) @test sol.u ≈ [2.0] # ≈ [2] end