Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docstrings to state and stats #657

Merged
merged 8 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions lib/OptimizationBBO/src/OptimizationBBO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,11 @@
else
n_steps = BlackBoxOptim.num_steps(trace)
curr_u = decompose_trace(trace, cache.progress)
opt_state = Optimization.OptimizationState(iteration = n_steps,
opt_state = Optimization.OptimizationState(;

Check warning on line 120 in lib/OptimizationBBO/src/OptimizationBBO.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationBBO/src/OptimizationBBO.jl#L120

Added line #L120 was not covered by tests
iter = n_steps,
u = curr_u,
objective = x[1],
solver_state = trace)
original = trace)
cb_call = cache.callback(opt_state, x...)
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@

function _cb(opt, y, fvals, perm)
curr_u = opt.logger.xbest[end]
opt_state = Optimization.OptimizationState(; iteration = length(opt.logger.fmedian),
opt_state = Optimization.OptimizationState(; iter = length(opt.logger.fmedian),

Check warning on line 81 in lib/OptimizationCMAEvolutionStrategy/src/OptimizationCMAEvolutionStrategy.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationCMAEvolutionStrategy/src/OptimizationCMAEvolutionStrategy.jl#L81

Added line #L81 was not covered by tests
u = curr_u,
objective = opt.logger.fbest[end],
solver_state = opt.logger)
original = opt.logger)

cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationCMAEvolutionStrategy/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using Test
@test 10 * sol.objective < l1

