Skip to content

Commit

Permalink
Merge pull request #600 from Zentrik/master
Browse files Browse the repository at this point in the history
Improve Performance for OptimizationBBO
  • Loading branch information
ChrisRackauckas authored Oct 1, 2023
2 parents 5a4517d + 87f72a9 commit 7acb4d5
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 13 deletions.
14 changes: 7 additions & 7 deletions lib/OptimizationBBO/src/OptimizationBBO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ function SciMLBase.__solve(cache::Optimization.OptimizationCache{
cur, state = iterate(cache.data)

function _cb(trace)
if isnothing(cache.callback)
if cache.callback === Optimization.DEFAULT_CALLBACK
cb_call = false
else
cb_call = cache.callback(decompose_trace(trace, cache.progress), x...)
Expand All @@ -125,7 +125,7 @@ function SciMLBase.__solve(cache::Optimization.OptimizationCache{
BlackBoxOptim.shutdown_optimizer!(trace) #doesn't work
end

if !isnothing(cache.data)
if cache.data !== Optimization.DEFAULT_DATA
cur, state = iterate(cache.data, state)
end
cb_call
Expand All @@ -135,11 +135,11 @@ function SciMLBase.__solve(cache::Optimization.OptimizationCache{
maxtime = Optimization._check_and_convert_maxtime(cache.solver_args.maxtime)

_loss = function (θ)
if isnothing(cache.callback) && isnothing(cache.data)
if cache.callback === Optimization.DEFAULT_CALLBACK && cache.data === Optimization.DEFAULT_DATA
return first(cache.f(θ, cache.p))
elseif isnothing(cache.callback)
elseif cache.callback === Optimization.DEFAULT_CALLBACK
return first(cache.f(θ, cache.p, cur...))
elseif isnothing(cache.data)
elseif cache.data !== Optimization.DEFAULT_DATA
x = cache.f(θ, cache.p)
return first(x)
else
Expand All @@ -149,8 +149,8 @@ function SciMLBase.__solve(cache::Optimization.OptimizationCache{
end

opt_args = __map_optimizer_args(cache, cache.opt;
callback = isnothing(cache.callback) &&
isnothing(cache.data) ?
callback = cache.callback === Optimization.DEFAULT_CALLBACK &&
cache.data === Optimization.DEFAULT_DATA ?
nothing : _cb,
cache.solver_args...,
maxiters = maxiters,
Expand Down
2 changes: 2 additions & 0 deletions lib/OptimizationBBO/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using Test
sol = solve(prob, BBO_adaptive_de_rand_1_bin_radiuslimited())
@test 10 * sol.objective < l1

@test (@allocated solve(prob, BBO_adaptive_de_rand_1_bin_radiuslimited())) < 1e7

prob = Optimization.OptimizationProblem(optprob, nothing, _p, lb = [-1.0, -1.0],
ub = [0.8, 0.8])
sol = solve(prob, BBO_adaptive_de_rand_1_bin_radiuslimited())
Expand Down
6 changes: 3 additions & 3 deletions lib/OptimizationOptimisers/src/sophia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ function SciMLBase.__solve(cache::OptimizationCache{
maxiters = Optimization._check_and_convert_maxiters(maxiters)

_loss = function (θ)
if isnothing(cache.callback) && isnothing(data)
if cache.callback === Optimization.DEFAULT_CALLBACK && data === Optimization.DEFAULT_DATA
return first(cache.f(θ, cache.p))
elseif isnothing(cache.callback)
elseif cache.callback === Optimization.DEFAULT_CALLBACK
return first(cache.f(θ, cache.p, cur...))
elseif isnothing(data)
elseif data === Optimization.DEFAULT_DATA
x = cache.f(θ, cache.p)
return first(x)
else
Expand Down
6 changes: 3 additions & 3 deletions src/cache.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

function Base.getproperty(cache::SciMLBase.AbstractOptimizationCache, x::Symbol)
if x in fieldnames(Optimization.ReInitCache)
if x in (:u0, :p)
return getfield(cache.reinit_cache, x)
end
return getfield(cache, x)
Expand Down Expand Up @@ -52,7 +52,7 @@ struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, D, P, C} <:
end

function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data;
callback = (args...) -> (false),
callback = Optimization.DEFAULT_CALLBACK,
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
abstol::Union{Number, Nothing} = nothing,
Expand All @@ -71,7 +71,7 @@ end

function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt,
data = Optimization.DEFAULT_DATA;
callback = (args...) -> (false),
callback = Optimization.DEFAULT_CALLBACK,
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
abstol::Union{Number, Nothing} = nothing,
Expand Down
4 changes: 4 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
struct NullCallback end
(x::NullCallback)(args...) = false;
const DEFAULT_CALLBACK = NullCallback()

struct NullData end
const DEFAULT_DATA = Iterators.cycle((NullData(),))
Base.iterate(::NullData, i = 1) = nothing
Expand Down

0 comments on commit 7acb4d5

Please sign in to comment.