Skip to content

Commit

Permalink
Use EnsembleProblem instead of custom struct
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Oct 27, 2023
1 parent 4eef579 commit b5709fe
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/Optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ include("utils.jl")
include("function.jl")
include("adtypes.jl")
include("cache.jl")
include("ensemble.jl")

@static if !isdefined(Base, :get_extension)
function __init__()
Expand Down
21 changes: 14 additions & 7 deletions src/ensemble.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
struct EnsembleOptimizationProblem{T1} <: SciMLBase.AbstractEnsembleProblem
prob::OptimizationProblem{iip, F, T1} where {iip, F}
u0s::Vector{T1}
function SciMLBase.EnsembleProblem(prob::OptimizationProblem, u0s::Vector{Vector{T}}; kwargs...) where {T}
prob_func = (prob, i, repeat = nothing) -> remake(prob, u0 = u0s[i])
return SciMLBase.EnsembleProblem(prob; prob_func, kwargs...)

Check warning on line 3 in src/ensemble.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble.jl#L1-L3

Added lines #L1 - L3 were not covered by tests
end

function SciMLBase.__init(prob::EnsembleOptimizationProblem{T}, args...; kwargs...) where {T <: OptimizationProblem}
probs = [remake(prob.prob, u0=u0; kwargs...) for u0 in prob.u0s]
return [SciMLBase.__init(prob, args...; kwargs...) for prob in probs]
function SciMLBase.init(prob::EnsembleProblem{T}, args...; kwargs...) where {T <: OptimizationProblem}
SciMLBase.__init(prob, args...; kwargs...)

Check warning on line 7 in src/ensemble.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble.jl#L6-L7

Added lines #L6 - L7 were not covered by tests
end

function SciMLBase.__solve(caches::Vector{OptimizationCache}, args...; kwargs...)
function SciMLBase.__init(prob::EnsembleProblem{T}, args...; trajectories, kwargs...) where {T <: OptimizationProblem}
return [SciMLBase.__init(prob.prob_func(prob.prob, i), args...; kwargs...) for i in 1:trajectories]

Check warning on line 11 in src/ensemble.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble.jl#L10-L11

Added lines #L10 - L11 were not covered by tests
end

function SciMLBase.solve!(cache::Vector{<:OptimizationCache}; kwargs...)
return [SciMLBase.solve!(cache[i]; kwargs...) for i in eachindex(cache)]

Check warning on line 15 in src/ensemble.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble.jl#L14-L15

Added lines #L14 - L15 were not covered by tests
end

function SciMLBase.__solve(caches::Vector{<:OptimizationCache}, args...; kwargs...)
return [SciMLBase.__solve(cache, args...; kwargs...) for cache in caches]

Check warning on line 19 in src/ensemble.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble.jl#L18-L19

Added lines #L18 - L19 were not covered by tests
end
15 changes: 15 additions & 0 deletions test/ensemble.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using Optimization, OptimizationOptimJL, ForwardDiff

x0 = zeros(2)
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
l1 = rosenbrock(x0)

optf = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff())
prob = OptimizationProblem(optf, x0)
sol1 = Optimization.solve(prob, OptimizationOptimJL.BFGS(), maxiters = 5)


ensembleprob = Optimization.EnsembleProblem(prob, [x0, x0 .+ rand(2), x0 .+ rand(2), x0 .+ rand(2)])
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), trajectories = 4, maxiters = 5)

@test findmin(i -> sol[i].objective, 1:4) < sol1.objective

0 comments on commit b5709fe

Please sign in to comment.