function cb(state, args...)
if state.iteration % 10 == 0
if state.iter % 10 == 0
println(state.u)
end
return false
Expand Down
4 changes: 2 additions & 2 deletions lib/OptimizationEvolutionary/src/OptimizationEvolutionary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ function SciMLBase.__solve(cache::OptimizationCache{
function _cb(trace)
curr_u = decompose_trace(trace).metadata["x"][end]
opt_state = Optimization.OptimizationState(;
iteration = decompose_trace(trace).iteration,
iter = decompose_trace(trace).iteration,
u = curr_u,
objective = x[1],
solver_state = trace)
original = trace)
cb_call = cache.callback(opt_state, trace.value...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationEvolutionary/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Random.seed!(1234)
@test sol.objective < l1

function cb(state, args...)
if state.iteration % 10 == 0
if state.iter % 10 == 0
println(state.u)
end
return false
Expand Down
10 changes: 5 additions & 5 deletions lib/OptimizationFlux/src/OptimizationFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@
for (i, d) in enumerate(data)
cache.f.grad(G, θ, d...)
x = cache.f(θ, cache.p, d...)
opt_state = Optimization.OptimizationState(; iteration = i,
opt_state = Optimization.OptimizationState(; iter = i,

Check warning on line 69 in lib/OptimizationFlux/src/OptimizationFlux.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationFlux/src/OptimizationFlux.jl#L69

Added line #L69 was not covered by tests
u = θ,
objective = x[1],
solver_state = opt)
original = opt)
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.")
Expand All @@ -85,14 +85,14 @@
min_err = x
min_θ = copy(θ)
end
if i == maxiters #Last iteration, revert to best.
if i == maxiters #Last iter, revert to best.

Check warning on line 88 in lib/OptimizationFlux/src/OptimizationFlux.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationFlux/src/OptimizationFlux.jl#L88

Added line #L88 was not covered by tests
opt = min_opt
x = min_err
θ = min_θ
opt_state = Optimization.OptimizationState(; iteration = i,
opt_state = Optimization.OptimizationState(; iter = i,

Check warning on line 92 in lib/OptimizationFlux/src/OptimizationFlux.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationFlux/src/OptimizationFlux.jl#L92

Added line #L92 was not covered by tests
u = θ,
objective = x[1],
solver_state = opt)
original = opt)
cache.callback(opt_state, x...)
break
end
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationFlux/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ using Test
end

function cb(state, args...)
if state.iteration % 10 == 0
if state.iter % 10 == 0
println(state.u)
end
return false
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationMOI/src/nlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@
else
l = evaluator.f(x, evaluator.p)
evaluator.iteration += 1
state = Optimization.OptimizationState(iteration = evaluator.iteration,
state = Optimization.OptimizationState(iter = evaluator.iteration,

Check warning on line 221 in lib/OptimizationMOI/src/nlp.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationMOI/src/nlp.jl#L221

Added line #L221 was not covered by tests
u = x,
objective = l[1])
evaluator.callback(state, l)
Expand Down
12 changes: 6 additions & 6 deletions lib/OptimizationOptimJL/src/OptimizationOptimJL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@
function _cb(trace)
θ = cache.opt isa Optim.NelderMead ? decompose_trace(trace).metadata["centroid"] :
decompose_trace(trace).metadata["x"]
opt_state = Optimization.OptimizationState(iteration = trace.iteration,
opt_state = Optimization.OptimizationState(iter = trace.iteration,
u = θ,
objective = x[1],
solver_state = trace)
original = trace)
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
Expand Down Expand Up @@ -255,10 +255,10 @@
θ = !(cache.opt isa Optim.SAMIN) && cache.opt.method == Optim.NelderMead() ?
decompose_trace(trace).metadata["centroid"] :
decompose_trace(trace).metadata["x"]
opt_state = Optimization.OptimizationState(iteration = trace.iteration,
opt_state = Optimization.OptimizationState(iter = trace.iteration,

Check warning on line 258 in lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl#L258

Added line #L258 was not covered by tests
u = θ,
objective = x[1],
solver_state = trace)
original = trace)
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
Expand Down Expand Up @@ -342,10 +342,10 @@
cur, state = iterate(cache.data)

function _cb(trace)
opt_state = Optimization.OptimizationState(iteration = trace.iteration,
opt_state = Optimization.OptimizationState(iter = trace.iteration,
u = decompose_trace(trace).metadata["x"],
objective = x[1],
solver_state = trace)
original = trace)
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
Expand Down
14 changes: 7 additions & 7 deletions lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ function SciMLBase.__solve(cache::OptimizationCache{
for (i, d) in enumerate(data)
cache.f.grad(G, θ, d...)
x = cache.f(θ, cache.p, d...)
opt_state = Optimization.OptimizationState(iteration = i,
opt_state = Optimization.OptimizationState(iter = i,
u = θ,
objective = x[1],
gradient = G,
solver_state = state)
grad = G,
original = state)
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.")
Expand All @@ -87,16 +87,16 @@ function SciMLBase.__solve(cache::OptimizationCache{
min_err = x
min_θ = copy(θ)
end
if i == maxiters #Last iteration, revert to best.
if i == maxiters #Last iter, revert to best.
opt = min_opt
x = min_err
θ = min_θ
cache.f.grad(G, θ, d...)
opt_state = Optimization.OptimizationState(iteration = i,
opt_state = Optimization.OptimizationState(iter = i,
u = θ,
objective = x[1],
gradient = G,
solver_state = state)
grad = G,
original = state)
cache.callback(opt_state, x...)
break
end
Expand Down
6 changes: 3 additions & 3 deletions lib/OptimizationOptimisers/src/sophia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@
for (i, d) in enumerate(data)
f.grad(gₜ, θ, d...)
x = cache.f(θ, cache.p, d...)
opt_state = Optimization.OptimizationState(; iteration = i,
opt_state = Optimization.OptimizationState(; iter = i,

Check warning on line 81 in lib/OptimizationOptimisers/src/sophia.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimisers/src/sophia.jl#L81

Added line #L81 was not covered by tests
u = θ,
objective = first(x),
gradient = gₜ,
solver_state = nothing)
grad = gₜ,
original = nothing)
cb_call = cache.callback(θ, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.")
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationOptimisers/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ using Zygote

prob = OptimizationProblem(optprob, x0, _p)
function callback(state, l)
Optimisers.adjust!(state.solver_state, 0.1 / state.iteration)
Optimisers.adjust!(state.original, 0.1 / state.iter)
return false
end
sol = solve(prob,
Expand Down
44 changes: 37 additions & 7 deletions src/stats_state.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
"""
$(TYPEDEF)

Stores the optimization run's statistics that is returned
in the `stats` field of the `OptimizationResult`.

## Fields
- `iterations`: number of iterations
- `time`: time taken to run the solver
- `fevals`: number of function evaluations
- `gevals`: number of gradient evaluations
- `hevals`: number of hessian evaluations

Default values for all the field are set to 0 and hence even when
you might expect non-zero values due to unavilability of the information
from the solver it would be 0.
"""
struct OptimizationStats
iterations::Int
time::Float64
Expand All @@ -11,16 +27,30 @@ function OptimizationStats(; iterations = 0, time = 0.0, fevals = 0, gevals = 0,
OptimizationStats(iterations, time, fevals, gevals, hevals)
end

"""
$(TYPEDEF)

Stores the optimization run's state at the current iteration
and is passed to the callback function as the first argument.

## Fields
- `iter`: current iteration
- `u`: current solution
- `objective`: current objective value
- `gradient`: current gradient
- `hessian`: current hessian
Comment on lines +40 to +41
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- `gradient`: current gradient
- `hessian`: current hessian
- `gradient`: current gradient
- `hessian`: current hessian

grad, hes

Make the naming match everything else, we can't just have one thing different.

- `original`: if the solver has its own state object then it is stored here
"""
struct OptimizationState{X, O, G, H, S}
iteration::Int
iter::Int
u::X
objective::O
gradient::G
hessian::H
solver_state::S
grad::G
hess::H
original::S
end

function OptimizationState(; iteration = 0, u = nothing, objective = nothing,
gradient = nothing, hessian = nothing, solver_state = nothing)
OptimizationState(iteration, u, objective, gradient, hessian, solver_state)
function OptimizationState(; iter = 0, u = nothing, objective = nothing,
grad = nothing, hess = nothing, original = nothing)
OptimizationState(iter, u, objective, grad, hess, original)
end
Loading