diff --git a/lib/OptimizationBBO/src/OptimizationBBO.jl b/lib/OptimizationBBO/src/OptimizationBBO.jl index defeba0c9..0e4d9fef1 100644 --- a/lib/OptimizationBBO/src/OptimizationBBO.jl +++ b/lib/OptimizationBBO/src/OptimizationBBO.jl @@ -117,10 +117,11 @@ function SciMLBase.__solve(cache::Optimization.OptimizationCache{ else n_steps = BlackBoxOptim.num_steps(trace) curr_u = decompose_trace(trace, cache.progress) - opt_state = Optimization.OptimizationState(iteration = n_steps, + opt_state = Optimization.OptimizationState(; + iter = n_steps, u = curr_u, objective = x[1], - solver_state = trace) + original = trace) cb_call = cache.callback(opt_state, x...) end diff --git a/lib/OptimizationCMAEvolutionStrategy/src/OptimizationCMAEvolutionStrategy.jl b/lib/OptimizationCMAEvolutionStrategy/src/OptimizationCMAEvolutionStrategy.jl index 43a94b99f..94aebe452 100644 --- a/lib/OptimizationCMAEvolutionStrategy/src/OptimizationCMAEvolutionStrategy.jl +++ b/lib/OptimizationCMAEvolutionStrategy/src/OptimizationCMAEvolutionStrategy.jl @@ -78,10 +78,10 @@ function SciMLBase.__solve(cache::OptimizationCache{ function _cb(opt, y, fvals, perm) curr_u = opt.logger.xbest[end] - opt_state = Optimization.OptimizationState(; iteration = length(opt.logger.fmedian), + opt_state = Optimization.OptimizationState(; iter = length(opt.logger.fmedian), u = curr_u, objective = opt.logger.fbest[end], - solver_state = opt.logger) + original = opt.logger) cb_call = cache.callback(opt_state, x...) if !(cb_call isa Bool) diff --git a/lib/OptimizationCMAEvolutionStrategy/test/runtests.jl b/lib/OptimizationCMAEvolutionStrategy/test/runtests.jl index 525772a87..8ce044e76 100644 --- a/lib/OptimizationCMAEvolutionStrategy/test/runtests.jl +++ b/lib/OptimizationCMAEvolutionStrategy/test/runtests.jl @@ -12,7 +12,7 @@ using Test @test 10 * sol.objective < l1 function cb(state, args...) - if state.iteration % 10 == 0 + if state.iter % 10 == 0 println(state.u) end return false diff --git a/lib/OptimizationEvolutionary/src/OptimizationEvolutionary.jl b/lib/OptimizationEvolutionary/src/OptimizationEvolutionary.jl index 283d8ac96..02663f1e5 100644 --- a/lib/OptimizationEvolutionary/src/OptimizationEvolutionary.jl +++ b/lib/OptimizationEvolutionary/src/OptimizationEvolutionary.jl @@ -86,10 +86,10 @@ function SciMLBase.__solve(cache::OptimizationCache{ function _cb(trace) curr_u = decompose_trace(trace).metadata["x"][end] opt_state = Optimization.OptimizationState(; - iteration = decompose_trace(trace).iteration, + iter = decompose_trace(trace).iteration, u = curr_u, objective = x[1], - solver_state = trace) + original = trace) cb_call = cache.callback(opt_state, trace.value...) if !(cb_call isa Bool) error("The callback should return a boolean `halt` for whether to stop the optimization process.") diff --git a/lib/OptimizationEvolutionary/test/runtests.jl b/lib/OptimizationEvolutionary/test/runtests.jl index e7cd3a97b..72b9c19ab 100644 --- a/lib/OptimizationEvolutionary/test/runtests.jl +++ b/lib/OptimizationEvolutionary/test/runtests.jl @@ -36,7 +36,7 @@ Random.seed!(1234) @test sol.objective < l1 function cb(state, args...) - if state.iteration % 10 == 0 + if state.iter % 10 == 0 println(state.u) end return false diff --git a/lib/OptimizationFlux/src/OptimizationFlux.jl b/lib/OptimizationFlux/src/OptimizationFlux.jl index c5a68c985..d9f160262 100644 --- a/lib/OptimizationFlux/src/OptimizationFlux.jl +++ b/lib/OptimizationFlux/src/OptimizationFlux.jl @@ -66,10 +66,10 @@ function SciMLBase.__solve(cache::OptimizationCache{ for (i, d) in enumerate(data) cache.f.grad(G, θ, d...) x = cache.f(θ, cache.p, d...) - opt_state = Optimization.OptimizationState(; iteration = i, + opt_state = Optimization.OptimizationState(; iter = i, u = θ, objective = x[1], - solver_state = opt) + original = opt) cb_call = cache.callback(opt_state, x...) if !(cb_call isa Bool) error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.") @@ -85,14 +85,14 @@ function SciMLBase.__solve(cache::OptimizationCache{ min_err = x min_θ = copy(θ) end - if i == maxiters #Last iteration, revert to best. + if i == maxiters #Last iter, revert to best. opt = min_opt x = min_err θ = min_θ - opt_state = Optimization.OptimizationState(; iteration = i, + opt_state = Optimization.OptimizationState(; iter = i, u = θ, objective = x[1], - solver_state = opt) + original = opt) cache.callback(opt_state, x...) break end diff --git a/lib/OptimizationFlux/test/runtests.jl b/lib/OptimizationFlux/test/runtests.jl index 3ca49a562..a78a592c7 100644 --- a/lib/OptimizationFlux/test/runtests.jl +++ b/lib/OptimizationFlux/test/runtests.jl @@ -36,7 +36,7 @@ using Test end function cb(state, args...) - if state.iteration % 10 == 0 + if state.iter % 10 == 0 println(state.u) end return false diff --git a/lib/OptimizationMOI/src/nlp.jl b/lib/OptimizationMOI/src/nlp.jl index 4ad2ed9af..9ef2995b2 100644 --- a/lib/OptimizationMOI/src/nlp.jl +++ b/lib/OptimizationMOI/src/nlp.jl @@ -218,7 +218,7 @@ function MOI.eval_objective(evaluator::MOIOptimizationNLPEvaluator, x) else l = evaluator.f(x, evaluator.p) evaluator.iteration += 1 - state = Optimization.OptimizationState(iteration = evaluator.iteration, + state = Optimization.OptimizationState(iter = evaluator.iteration, u = x, objective = l[1]) evaluator.callback(state, l) diff --git a/lib/OptimizationOptimJL/src/OptimizationOptimJL.jl b/lib/OptimizationOptimJL/src/OptimizationOptimJL.jl index 5b4fbe483..500ef782c 100644 --- a/lib/OptimizationOptimJL/src/OptimizationOptimJL.jl +++ b/lib/OptimizationOptimJL/src/OptimizationOptimJL.jl @@ -135,10 +135,10 @@ function SciMLBase.__solve(cache::OptimizationCache{ function _cb(trace) θ = cache.opt isa Optim.NelderMead ? decompose_trace(trace).metadata["centroid"] : decompose_trace(trace).metadata["x"] - opt_state = Optimization.OptimizationState(iteration = trace.iteration, + opt_state = Optimization.OptimizationState(iter = trace.iteration, u = θ, objective = x[1], - solver_state = trace) + original = trace) cb_call = cache.callback(opt_state, x...) if !(cb_call isa Bool) error("The callback should return a boolean `halt` for whether to stop the optimization process.") @@ -255,10 +255,10 @@ function SciMLBase.__solve(cache::OptimizationCache{ θ = !(cache.opt isa Optim.SAMIN) && cache.opt.method == Optim.NelderMead() ? decompose_trace(trace).metadata["centroid"] : decompose_trace(trace).metadata["x"] - opt_state = Optimization.OptimizationState(iteration = trace.iteration, + opt_state = Optimization.OptimizationState(iter = trace.iteration, u = θ, objective = x[1], - solver_state = trace) + original = trace) cb_call = cache.callback(opt_state, x...) if !(cb_call isa Bool) error("The callback should return a boolean `halt` for whether to stop the optimization process.") @@ -342,10 +342,10 @@ function SciMLBase.__solve(cache::OptimizationCache{ cur, state = iterate(cache.data) function _cb(trace) - opt_state = Optimization.OptimizationState(iteration = trace.iteration, + opt_state = Optimization.OptimizationState(iter = trace.iteration, u = decompose_trace(trace).metadata["x"], objective = x[1], - solver_state = trace) + original = trace) cb_call = cache.callback(opt_state, x...) if !(cb_call isa Bool) error("The callback should return a boolean `halt` for whether to stop the optimization process.") diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index d34299b27..21f107b05 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -67,11 +67,11 @@ function SciMLBase.__solve(cache::OptimizationCache{ for (i, d) in enumerate(data) cache.f.grad(G, θ, d...) x = cache.f(θ, cache.p, d...) - opt_state = Optimization.OptimizationState(iteration = i, + opt_state = Optimization.OptimizationState(iter = i, u = θ, objective = x[1], - gradient = G, - solver_state = state) + grad = G, + original = state) cb_call = cache.callback(opt_state, x...) if !(cb_call isa Bool) error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.") @@ -87,16 +87,16 @@ function SciMLBase.__solve(cache::OptimizationCache{ min_err = x min_θ = copy(θ) end - if i == maxiters #Last iteration, revert to best. + if i == maxiters #Last iter, revert to best. opt = min_opt x = min_err θ = min_θ cache.f.grad(G, θ, d...) - opt_state = Optimization.OptimizationState(iteration = i, + opt_state = Optimization.OptimizationState(iter = i, u = θ, objective = x[1], - gradient = G, - solver_state = state) + grad = G, + original = state) cache.callback(opt_state, x...) break end diff --git a/lib/OptimizationOptimisers/src/sophia.jl b/lib/OptimizationOptimisers/src/sophia.jl index 4625418a4..94dd34321 100644 --- a/lib/OptimizationOptimisers/src/sophia.jl +++ b/lib/OptimizationOptimisers/src/sophia.jl @@ -78,11 +78,11 @@ function SciMLBase.__solve(cache::OptimizationCache{ for (i, d) in enumerate(data) f.grad(gₜ, θ, d...) x = cache.f(θ, cache.p, d...) - opt_state = Optimization.OptimizationState(; iteration = i, + opt_state = Optimization.OptimizationState(; iter = i, u = θ, objective = first(x), - gradient = gₜ, - solver_state = nothing) + grad = gₜ, + original = nothing) cb_call = cache.callback(θ, x...) if !(cb_call isa Bool) error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.") diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index bb91c07d5..d16036108 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -61,7 +61,7 @@ using Zygote prob = OptimizationProblem(optprob, x0, _p) function callback(state, l) - Optimisers.adjust!(state.solver_state, 0.1 / state.iteration) + Optimisers.adjust!(state.original, 0.1 / state.iter) return false end sol = solve(prob, diff --git a/src/stats_state.jl b/src/stats_state.jl index 775e78c28..59911a7f4 100644 --- a/src/stats_state.jl +++ b/src/stats_state.jl @@ -1,4 +1,20 @@ +""" +$(TYPEDEF) +Stores the optimization run's statistics that is returned +in the `stats` field of the `OptimizationResult`. + +## Fields +- `iterations`: number of iterations +- `time`: time taken to run the solver +- `fevals`: number of function evaluations +- `gevals`: number of gradient evaluations +- `hevals`: number of hessian evaluations + +Default values for all the field are set to 0 and hence even when +you might expect non-zero values due to unavilability of the information +from the solver it would be 0. +""" struct OptimizationStats iterations::Int time::Float64 @@ -11,16 +27,30 @@ function OptimizationStats(; iterations = 0, time = 0.0, fevals = 0, gevals = 0, OptimizationStats(iterations, time, fevals, gevals, hevals) end +""" +$(TYPEDEF) + +Stores the optimization run's state at the current iteration +and is passed to the callback function as the first argument. + +## Fields +- `iter`: current iteration +- `u`: current solution +- `objective`: current objective value +- `gradient`: current gradient +- `hessian`: current hessian +- `original`: if the solver has its own state object then it is stored here +""" struct OptimizationState{X, O, G, H, S} - iteration::Int + iter::Int u::X objective::O - gradient::G - hessian::H - solver_state::S + grad::G + hess::H + original::S end -function OptimizationState(; iteration = 0, u = nothing, objective = nothing, - gradient = nothing, hessian = nothing, solver_state = nothing) - OptimizationState(iteration, u, objective, gradient, hessian, solver_state) +function OptimizationState(; iter = 0, u = nothing, objective = nothing, + grad = nothing, hess = nothing, original = nothing) + OptimizationState(iter, u, objective, grad, hess, original) end