Skip to content

Commit

Permalink
Add suggestions from review, fieldnames changed
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Jan 5, 2024
1 parent 4ae646c commit d026405
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion lib/OptimizationBBO/src/OptimizationBBO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ function SciMLBase.__solve(cache::Optimization.OptimizationCache{
else
n_steps = BlackBoxOptim.num_steps(trace)
curr_u = decompose_trace(trace, cache.progress)
opt_state = Optimization.OptimizationState(iteration = n_steps, u = curr_u, objective = x[1], solver_state = trace)
opt_state = Optimization.OptimizationState(; iteration = n_steps, u = curr_u, objective = x[1], original = trace)

Check warning on line 120 in lib/OptimizationBBO/src/OptimizationBBO.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationBBO/src/OptimizationBBO.jl#L120

Added line #L120 was not covered by tests
cb_call = cache.callback(opt_state, x...)
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
opt_state = Optimization.OptimizationState(; iteration = length(opt.logger.fmedian),
u = curr_u,
objective = opt.logger.fbest[end],
solver_state = opt.logger)
original = opt.logger)

cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
opt_state = Optimization.OptimizationState(; iteration = decompose_trace(trace).iteration,
u = curr_u,
objective = x[1],
solver_state = trace)
original = trace)
cb_call = cache.callback(opt_state, trace.value...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
Expand Down
4 changes: 2 additions & 2 deletions lib/OptimizationFlux/src/OptimizationFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
opt_state = Optimization.OptimizationState(; iteration = i,
u = θ,
objective = x[1],
solver_state = opt)
original = opt)
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.")
Expand All @@ -92,7 +92,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
opt_state = Optimization.OptimizationState(; iteration = i,
u = θ,
objective = x[1],
solver_state = opt)
original = opt)
cache.callback(opt_state, x...)
break
end
Expand Down
6 changes: 3 additions & 3 deletions lib/OptimizationOptimJL/src/OptimizationOptimJL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ function SciMLBase.__solve(cache::OptimizationCache{

function _cb(trace)
θ = cache.opt isa Optim.NelderMead ? decompose_trace(trace).metadata["centroid"] : decompose_trace(trace).metadata["x"]
opt_state = Optimization.OptimizationState(iteration = trace.iteration, u = θ, objective = x[1], solver_state = trace)
opt_state = Optimization.OptimizationState(iteration = trace.iteration, u = θ, objective = x[1], original = trace)

Check warning on line 137 in lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl#L137

Added line #L137 was not covered by tests
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
Expand Down Expand Up @@ -248,7 +248,7 @@ function SciMLBase.__solve(cache::OptimizationCache{

function _cb(trace)
θ = !(cache.opt isa Optim.SAMIN) && cache.opt.method == Optim.NelderMead() ? decompose_trace(trace).metadata["centroid"] : decompose_trace(trace).metadata["x"]
opt_state = Optimization.OptimizationState(iteration = trace.iteration, u = θ, objective = x[1], solver_state = trace)
opt_state = Optimization.OptimizationState(iteration = trace.iteration, u = θ, objective = x[1], original = trace)

Check warning on line 251 in lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl#L251

Added line #L251 was not covered by tests
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
Expand Down Expand Up @@ -331,7 +331,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
cur, state = iterate(cache.data)

function _cb(trace)
opt_state = Optimization.OptimizationState(iteration = trace.iteration, u = decompose_trace(trace).metadata["x"], objective = x[1], solver_state = trace)
opt_state = Optimization.OptimizationState(iteration = trace.iteration, u = decompose_trace(trace).metadata["x"], objective = x[1], original = trace)

Check warning on line 334 in lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl#L334

Added line #L334 was not covered by tests
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
Expand Down
4 changes: 2 additions & 2 deletions lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
for (i, d) in enumerate(data)
cache.f.grad(G, θ, d...)
x = cache.f(θ, cache.p, d...)
opt_state = Optimization.OptimizationState(iteration = i, u = θ, objective = x[1], gradient = G, solver_state = state)
opt_state = Optimization.OptimizationState(iteration = i, u = θ, objective = x[1], grad = G, original = state)

Check warning on line 70 in lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl#L70

Added line #L70 was not covered by tests
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.")
Expand All @@ -88,7 +88,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
x = min_err
θ = min_θ
cache.f.grad(G, θ, d...)
opt_state = Optimization.OptimizationState(iteration = i, u = θ, objective = x[1], gradient = G, solver_state = state)
opt_state = Optimization.OptimizationState(iteration = i, u = θ, objective = x[1], grad = G, original = state)

Check warning on line 91 in lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl#L91

Added line #L91 was not covered by tests
cache.callback(opt_state, x...)
break
end
Expand Down
10 changes: 5 additions & 5 deletions src/stats_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ Stores the optimization run's state at the current iteration
and is passed to the callback function as the first argument.
## Fields
- `iteration`: current iteration
- `iter`: current iteration
- `u`: current solution
- `objective`: current objective value
- `gradient`: current gradient
- `hessian`: current hessian
- `solver_state`: if the solver has its own state object then it is stored here
- `original`: if the solver has its own state object then it is stored here
"""
struct OptimizationState{X, O, G, H, S}
iteration::Int
u::X
objective::O
gradient::G
hessian::H
solver_state::S
grad::G
hess::H
original::S
end

function OptimizationState(; iteration = 0, u = nothing, objective = nothing,

Check warning on line 53 in src/stats_state.jl

View check run for this annotation

Codecov / codecov/patch

src/stats_state.jl#L53

Added line #L53 was not covered by tests
Expand Down

0 comments on commit d026405

Please sign in to comment.