Skip to content

Commit

Permalink
Run optimization tests and add constructors for it
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Nov 7, 2023
1 parent 7e7367b commit dea66d1
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 5 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down
2 changes: 1 addition & 1 deletion src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import ADTypes: AbstractADType
import ChainRulesCore
import ZygoteRules: @adjoint
import FillArrays

import QuasiMonteCarlo
using Reexport
using SciMLOperators
using SciMLOperators:
Expand Down
17 changes: 17 additions & 0 deletions src/ensemble/ensemble_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ function EnsembleProblem(; prob,
EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy)
end

#since NonlinearProblem might want to use this dispatch as well
function SciMLBase.EnsembleProblem(prob::AbstractSciMLProblem, u0s::Vector{Vector{T}}; kwargs...) where {T}
prob_func = (prob, i, repeat = nothing) -> remake(prob, u0 = u0s[i])
return SciMLBase.EnsembleProblem(prob; prob_func, kwargs...)

Check warning on line 50 in src/ensemble/ensemble_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_problems.jl#L48-L50

Added lines #L48 - L50 were not covered by tests
end

#only makes sense for OptimizationProblem, might make sense for IntervalNonlinearProblem
function SciMLBase.EnsembleProblem(prob::OptimizationProblem, trajectories::Int; kwargs...)
if prob.lb !== nothing && prob.ub !== nothing
u0s = QuasiMonteCarlo.sample(trajectories, prob.lb, prob.ub, QuasiMonteCarlo.LatinHypercubeSample())
prob_func = (prob, i, repeat = nothing) -> remake(prob, u0 = u0s[:, i])

Check warning on line 57 in src/ensemble/ensemble_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_problems.jl#L54-L57

Added lines #L54 - L57 were not covered by tests
else
error("EnsembleProblem with `trajectories` as second argument requires lower and upper bounds to be defined in the `OptimizationProblem`.")

Check warning on line 59 in src/ensemble/ensemble_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_problems.jl#L59

Added line #L59 was not covered by tests
end
return SciMLBase.EnsembleProblem(prob; prob_func, kwargs...)

Check warning on line 61 in src/ensemble/ensemble_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_problems.jl#L61

Added line #L61 was not covered by tests
end

struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <:
AbstractEnsembleProblem
ensembleprob::T1
Expand Down
14 changes: 13 additions & 1 deletion test/downstream/ensemble_nondes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,16 @@ sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThrea
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective

sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 5, maxiters = 5)
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective

using NonlinearSolve

f(u, p) = u .* u .- p
u0 = [1.0, 1.0]
p = 2.0
prob = NonlinearProblem(f, u0, p)
ensembleprob = EnsembleProblem(prob, [u0, u0 .+ rand(2), u0 .+ rand(2), u0 .+ rand(2)])

sol = solve(ensembleprob, EnsembleThreads(), trajectories = 4, maxiters = 100)

sol = solve(ensembleprob, EnsembleDistributed(), trajectories = 4, maxiters = 100)
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ end
@time @safetestset "Ensemble solution statistics" begin
include("downstream/ensemble_stats.jl")
end
# @time @safetestset "Ensemble Optimization and Nonlinear problems" begin
# include("downstream/ensemble_nondes.jl")
# end
@time @safetestset "Ensemble Optimization and Nonlinear problems" begin
include("downstream/ensemble_nondes.jl")
end
@time @safetestset "Ensemble with DifferentialEquations automatic algorithm selection" begin
include("downstream/ensemble_diffeq.jl")
end
Expand Down

0 comments on commit dea66d1

Please sign in to comment.