Skip to content

Commit

Permalink
Merge pull request #534 from SciML/optensemble
Browse files Browse the repository at this point in the history
Relax type of alg in ensemble solve for optimization
  • Loading branch information
ChrisRackauckas authored Nov 8, 2023
2 parents 88890d5 + ebcbae3 commit 06d5c2c
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 3 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down Expand Up @@ -78,6 +79,7 @@ Statistics = "1"
SymbolicIndexingInterface = "0.2"
Tables = "1"
TruncatedStacktraces = "1"
QuasiMonteCarlo = "0.3"
Zygote = "0.6"
julia = "1.9"

Expand Down
2 changes: 1 addition & 1 deletion src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import EnumX
import TruncatedStacktraces
import ADTypes: AbstractADType
import FillArrays

import QuasiMonteCarlo
using Reexport
using SciMLOperators
using SciMLOperators:
Expand Down
4 changes: 2 additions & 2 deletions src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ function __solve(prob::EnsembleProblem{<:AbstractVector{<:AbstractSciMLProblem}}
end

function __solve(prob::AbstractEnsembleProblem,
alg::Union{AbstractDEAlgorithm, Nothing},
alg::A,
ensemblealg::BasicEnsembleAlgorithm;
trajectories, batch_size = trajectories,
pmap_batch_size = batch_size ÷ 100 > 0 ? batch_size ÷ 100 : 1, kwargs...)
pmap_batch_size = batch_size ÷ 100 > 0 ? batch_size ÷ 100 : 1, kwargs...) where {A}
num_batches = trajectories ÷ batch_size
num_batches < 1 &&
error("trajectories ÷ batch_size cannot be less than 1, got $num_batches")
Expand Down
17 changes: 17 additions & 0 deletions src/ensemble/ensemble_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ function EnsembleProblem(; prob,
EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy)
end

#since NonlinearProblem might want to use this dispatch as well
function SciMLBase.EnsembleProblem(prob::AbstractSciMLProblem, u0s::Vector{Vector{T}}; kwargs...) where {T}
prob_func = (prob, i, repeat = nothing) -> remake(prob, u0 = u0s[i])
return SciMLBase.EnsembleProblem(prob; prob_func, kwargs...)
end

#only makes sense for OptimizationProblem, might make sense for IntervalNonlinearProblem
function SciMLBase.EnsembleProblem(prob::OptimizationProblem, trajectories::Int; kwargs...)
if prob.lb !== nothing && prob.ub !== nothing
u0s = QuasiMonteCarlo.sample(trajectories, prob.lb, prob.ub, QuasiMonteCarlo.LatinHypercubeSample())
prob_func = (prob, i, repeat = nothing) -> remake(prob, u0 = u0s[:, i])
else
error("EnsembleProblem with `trajectories` as second argument requires lower and upper bounds to be defined in the `OptimizationProblem`.")
end
return SciMLBase.EnsembleProblem(prob; prob_func, kwargs...)
end

struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <:
AbstractEnsembleProblem
ensembleprob::T1
Expand Down
4 changes: 4 additions & 0 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ function solve(prob::OptimizationProblem, alg, args...;
end
end

function SciMLBase.solve(prob::EnsembleProblem{T}, args...; kwargs...) where {T <: OptimizationProblem}
return SciMLBase.__solve(prob, args...; kwargs...)
end

function _check_opt_alg(prob::OptimizationProblem, alg; kwargs...)
!allowsbounds(alg) && (!isnothing(prob.lb) || !isnothing(prob.ub)) &&
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) does not support box constraints. Either remove the `lb` or `ub` bounds passed to `OptimizationProblem` or use a different algorithm."))
Expand Down
3 changes: 3 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
[deps]
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Expand Down
9 changes: 9 additions & 0 deletions test/downstream/ensemble_diffeq.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using DifferentialEquations

prob = ODEProblem((u, p, t) -> 1.01u, 0.5, (0.0, 1.0))
function prob_func(prob, i, repeat)
remake(prob, u0 = rand() * prob.u0)
end
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10)
@test sim isa EnsembleSolution
38 changes: 38 additions & 0 deletions test/downstream/ensemble_nondes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using Optimization, OptimizationOptimJL, ForwardDiff, Test

x0 = zeros(2)
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
l1 = rosenbrock(x0)

optf = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff())
prob = OptimizationProblem(optf, x0)
sol1 = Optimization.solve(prob, OptimizationOptimJL.BFGS(), maxiters = 5)

ensembleprob = Optimization.EnsembleProblem(prob, [x0, x0 .+ rand(2), x0 .+ rand(2), x0 .+ rand(2)])

sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 4, maxiters = 5)
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective

sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 4, maxiters = 5)
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective

prob = OptimizationProblem(optf, x0, lb = [-0.5, -0.5], ub = [0.5, 0.5])
ensembleprob = Optimization.EnsembleProblem(prob, 5, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand(-0.5:0.001:0.5, 2)))

sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 5, maxiters = 5)
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective

sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 5, maxiters = 5)
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective

using NonlinearSolve

f(u, p) = u .* u .- p
u0 = [1.0, 1.0]
p = 2.0
prob = NonlinearProblem(f, u0, p)
ensembleprob = EnsembleProblem(prob, [u0, u0 .+ rand(2), u0 .+ rand(2), u0 .+ rand(2)])

sol = solve(ensembleprob, EnsembleThreads(), trajectories = 4, maxiters = 100)

sol = solve(ensembleprob, EnsembleDistributed(), trajectories = 4, maxiters = 100)
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ end
@time @safetestset "Ensemble solution statistics" begin
include("downstream/ensemble_stats.jl")
end
@time @safetestset "Ensemble Optimization and Nonlinear problems" begin
include("downstream/ensemble_nondes.jl")
end
@time @safetestset "Ensemble with DifferentialEquations automatic algorithm selection" begin
include("downstream/ensemble_diffeq.jl")
end
@time @safetestset "Symbol and integer based indexing of interpolated solutions" begin
include("downstream/symbol_indexing.jl")
end
Expand Down

0 comments on commit 06d5c2c

Please sign in to comment.