diff --git a/lib/OptimizationManopt/src/OptimizationManopt.jl b/lib/OptimizationManopt/src/OptimizationManopt.jl index b2e14a2fe..3507f7660 100644 --- a/lib/OptimizationManopt/src/OptimizationManopt.jl +++ b/lib/OptimizationManopt/src/OptimizationManopt.jl @@ -2,7 +2,7 @@ module OptimizationManopt using Reexport @reexport using Manopt -using Optimization, Manopt, ManifoldsBase, ManifoldDiff +using Optimization, Manopt, ManifoldsBase, ManifoldDiff, Optimization.SciMLBase """ abstract type AbstractManoptOptimizer end @@ -12,233 +12,191 @@ internal state. """ abstract type AbstractManoptOptimizer end -function stopping_criterion_to_kwarg(stopping_criterion::Nothing) - return NamedTuple() -end -function stopping_criterion_to_kwarg(stopping_criterion::StoppingCriterion) - return (; stopping_criterion = stopping_criterion) -end +SciMLBase.supports_opt_cache_interface(opt::AbstractManoptOptimizer) = true -## gradient descent +function __map_optimizer_args!(cache::OptimizationCache, + opt::AbstractManoptOptimizer; + callback = nothing, + maxiters::Union{Number, Nothing} = nothing, + maxtime::Union{Number, Nothing} = nothing, + abstol::Union{Number, Nothing} = nothing, + reltol::Union{Number, Nothing} = nothing, + kwargs...) -struct GradientDescentOptimizer{ - Teval <: AbstractEvaluationType, - TM <: AbstractManifold, - TLS <: Linesearch -} <: AbstractManoptOptimizer - M::TM - stepsize::TLS -end + solver_kwargs = (; kwargs...) -function GradientDescentOptimizer(M::AbstractManifold; - eval::AbstractEvaluationType = Manopt.AllocatingEvaluation(), - stepsize::Stepsize = ArmijoLinesearch(M)) - GradientDescentOptimizer{typeof(eval), typeof(M), typeof(stepsize)}(M, stepsize) + if !isnothing(maxiters) + solver_kwargs = (; solver_kwargs..., stopping_criterion = [Manopt.StopAfterIteration(maxiters)]) + end + + if !isnothing(maxtime) + if haskey(solver_kwargs, :stopping_criterion) + solver_kwargs = (; solver_kwargs..., stopping_criterion = push!(solver_kwargs.stopping_criterion, Manopt.StopAfterTime(maxtime))) + else + solver_kwargs = (; solver_kwargs..., stopping_criterion = [Manopt.StopAfter(maxtime)]) + end + end + + if !isnothing(abstol) + if haskey(solver_kwargs, :stopping_criterion) + solver_kwargs = (; solver_kwargs..., stopping_criterion = push!(solver_kwargs.stopping_criterion, Manopt.StopWhenChangeLess(abstol))) + else + solver_kwargs = (; solver_kwargs..., stopping_criterion = [Manopt.StopWhenChangeLess(abstol)]) + end + end + + if !isnothing(reltol) + @warn "common reltol is currently not used by $(typeof(opt).super)" + end + return solver_kwargs end -function call_manopt_optimizer(opt::GradientDescentOptimizer{Teval}, +## gradient descent +struct GradientDescentOptimizer <: AbstractManoptOptimizer end + +function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold, opt::GradientDescentOptimizer, loss, gradF, - x0, - stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where { - Teval <: - AbstractEvaluationType -} - sckwarg = stopping_criterion_to_kwarg(stopping_criterion) - opts = gradient_descent(opt.M, + x0; + stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet}, + evaluation::AbstractEvaluationType = Manopt.AllocatingEvaluation(), + stepsize::Stepsize = ArmijoLinesearch(M), + kwargs...) + opts = gradient_descent(M, loss, gradF, x0; return_state = true, - evaluation = Teval(), - stepsize = opt.stepsize, - sckwarg...) + evaluation, + stepsize, + stopping_criterion) # we unwrap DebugOptions here minimizer = Manopt.get_solver_result(opts) - return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts), - :who_knows + return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts) end ## Nelder-Mead +struct NelderMeadOptimizer <: AbstractManoptOptimizer end -struct NelderMeadOptimizer{ - TM <: AbstractManifold, -} <: AbstractManoptOptimizer - M::TM -end - -function call_manopt_optimizer(opt::NelderMeadOptimizer, +function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold, opt::NelderMeadOptimizer, loss, gradF, - x0, - stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) - sckwarg = stopping_criterion_to_kwarg(stopping_criterion) + x0; + stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet}, + kwargs...) - opts = NelderMead(opt.M, + opts = NelderMead(M, loss; return_state = true, - sckwarg...) + stopping_criterion) minimizer = Manopt.get_solver_result(opts) - return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts), - :who_knows + return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts) end ## conjugate gradient descent +struct ConjugateGradientDescentOptimizer <: AbstractManoptOptimizer end -struct ConjugateGradientDescentOptimizer{Teval <: AbstractEvaluationType, - TM <: AbstractManifold, TLS <: Stepsize} <: - AbstractManoptOptimizer - M::TM - stepsize::TLS -end - -function ConjugateGradientDescentOptimizer(M::AbstractManifold; - eval::AbstractEvaluationType = InplaceEvaluation(), - stepsize::Stepsize = ArmijoLinesearch(M)) - ConjugateGradientDescentOptimizer{typeof(eval), typeof(M), typeof(stepsize)}(M, - stepsize) -end - -function call_manopt_optimizer(opt::ConjugateGradientDescentOptimizer{Teval}, +function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold, + opt::ConjugateGradientDescentOptimizer, loss, gradF, - x0, - stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where { - Teval <: - AbstractEvaluationType -} - sckwarg = stopping_criterion_to_kwarg(stopping_criterion) - opts = conjugate_gradient_descent(opt.M, + x0; + stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet}, + evaluation::AbstractEvaluationType = InplaceEvaluation(), + stepsize::Stepsize = ArmijoLinesearch(M), + kwargs...) + + opts = conjugate_gradient_descent(M, loss, gradF, x0; return_state = true, - evaluation = Teval(), - stepsize = opt.stepsize, - sckwarg...) + evaluation, + stepsize, + stopping_criterion) # we unwrap DebugOptions here minimizer = Manopt.get_solver_result(opts) - return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts), - :who_knows + return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts) end ## particle swarm +struct ParticleSwarmOptimizer <: AbstractManoptOptimizer end -struct ParticleSwarmOptimizer{Teval <: AbstractEvaluationType, - TM <: AbstractManifold, Tretr <: AbstractRetractionMethod, - Tinvretr <: AbstractInverseRetractionMethod, - Tvt <: AbstractVectorTransportMethod} <: - AbstractManoptOptimizer - M::TM - retraction_method::Tretr - inverse_retraction_method::Tinvretr - vector_transport_method::Tvt - population_size::Int -end - -function ParticleSwarmOptimizer(M::AbstractManifold; - eval::AbstractEvaluationType = InplaceEvaluation(), +function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold, + opt::ParticleSwarmOptimizer, + loss, + gradF, + x0; + stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet}, + evaluation::AbstractEvaluationType = InplaceEvaluation(), population_size::Int = 100, retraction_method::AbstractRetractionMethod = default_retraction_method(M), inverse_retraction_method::AbstractInverseRetractionMethod = default_inverse_retraction_method(M), - vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M)) - ParticleSwarmOptimizer{typeof(eval), typeof(M), typeof(retraction_method), - typeof(inverse_retraction_method), - typeof(vector_transport_method)}(M, - retraction_method, - inverse_retraction_method, - vector_transport_method, - population_size) -end - -function call_manopt_optimizer(opt::ParticleSwarmOptimizer{Teval}, - loss, - gradF, - x0, - stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where { - Teval <: - AbstractEvaluationType -} - sckwarg = stopping_criterion_to_kwarg(stopping_criterion) - initial_population = vcat([x0], [rand(opt.M) for _ in 1:(opt.population_size - 1)]) - opts = particle_swarm(opt.M, + vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M), + kwargs...) + + initial_population = vcat([x0], [rand(M) for _ in 1:(population_size - 1)]) + opts = particle_swarm(M, loss; x0 = initial_population, - n = opt.population_size, + n = population_size, return_state = true, - retraction_method = opt.retraction_method, - inverse_retraction_method = opt.inverse_retraction_method, - vector_transport_method = opt.vector_transport_method, - sckwarg...) + retraction_method, + inverse_retraction_method, + vector_transport_method, + stopping_criterion) # we unwrap DebugOptions here minimizer = Manopt.get_solver_result(opts) - return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts), - :who_knows + return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts) end ## quasi Newton -struct QuasiNewtonOptimizer{Teval <: AbstractEvaluationType, - TM <: AbstractManifold, Tretr <: AbstractRetractionMethod, - Tvt <: AbstractVectorTransportMethod, TLS <: Stepsize} <: - AbstractManoptOptimizer - M::TM - retraction_method::Tretr - vector_transport_method::Tvt - stepsize::TLS -end +struct QuasiNewtonOptimizer <: AbstractManoptOptimizer end -function QuasiNewtonOptimizer(M::AbstractManifold; - eval::AbstractEvaluationType = InplaceEvaluation(), +function call_manopt_optimizer(M::Manopt.AbstractManifold, + opt::QuasiNewtonOptimizer, + loss, + gradF, + x0; + stopping_criterion::Union{Manopt.StoppingCriterion, Manopt.StoppingCriterionSet}, + evaluation::AbstractEvaluationType = InplaceEvaluation(), retraction_method::AbstractRetractionMethod = default_retraction_method(M), vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M), stepsize = WolfePowellLinesearch(M; retraction_method = retraction_method, vector_transport_method = vector_transport_method, - linesearch_stopsize = 1e-12)) - QuasiNewtonOptimizer{typeof(eval), typeof(M), typeof(retraction_method), - typeof(vector_transport_method), typeof(stepsize)}(M, - retraction_method, - vector_transport_method, - stepsize) -end - -function call_manopt_optimizer(opt::QuasiNewtonOptimizer{Teval}, - loss, - gradF, - x0, - stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where { - Teval <: - AbstractEvaluationType -} - sckwarg = stopping_criterion_to_kwarg(stopping_criterion) - opts = quasi_Newton(opt.M, + linesearch_stopsize = 1e-12), + kwargs... + ) + + opts = quasi_Newton(M, loss, gradF, x0; return_state = true, - evaluation = Teval(), - retraction_method = opt.retraction_method, - vector_transport_method = opt.vector_transport_method, - stepsize = opt.stepsize, - sckwarg...) + evaluation, + retraction_method, + vector_transport_method, + stepsize, + stopping_criterion) # we unwrap DebugOptions here minimizer = Manopt.get_solver_result(opts) - return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts), - :who_knows + return (; minimizer = minimizer, minimum = loss(M, minimizer), options = opts) end ## Optimization.jl stuff -function build_loss(f::OptimizationFunction, prob) +function build_loss(f::OptimizationFunction, prob, cb) function (::AbstractManifold, θ) - x = f.f(θ) + x = f.f(θ, prob.p) + cb(x, θ) __x = first(x) return prob.sense === Optimization.MaxSense ? -__x : __x end end -function build_gradF(f::OptimizationFunction{true}, prob, cur) +function build_gradF(f::OptimizationFunction{true}, cur) function g(M::AbstractManifold, G, θ) f.grad(G, θ, cur...) G .= riemannian_gradient(M, θ, G) @@ -255,50 +213,91 @@ end # 2) return convergence information # 3) add callbacks to Manopt.jl -function SciMLBase.__solve(prob::OptimizationProblem, - opt::AbstractManoptOptimizer, - data = Optimization.DEFAULT_DATA; - callback = (args...) -> (false), - maxiters::Union{Number, Nothing} = nothing, - maxtime::Union{Number, Nothing} = nothing, - abstol::Union{Number, Nothing} = nothing, - reltol::Union{Number, Nothing} = nothing, - progress = false, - kwargs...) +function SciMLBase.__solve(cache::OptimizationCache{ + F, + RC, + LB, + UB, + LC, + UC, + S, + O, + D, + P, + C +}) where { + F, + RC, + LB, + UB, + LC, + UC, + S, + O <: + AbstractManoptOptimizer, + D, + P, + C +} local x, cur, state - manifold = haskey(prob.kwargs, :manifold) ? prob.kwargs[:manifold] : nothing + manifold = haskey(cache.solver_args, :manifold) ? cache.solver_args[:manifold] : nothing - if manifold === nothing || manifold !== opt.M - throw(ArgumentError("Either manifold not specified in the problem `OptimizationProblem(f, x, p; manifold = SymmetricPositiveDefinite(5))` or it doesn't match the manifold specified in the optimizer `$(opt.M)`")) + if manifold === nothing + throw(ArgumentError("Manifold not specified in the problem for e.g. `OptimizationProblem(f, x, p; manifold = SymmetricPositiveDefinite(5))`.")) end - if data !== Optimization.DEFAULT_DATA - maxiters = length(data) + if cache.data !== Optimization.DEFAULT_DATA + maxiters = length(cache.data) + else + maxiters = cache.solver_args.maxiters end - cur, state = iterate(data) - - stopping_criterion = nothing - if maxiters !== nothing - stopping_criterion = StopAfterIteration(maxiters) + cur, state = iterate(cache.data) + + function _cb(x, θ) + opt_state = Optimization.OptimizationState(iter = 0, + u = θ, + objective = x[1]) + cb_call = cache.callback(opt_state, x...) + if !(cb_call isa Bool) + error("The callback should return a boolean `halt` for whether to stop the optimization process.") + end + nx_itr = iterate(cache.data, state) + if isnothing(nx_itr) + true + else + cur, state = nx_itr + cb_call + end end + solver_kwarg = __map_optimizer_args!(cache, cache.opt, callback = _cb, + maxiters = maxiters, + maxtime = cache.solver_args.maxtime, + abstol = cache.solver_args.abstol, + reltol = cache.solver_args.reltol; + ) - maxiters = Optimization._check_and_convert_maxiters(maxiters) - maxtime = Optimization._check_and_convert_maxtime(maxtime) + _loss = build_loss(cache.f, cache, _cb) - f = Optimization.instantiate_function(prob.f, prob.u0, prob.f.adtype, prob.p) + gradF = build_gradF(cache.f, cur) + + if haskey(solver_kwarg, :stopping_criterion) + stopping_criterion = Manopt.StopWhenAny(solver_kwarg.stopping_criterion...) + else + stopping_criterion = Manopt.StopAfterIteration(500) + end - _loss = build_loss(f, prob) + opt_res = call_manopt_optimizer(manifold, cache.opt, _loss, gradF, cache.u0; solver_kwarg..., stopping_criterion=stopping_criterion) - gradF = build_gradF(f, prob, cur) + asc = get_active_stopping_criteria(opt_res.options.stop) - opt_res, opt_ret = call_manopt_optimizer(opt, _loss, gradF, prob.u0, stopping_criterion) + opt_ret = any(Manopt.indicates_convergence, asc) ? ReturnCode.Success : ReturnCode.Failure - return SciMLBase.build_solution(SciMLBase.DefaultOptimizationCache(prob.f, prob.p), - opt, + return SciMLBase.build_solution(cache, + cache.opt, opt_res.minimizer, - prob.sense === Optimization.MaxSense ? + cache.sense === Optimization.MaxSense ? -opt_res.minimum : opt_res.minimum; original = opt_res.options, retcode = opt_ret) diff --git a/lib/OptimizationManopt/test/runtests.jl b/lib/OptimizationManopt/test/runtests.jl index 4a9ed3176..b309d9932 100644 --- a/lib/OptimizationManopt/test/runtests.jl +++ b/lib/OptimizationManopt/test/runtests.jl @@ -21,18 +21,12 @@ R2 = Euclidean(2) p = [1.0, 100.0] stepsize = Manopt.ArmijoLinesearch(R2) - opt = OptimizationManopt.GradientDescentOptimizer(R2, - stepsize = stepsize) + opt = OptimizationManopt.GradientDescentOptimizer() optprob_forwarddiff = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff()) prob_forwarddiff = OptimizationProblem(optprob_forwarddiff, x0, p) - @test_throws ArgumentError("Either manifold not specified in the problem `OptimizationProblem(f, x, p; manifold = SymmetricPositiveDefinite(5))` or it doesn't match the manifold specified in the optimizer `$(opt.M)`") Optimization.solve( + @test_throws ArgumentError("Manifold not specified in the problem for e.g. `OptimizationProblem(f, x, p; manifold = SymmetricPositiveDefinite(5))`.") Optimization.solve( prob_forwarddiff, opt) - - optprob = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff()) - prob = OptimizationProblem(optprob, x0, p; manifold = SymmetricPositiveDefinite(5)) - @test_throws ArgumentError("Either manifold not specified in the problem `OptimizationProblem(f, x, p; manifold = SymmetricPositiveDefinite(5))` or it doesn't match the manifold specified in the optimizer `$(opt.M)`") Optimization.solve( - prob, opt) end @testset "Gradient descent" begin @@ -40,16 +34,15 @@ end p = [1.0, 100.0] stepsize = Manopt.ArmijoLinesearch(R2) - opt = OptimizationManopt.GradientDescentOptimizer(R2, - stepsize = stepsize) + opt = OptimizationManopt.GradientDescentOptimizer() optprob_forwarddiff = OptimizationFunction(rosenbrock, Optimization.AutoEnzyme()) - prob_forwarddiff = OptimizationProblem(optprob_forwarddiff, x0, p; manifold = R2) + prob_forwarddiff = OptimizationProblem(optprob_forwarddiff, x0, p; manifold = R2, stepsize = stepsize) sol = Optimization.solve(prob_forwarddiff, opt) @test sol.minimum < 0.2 optprob_grad = OptimizationFunction(rosenbrock; grad = rosenbrock_grad!) - prob_grad = OptimizationProblem(optprob_grad, x0, p; manifold = R2) + prob_grad = OptimizationProblem(optprob_grad, x0, p; manifold = R2, stepsize = stepsize) sol = Optimization.solve(prob_grad, opt) @test sol.minimum < 0.2 end @@ -58,7 +51,7 @@ end x0 = zeros(2) p = [1.0, 100.0] - opt = OptimizationManopt.NelderMeadOptimizer(R2) + opt = OptimizationManopt.NelderMeadOptimizer() optprob = OptimizationFunction(rosenbrock) prob = OptimizationProblem(optprob, x0, p; manifold = R2) @@ -72,13 +65,12 @@ end p = [1.0, 100.0] stepsize = Manopt.ArmijoLinesearch(R2) - opt = OptimizationManopt.ConjugateGradientDescentOptimizer(R2, - stepsize = stepsize) + opt = OptimizationManopt.ConjugateGradientDescentOptimizer() optprob = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff()) prob = OptimizationProblem(optprob, x0, p; manifold = R2) - sol = Optimization.solve(prob, opt) + sol = Optimization.solve(prob, opt, stepsize = stepsize) @test sol.minimum < 0.5 end @@ -86,12 +78,16 @@ end x0 = zeros(2) p = [1.0, 100.0] - opt = OptimizationManopt.QuasiNewtonOptimizer(R2) - + opt = OptimizationManopt.QuasiNewtonOptimizer() + function callback(state, l) + println(state.u) + println(l) + return false + end optprob = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff()) prob = OptimizationProblem(optprob, x0, p; manifold = R2) - sol = Optimization.solve(prob, opt) + sol = Optimization.solve(prob, opt, callback = callback, maxiters = 30) @test sol.minimum < 1e-14 end @@ -99,7 +95,7 @@ end x0 = zeros(2) p = [1.0, 100.0] - opt = OptimizationManopt.ParticleSwarmOptimizer(R2) + opt = OptimizationManopt.ParticleSwarmOptimizer() optprob = OptimizationFunction(rosenbrock) prob = OptimizationProblem(optprob, x0, p; manifold = R2) @@ -113,7 +109,7 @@ end x0 = zeros(2) p = [1.0, 100.0] - opt = OptimizationManopt.GradientDescentOptimizer(R2) + opt = OptimizationManopt.GradientDescentOptimizer() optprob_cons = OptimizationFunction(rosenbrock; grad = rosenbrock_grad!, cons = cons) prob_cons = OptimizationProblem(optprob_cons, x0, p) @@ -133,7 +129,7 @@ end optf = OptimizationFunction(f, Optimization.AutoForwardDiff()) prob = OptimizationProblem(optf, data2[1]; manifold = M, maxiters = 1000) - opt = OptimizationManopt.GradientDescentOptimizer(M) + opt = OptimizationManopt.GradientDescentOptimizer() @time sol = Optimization.solve(prob, opt) @test sol.u ≈ q atol = 1e-2