Skip to content

Commit

Permalink
Format and sparsearrays only 1.10
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Jan 5, 2024
1 parent 880fc54 commit 623214b
Show file tree
Hide file tree
Showing 14 changed files with 79 additions and 36 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ ProgressLogging = "0.1"
Reexport = "1.2"
ReverseDiff = "1.14"
SciMLBase = "2.16.3"
SparseArrays = "1.9, 1.10"
SparseArrays = "1.10"
SparseDiffTools = "2.14"
SymbolicIndexingInterface = "0.3"
Symbolics = "5.12"
Expand Down
10 changes: 8 additions & 2 deletions lib/OptimizationBBO/src/OptimizationBBO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ function SciMLBase.__solve(cache::Optimization.OptimizationCache{
else
n_steps = BlackBoxOptim.num_steps(trace)
curr_u = decompose_trace(trace, cache.progress)
opt_state = Optimization.OptimizationState(iteration = n_steps, u = curr_u, objective = x[1], solver_state = trace)
opt_state = Optimization.OptimizationState(iteration = n_steps,

Check warning on line 120 in lib/OptimizationBBO/src/OptimizationBBO.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationBBO/src/OptimizationBBO.jl#L120

Added line #L120 was not covered by tests
u = curr_u,
objective = x[1],
solver_state = trace)
cb_call = cache.callback(opt_state, x...)
end

Expand Down Expand Up @@ -178,7 +181,10 @@ function SciMLBase.__solve(cache::Optimization.OptimizationCache{
t1 = time()

opt_ret = Symbol(opt_res.stop_reason)
stats = Optimization.OptimizationStats(; iterations = opt_res.iterations, time = t1 - t0, fevals = opt_res.f_calls)
stats = Optimization.OptimizationStats(;

Check warning on line 184 in lib/OptimizationBBO/src/OptimizationBBO.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationBBO/src/OptimizationBBO.jl#L184

Added line #L184 was not covered by tests
iterations = opt_res.iterations,
time = t1 - t0,
fevals = opt_res.f_calls)
SciMLBase.build_solution(cache, cache.opt,
BlackBoxOptim.best_candidate(opt_res),
BlackBoxOptim.best_fitness(opt_res);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ function __map_optimizer_args(prob::OptimizationCache, opt::CMAEvolutionStrategy
end

mapped_args = (; lower = prob.lb,
upper = prob.ub, logger = CMAEvolutionStrategy.BasicLogger(prob.u0; verbosity = 0, callback = callback))
upper = prob.ub,
logger = CMAEvolutionStrategy.BasicLogger(prob.u0;
verbosity = 0,
callback = callback))

if !isnothing(maxiters)
mapped_args = (; mapped_args..., maxiter = maxiters)
Expand Down Expand Up @@ -105,7 +108,10 @@ function SciMLBase.__solve(cache::OptimizationCache{
t1 = time()

opt_ret = opt_res.stop.reason
stats = Optimization.OptimizationStats(; iterations = length(opt_res.logger.fmedian), time = t1 - t0, fevals = length(opt_res.logger.fmedian))
stats = Optimization.OptimizationStats(;

Check warning on line 111 in lib/OptimizationCMAEvolutionStrategy/src/OptimizationCMAEvolutionStrategy.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationCMAEvolutionStrategy/src/OptimizationCMAEvolutionStrategy.jl#L111

Added line #L111 was not covered by tests
iterations = length(opt_res.logger.fmedian),
time = t1 - t0,
fevals = length(opt_res.logger.fmedian))
SciMLBase.build_solution(cache, cache.opt,
opt_res.logger.xbest[end],
opt_res.logger.fbest[end]; original = opt_res,
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationCMAEvolutionStrategy/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using Test
@test 10 * sol.objective < l1

function cb(state, args...)
if state.iteration %10 == 0
if state.iteration % 10 == 0
println(state.u)
end
return false
Expand Down
7 changes: 4 additions & 3 deletions lib/OptimizationEvolutionary/src/OptimizationEvolutionary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ function SciMLBase.__solve(cache::OptimizationCache{

function _cb(trace)
curr_u = decompose_trace(trace).metadata["x"][end]
opt_state = Optimization.OptimizationState(; iteration = decompose_trace(trace).iteration,
opt_state = Optimization.OptimizationState(;

Check warning on line 88 in lib/OptimizationEvolutionary/src/OptimizationEvolutionary.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationEvolutionary/src/OptimizationEvolutionary.jl#L88

Added line #L88 was not covered by tests
iteration = decompose_trace(trace).iteration,
u = curr_u,
objective = x[1],
solver_state = trace)
Expand Down Expand Up @@ -132,8 +133,8 @@ function SciMLBase.__solve(cache::OptimizationCache{
end
t1 = time()
opt_ret = Symbol(Evolutionary.converged(opt_res))
stats = Optimization.OptimizationStats(; iterations = opt_res.iterations
, time = t1 - t0, fevals = opt_res.f_calls)
stats = Optimization.OptimizationStats(; iterations = opt_res.iterations,

Check warning on line 136 in lib/OptimizationEvolutionary/src/OptimizationEvolutionary.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationEvolutionary/src/OptimizationEvolutionary.jl#L136

Added line #L136 was not covered by tests
time = t1 - t0, fevals = opt_res.f_calls)
SciMLBase.build_solution(cache, cache.opt,
Evolutionary.minimizer(opt_res),
Evolutionary.minimum(opt_res); original = opt_res,
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationEvolutionary/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Random.seed!(1234)
@test sol.objective < l1

function cb(state, args...)
if state.iteration %10 == 0
if state.iteration % 10 == 0
println(state.u)
end
return false
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationFlux/src/OptimizationFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
end

t1 = time()
stats = Optimization.OptimizationStats(; iterations = maxiters,
stats = Optimization.OptimizationStats(; iterations = maxiters,

Check warning on line 105 in lib/OptimizationFlux/src/OptimizationFlux.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationFlux/src/OptimizationFlux.jl#L105

Added line #L105 was not covered by tests
time = t1 - t0, fevals = maxiters, gevals = maxiters)
SciMLBase.build_solution(cache, opt, θ, x[1], stats = stats)
# here should be build_solution to create the output message
Expand Down
8 changes: 4 additions & 4 deletions lib/OptimizationGCMAES/src/OptimizationGCMAES.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ function SciMLBase.__solve(cache::OptimizationCache{
cache.ub; opt_args...)
end
t1 = time()
stats = Optimization.OptimizationStats(; iterations = maxiters === nothing ? 0 : maxiters,
time = t1 - t0)
stats = Optimization.OptimizationStats(;

Check warning on line 117 in lib/OptimizationGCMAES/src/OptimizationGCMAES.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationGCMAES/src/OptimizationGCMAES.jl#L117

Added line #L117 was not covered by tests
iterations = maxiters === nothing ? 0 : maxiters,
time = t1 - t0)
SciMLBase.build_solution(cache, cache.opt,
opt_xmin, opt_fmin; retcode = Symbol(Bool(opt_ret)),
stats = stats
)
stats = stats)
end

end
4 changes: 3 additions & 1 deletion lib/OptimizationMOI/src/nlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ function MOI.eval_objective(evaluator::MOIOptimizationNLPEvaluator, x)
else
l = evaluator.f(x, evaluator.p)
evaluator.iteration += 1
state = Optimization.OptimizationState(iteration = evaluator.iteration, u = x, objective = l[1])
state = Optimization.OptimizationState(iteration = evaluator.iteration,

Check warning on line 221 in lib/OptimizationMOI/src/nlp.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationMOI/src/nlp.jl#L221

Added line #L221 was not covered by tests
u = x,
objective = l[1])
evaluator.callback(state, l)
return l
end
Expand Down
37 changes: 26 additions & 11 deletions lib/OptimizationOptimJL/src/OptimizationOptimJL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,12 @@ function SciMLBase.__solve(cache::OptimizationCache{
error("Use OptimizationFunction to pass the derivatives or automatically generate them with one of the autodiff backends")

function _cb(trace)
θ = cache.opt isa Optim.NelderMead ? decompose_trace(trace).metadata["centroid"] : decompose_trace(trace).metadata["x"]
opt_state = Optimization.OptimizationState(iteration = trace.iteration, u = θ, objective = x[1], solver_state = trace)
θ = cache.opt isa Optim.NelderMead ? decompose_trace(trace).metadata["centroid"] :
decompose_trace(trace).metadata["x"]
opt_state = Optimization.OptimizationState(iteration = trace.iteration,
u = θ,
objective = x[1],
solver_state = trace)
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
Expand Down Expand Up @@ -208,8 +212,9 @@ function SciMLBase.__solve(cache::OptimizationCache{
opt_res = Optim.optimize(optim_f, cache.u0, cache.opt, opt_args)
t1 = time()
opt_ret = Symbol(Optim.converged(opt_res))
stats = Optimization.OptimizationStats(; iterations = opt_res.iterations,
time = t1 - t0, fevals = opt_res.f_calls, gevals = opt_res.g_calls, hevals = opt_res.h_calls)
stats = Optimization.OptimizationStats(; iterations = opt_res.iterations,
time = t1 - t0, fevals = opt_res.f_calls, gevals = opt_res.g_calls,
hevals = opt_res.h_calls)
SciMLBase.build_solution(cache, cache.opt,
opt_res.minimizer,
cache.sense === Optimization.MaxSense ? -opt_res.minimum :
Expand Down Expand Up @@ -247,8 +252,13 @@ function SciMLBase.__solve(cache::OptimizationCache{
cur, state = iterate(cache.data)

function _cb(trace)
θ = !(cache.opt isa Optim.SAMIN) && cache.opt.method == Optim.NelderMead() ? decompose_trace(trace).metadata["centroid"] : decompose_trace(trace).metadata["x"]
opt_state = Optimization.OptimizationState(iteration = trace.iteration, u = θ, objective = x[1], solver_state = trace)
θ = !(cache.opt isa Optim.SAMIN) && cache.opt.method == Optim.NelderMead() ?

Check warning on line 255 in lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl#L255

Added line #L255 was not covered by tests
decompose_trace(trace).metadata["centroid"] :
decompose_trace(trace).metadata["x"]
opt_state = Optimization.OptimizationState(iteration = trace.iteration,

Check warning on line 258 in lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl#L258

Added line #L258 was not covered by tests
u = θ,
objective = x[1],
solver_state = trace)
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
Expand Down Expand Up @@ -297,8 +307,9 @@ function SciMLBase.__solve(cache::OptimizationCache{
opt_res = Optim.optimize(optim_f, cache.lb, cache.ub, cache.u0, cache.opt, opt_args)
t1 = time()
opt_ret = Symbol(Optim.converged(opt_res))
stats = Optimization.OptimizationStats(; iterations = opt_res.iterations,
time = t1 - t0, fevals = opt_res.f_calls, gevals = opt_res.g_calls, hevals = opt_res.h_calls)
stats = Optimization.OptimizationStats(; iterations = opt_res.iterations,

Check warning on line 310 in lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

View check run for this annotation

Codecov / codecov/patch

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl#L310

Added line #L310 was not covered by tests
time = t1 - t0, fevals = opt_res.f_calls, gevals = opt_res.g_calls,
hevals = opt_res.h_calls)
SciMLBase.build_solution(cache, cache.opt,
opt_res.minimizer, opt_res.minimum;
original = opt_res, retcode = opt_ret, stats = stats)
Expand Down Expand Up @@ -331,7 +342,10 @@ function SciMLBase.__solve(cache::OptimizationCache{
cur, state = iterate(cache.data)

function _cb(trace)
opt_state = Optimization.OptimizationState(iteration = trace.iteration, u = decompose_trace(trace).metadata["x"], objective = x[1], solver_state = trace)
opt_state = Optimization.OptimizationState(iteration = trace.iteration,
u = decompose_trace(trace).metadata["x"],
objective = x[1],
solver_state = trace)
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
Expand Down Expand Up @@ -412,8 +426,9 @@ function SciMLBase.__solve(cache::OptimizationCache{
opt_res = Optim.optimize(optim_f, optim_fc, cache.u0, cache.opt, opt_args)
t1 = time()
opt_ret = Symbol(Optim.converged(opt_res))
stats = Optimization.OptimizationStats(; iterations = opt_res.iterations,
time = t1 - t0, fevals = opt_res.f_calls, gevals = opt_res.g_calls, hevals = opt_res.h_calls)
stats = Optimization.OptimizationStats(; iterations = opt_res.iterations,
time = t1 - t0, fevals = opt_res.f_calls, gevals = opt_res.g_calls,
hevals = opt_res.h_calls)
SciMLBase.build_solution(cache, cache.opt,
opt_res.minimizer, opt_res.minimum;
original = opt_res, retcode = opt_ret,
Expand Down
12 changes: 10 additions & 2 deletions lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ function SciMLBase.__solve(cache::OptimizationCache{
for (i, d) in enumerate(data)
cache.f.grad(G, θ, d...)
x = cache.f(θ, cache.p, d...)
opt_state = Optimization.OptimizationState(iteration = i, u = θ, objective = x[1], gradient = G, solver_state = state)
opt_state = Optimization.OptimizationState(iteration = i,
u = θ,
objective = x[1],
gradient = G,
solver_state = state)
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.")
Expand All @@ -88,7 +92,11 @@ function SciMLBase.__solve(cache::OptimizationCache{
x = min_err
θ = min_θ
cache.f.grad(G, θ, d...)
opt_state = Optimization.OptimizationState(iteration = i, u = θ, objective = x[1], gradient = G, solver_state = state)
opt_state = Optimization.OptimizationState(iteration = i,
u = θ,
objective = x[1],
gradient = G,
solver_state = state)
cache.callback(opt_state, x...)
break
end
Expand Down
10 changes: 7 additions & 3 deletions lib/OptimizationOptimisers/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,15 @@ using Zygote

prob = OptimizationProblem(optprob, x0, _p)
function callback(state, l)
Optimisers.adjust!(state.solver_state, 0.1/state.iteration)
Optimisers.adjust!(state.solver_state, 0.1 / state.iteration)
return false
end
sol = solve(prob, Optimisers.Adam(0.1), maxiters = 1000, progress = false, callback = callback)
sol = solve(prob,
Optimisers.Adam(0.1),
maxiters = 1000,
progress = false,
callback = callback)
end

@test_throws ArgumentError sol = solve(prob, Optimisers.Adam())
@test_throws ArgumentError sol=solve(prob, Optimisers.Adam())
end
8 changes: 5 additions & 3 deletions src/stats_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ struct OptimizationStats
hevals::Int
end

OptimizationStats(; iterations = 0, time = 0.0, fevals = 0, gevals = 0, hevals = 0) =
function OptimizationStats(; iterations = 0, time = 0.0, fevals = 0, gevals = 0, hevals = 0)
OptimizationStats(iterations, time, fevals, gevals, hevals)
end

struct OptimizationState{X, O, G, H, S}
iteration::Int
Expand All @@ -19,6 +20,7 @@ struct OptimizationState{X, O, G, H, S}
solver_state::S
end

OptimizationState(; iteration = 0, u = nothing, objective = nothing,
gradient = nothing, hessian = nothing, solver_state = nothing) =
function OptimizationState(; iteration = 0, u = nothing, objective = nothing,
gradient = nothing, hessian = nothing, solver_state = nothing)
OptimizationState(iteration, u, objective, gradient, hessian, solver_state)
end
1 change: 0 additions & 1 deletion test/diffeqfluxtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ end

iter = 0
callback = function (state, l, pred)

display(l)

# using `remake` to re-create our `prob` with current parameters `p`
Expand Down

0 comments on commit 623214b

Please sign in to comment.