Skip to content

Commit

Permalink
Add more methods for EnsembleProblem and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Nov 1, 2023
1 parent 8c17212 commit 23754a8
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 0 deletions.
20 changes: 20 additions & 0 deletions src/ensemble/ensemble_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,26 @@ function EnsembleProblem(; prob,
EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy)
end

function EnsembleProblem(; prob,

Check warning on line 47 in src/ensemble/ensemble_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_problems.jl#L47

Added line #L47 was not covered by tests
u0s::Union{Nothing, Vector{uType}} = nothing,
prob_func = (prob, i, repeat) -> remake(prob, u0 = u0s[i]),
output_func = DEFAULT_OUTPUT_FUNC,
reduction = DEFAULT_REDUCTION,
u_init = nothing, p = nothing,
safetycopy = prob_func !== DEFAULT_PROB_FUNC) where {uType}
EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy)

Check warning on line 54 in src/ensemble/ensemble_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_problems.jl#L54

Added line #L54 was not covered by tests
end

function EnsembleProblem(; prob,

Check warning on line 57 in src/ensemble/ensemble_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_problems.jl#L57

Added line #L57 was not covered by tests
trajectories::Int,
prob_func,
output_func = DEFAULT_OUTPUT_FUNC,
reduction = DEFAULT_REDUCTION,
u_init = nothing, p = nothing,
safetycopy = prob_func !== DEFAULT_PROB_FUNC)
EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy)

Check warning on line 64 in src/ensemble/ensemble_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_problems.jl#L64

Added line #L64 was not covered by tests
end

struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <:
AbstractEnsembleProblem
ensembleprob::T1
Expand Down
1 change: 1 addition & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Expand Down
8 changes: 8 additions & 0 deletions test/downstream/ensemble_diffeq.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using DifferentialEquations

f(u, p, t) = 1.01 * u
u0 = 1 / 2
tspan = (0.0, 1.0)
prob = ODEProblem(f, u0, tspan)
ensemble_prob = EnsembleProblem(prob, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand()))
sim = solve(ensemble_prob, EnsembleThreads(), trajectories = 10, dt = 0.1)
26 changes: 26 additions & 0 deletions test/downstream/ensemble_nondes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using Optimization, OptimizationOptimJL, ForwardDiff, Test

x0 = zeros(2)
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
l1 = rosenbrock(x0)

optf = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff())
prob = OptimizationProblem(optf, x0)
sol1 = Optimization.solve(prob, OptimizationOptimJL.BFGS(), maxiters = 5)

ensembleprob = Optimization.EnsembleProblem(prob, [x0, x0 .+ rand(2), x0 .+ rand(2), x0 .+ rand(2)])

sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 4, maxiters = 5)
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective

sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 4, maxiters = 5)
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective

prob = OptimizationProblem(optf, x0, lb = [-0.5, -0.5], ub = [0.5, 0.5])
ensembleprob = Optimization.EnsembleProblem(prob, 5, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand(-0.5:0.001:0.5, 2)))

sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 5, maxiters = 5)
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective

sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 5, maxiters = 5)
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ end
@time @safetestset "Ensemble solution statistics" begin
include("downstream/ensemble_stats.jl")
end
@time @safetestset "Ensemble Optimization and Nonlinear problems" begin
include("downstream/ensemble_nondes.jl")
end
@time @safetestset "Ensemble with DifferentialEquations automatic algorithm selection" begin
include("downstream/ensemble_diffeq.jl")
end
@time @safetestset "Symbol and integer based indexing of interpolated solutions" begin
include("downstream/symbol_indexing.jl")
end
Expand Down

0 comments on commit 23754a8

Please sign in to comment.