diff --git a/src/ensemble/ensemble_problems.jl b/src/ensemble/ensemble_problems.jl index 4e493c380..49ee18a9b 100644 --- a/src/ensemble/ensemble_problems.jl +++ b/src/ensemble/ensemble_problems.jl @@ -44,6 +44,26 @@ function EnsembleProblem(; prob, EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy) end +function EnsembleProblem(; prob, + u0s::Union{Nothing, Vector{uType}} = nothing, + prob_func = (prob, i, repeat) -> remake(prob, u0 = u0s[i]), + output_func = DEFAULT_OUTPUT_FUNC, + reduction = DEFAULT_REDUCTION, + u_init = nothing, p = nothing, + safetycopy = prob_func !== DEFAULT_PROB_FUNC) where {uType} + EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy) +end + +function EnsembleProblem(; prob, + trajectories::Int, + prob_func, + output_func = DEFAULT_OUTPUT_FUNC, + reduction = DEFAULT_REDUCTION, + u_init = nothing, p = nothing, + safetycopy = prob_func !== DEFAULT_PROB_FUNC) + EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy) +end + struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <: AbstractEnsembleProblem ensembleprob::T1 diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index db1a40a1c..9c4b9014a 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -1,5 +1,6 @@ [deps] BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d" +DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" diff --git a/test/downstream/ensemble_diffeq.jl b/test/downstream/ensemble_diffeq.jl new file mode 100644 index 000000000..f64f9b15e --- /dev/null +++ b/test/downstream/ensemble_diffeq.jl @@ -0,0 +1,8 @@ +using DifferentialEquations + +f(u, p, t) = 1.01 * u +u0 = 1 / 2 +tspan = (0.0, 1.0) +prob = ODEProblem(f, u0, tspan) +ensemble_prob = EnsembleProblem(prob, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand())) +sim = solve(ensemble_prob, EnsembleThreads(), trajectories = 10, dt = 0.1) \ No newline at end of file diff --git a/test/downstream/ensemble_nondes.jl b/test/downstream/ensemble_nondes.jl new file mode 100644 index 000000000..c1274ba1d --- /dev/null +++ b/test/downstream/ensemble_nondes.jl @@ -0,0 +1,26 @@ +using Optimization, OptimizationOptimJL, ForwardDiff, Test + +x0 = zeros(2) +rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2 +l1 = rosenbrock(x0) + +optf = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff()) +prob = OptimizationProblem(optf, x0) +sol1 = Optimization.solve(prob, OptimizationOptimJL.BFGS(), maxiters = 5) + +ensembleprob = Optimization.EnsembleProblem(prob, [x0, x0 .+ rand(2), x0 .+ rand(2), x0 .+ rand(2)]) + +sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 4, maxiters = 5) +@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective + +sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 4, maxiters = 5) +@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective + +prob = OptimizationProblem(optf, x0, lb = [-0.5, -0.5], ub = [0.5, 0.5]) +ensembleprob = Optimization.EnsembleProblem(prob, 5, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand(-0.5:0.001:0.5, 2))) + +sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 5, maxiters = 5) +@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective + +sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 5, maxiters = 5) +@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index ead9e0d72..0c58fb841 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -78,6 +78,12 @@ end @time @safetestset "Ensemble solution statistics" begin include("downstream/ensemble_stats.jl") end + @time @safetestset "Ensemble Optimization and Nonlinear problems" begin + include("downstream/ensemble_nondes.jl") + end + @time @safetestset "Ensemble with DifferentialEquations automatic algorithm selection" begin + include("downstream/ensemble_diffeq.jl") + end @time @safetestset "Symbol and integer based indexing of interpolated solutions" begin include("downstream/symbol_indexing.jl") end