Skip to content

Commit

Permalink
changes to sciml init dispatch interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ValentinKaisermayer committed Oct 28, 2022
1 parent 106c7c6 commit 1aa0395
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 39 deletions.
20 changes: 10 additions & 10 deletions lib/OptimizationMOI/src/OptimizationMOI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,19 @@ end
include("nlp.jl")
include("moi.jl")

function SciMLBase.solve(prob::OptimizationProblem,
opt::Union{MOI.AbstractOptimizer, MOI.OptimizerWithAttributes};
kwargs...)
cache = SciMLBase.init(prob, opt)
SciMLBase.solve(cache; kwargs...)
end
SciMLBase.supports_opt_cache_interface(alg::Union{MOI.AbstractOptimizer, MOI.OptimizerWithAttributes}) = true

function SciMLBase.init(prob::OptimizationProblem,
opt::Union{MOI.AbstractOptimizer, MOI.OptimizerWithAttributes})
function SciMLBase.__init(prob::OptimizationProblem,
opt::Union{MOI.AbstractOptimizer, MOI.OptimizerWithAttributes};
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
abstol::Union{Number, Nothing} = nothing,
reltol::Union{Number, Nothing} = nothing,
kwargs...)
cache = if MOI.supports(_create_new_optimizer(opt), MOI.NLPBlock())
MOIOptimizationNLPCache(prob, opt)
MOIOptimizationNLPCache(prob, opt; maxiters, maxtime, abstol, reltol, kwargs...)
else
MOIOptimizationCache(prob, opt)
MOIOptimizationCache(prob, opt; maxiters, maxtime, abstol, reltol, kwargs...)
end
return cache
end
Expand Down
23 changes: 10 additions & 13 deletions lib/OptimizationMOI/src/moi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ mutable struct MOIOptimizationCache{F <: OptimizationFunction, uType, P, LB, UB,
expr::EX
cons_expr::CEX
opt::O
solver_args::NamedTuple
end

function MOIOptimizationCache(prob::OptimizationProblem, opt)
function MOIOptimizationCache(prob::OptimizationProblem, opt; kwargs...)
isnothing(prob.f.sys) &&
throw(ArgumentError("Expected a `OptimizationProblem` that was setup via an `OptimizationSystem`, consider `modelingtoolkitize(prob).`"))

Expand All @@ -32,7 +33,8 @@ function MOIOptimizationCache(prob::OptimizationProblem, opt)
prob.sense,
expr,
cons_expr,
opt)
opt,
NamedTuple(kwargs))
end

SciMLBase.has_reinit(cache::MOIOptimizationCache) = true
Expand Down Expand Up @@ -86,21 +88,16 @@ function _add_moi_variables!(opt_setup, cache::MOIOptimizationCache)
return θ
end

function SciMLBase.solve(cache::MOIOptimizationCache;
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
abstol::Union{Number, Nothing} = nothing,
reltol::Union{Number, Nothing} = nothing,
kwargs...)
maxiters = Optimization._check_and_convert_maxiters(maxiters)
maxtime = Optimization._check_and_convert_maxtime(maxtime)
function SciMLBase.__solve(cache::MOIOptimizationCache)
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
maxtime = Optimization._check_and_convert_maxtime(cache.solver_args.maxtime)
opt_setup = __map_optimizer_args(cache,
cache.opt;
abstol = abstol,
reltol = reltol,
abstol = cache.solver_args.abstol,
reltol = cache.solver_args.reltol,
maxiters = maxiters,
maxtime = maxtime,
kwargs...)
cache.solver_args...)

θ = _add_moi_variables!(opt_setup, cache)
MOI.set(opt_setup,
Expand Down
22 changes: 9 additions & 13 deletions lib/OptimizationMOI/src/nlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct MOIOptimizationNLPCache{E <: MOIOptimizationNLPEvaluator, O} <:
SciMLBase.AbstractOptimizationCache
evaluator::E
opt::O
solver_args::NamedTuple
end

function SciMLBase.get_p(sol::SciMLBase.OptimizationSolution{T, N, uType, C}) where {T, N,
Expand Down Expand Up @@ -63,7 +64,7 @@ function SciMLBase.get_paramsyms(sol::SciMLBase.OptimizationSolution{T, N, uType
sol.cache.evaluator.f.paramsyms
end

function MOIOptimizationNLPCache(prob::OptimizationProblem, opt)
function MOIOptimizationNLPCache(prob::OptimizationProblem, opt; kwargs...)
reinit_cache = Optimization.ReInitCache(prob.u0, prob.p) # everything that can be changed via `reinit`

num_cons = prob.ucons === nothing ? 0 : length(prob.ucons)
Expand Down Expand Up @@ -100,7 +101,7 @@ function MOIOptimizationNLPCache(prob::OptimizationProblem, opt)
J,
H,
cons_H)
return MOIOptimizationNLPCache(evaluator, opt)
return MOIOptimizationNLPCache(evaluator, opt, NamedTuple(kwargs))
end

SciMLBase.has_reinit(cache::MOIOptimizationNLPCache) = true
Expand Down Expand Up @@ -319,21 +320,16 @@ function _add_moi_variables!(opt_setup, evaluator::MOIOptimizationNLPEvaluator)
return θ
end

function SciMLBase.solve(cache::MOIOptimizationNLPCache;
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
abstol::Union{Number, Nothing} = nothing,
reltol::Union{Number, Nothing} = nothing,
kwargs...)
maxiters = Optimization._check_and_convert_maxiters(maxiters)
maxtime = Optimization._check_and_convert_maxtime(maxtime)
function SciMLBase.__solve(cache::MOIOptimizationNLPCache)
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
maxtime = Optimization._check_and_convert_maxtime(cache.solver_args.maxtime)
opt_setup = __map_optimizer_args(cache,
cache.opt;
abstol = abstol,
reltol = reltol,
abstol = cache.solver_args.abstol,
reltol = cache.solver_args.reltol,
maxiters = maxiters,
maxtime = maxtime,
kwargs...)
cache.solver_args...)

θ = _add_moi_variables!(opt_setup, cache.evaluator)
MOI.set(opt_setup,
Expand Down
6 changes: 3 additions & 3 deletions lib/OptimizationMOI/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ end
@named sys = OptimizationSystem((x - a)^2, [x], [a];)

prob = OptimizationProblem(sys, [x => 0.0], []; grad = true, hess = true)
cache = Optimization.init(prob, Ipopt.Optimizer())
sol = Optimization.solve(cache; print_level = 0)
cache = Optimization.init(prob, Ipopt.Optimizer(); print_level = 0)
sol = Optimization.solve(cache)
@test sol.u [1.0] # ≈ [1]

cache = OptimizationMOI.reinit!(cache; p = [2.0])
sol = Optimization.solve(cache; print_level = 0)
sol = Optimization.solve(cache)
@test sol.u [2.0] # ≈ [2]
end

Expand Down

0 comments on commit 1aa0395

Please sign in to comment.