-
-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Relax type of alg in ensemble solve for optimization #534
Changes from 5 commits
8c17212
23754a8
9431771
275cfe1
7e7367b
dea66d1
1d87d19
e85bdb3
145f495
ebcbae3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
using DifferentialEquations | ||
|
||
prob = ODEProblem((u, p, t) -> 1.01u, 0.5, (0.0, 1.0)) | ||
function prob_func(prob, i, repeat) | ||
remake(prob, u0 = rand() * prob.u0) | ||
end | ||
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) | ||
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10) | ||
@test sim isa EnsembleSolution |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
using Optimization, OptimizationOptimJL, ForwardDiff, Test | ||
|
||
x0 = zeros(2) | ||
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2 | ||
l1 = rosenbrock(x0) | ||
|
||
optf = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff()) | ||
prob = OptimizationProblem(optf, x0) | ||
sol1 = Optimization.solve(prob, OptimizationOptimJL.BFGS(), maxiters = 5) | ||
|
||
ensembleprob = Optimization.EnsembleProblem(prob, [x0, x0 .+ rand(2), x0 .+ rand(2), x0 .+ rand(2)]) | ||
|
||
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 4, maxiters = 5) | ||
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective | ||
|
||
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 4, maxiters = 5) | ||
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective | ||
|
||
prob = OptimizationProblem(optf, x0, lb = [-0.5, -0.5], ub = [0.5, 0.5]) | ||
ensembleprob = Optimization.EnsembleProblem(prob, 5, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand(-0.5:0.001:0.5, 2))) | ||
|
||
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 5, maxiters = 5) | ||
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective | ||
|
||
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 5, maxiters = 5) | ||
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,6 +78,12 @@ end | |
@time @safetestset "Ensemble solution statistics" begin | ||
include("downstream/ensemble_stats.jl") | ||
end | ||
# @time @safetestset "Ensemble Optimization and Nonlinear problems" begin | ||
# include("downstream/ensemble_nondes.jl") | ||
# end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this intended? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah it was to show that other tests all pass, to get this one working it needs the dispatches we have been discussing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm lost. What's the version that works with all tests? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The one without type assertion in the __solve method There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you put it this in the form you believe works for everything and we can check the downstream tests and these tests? |
||
@time @safetestset "Ensemble with DifferentialEquations automatic algorithm selection" begin | ||
include("downstream/ensemble_diffeq.jl") | ||
end | ||
@time @safetestset "Symbol and integer based indexing of interpolated solutions" begin | ||
include("downstream/symbol_indexing.jl") | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for this to work it needs a dispatch that's not
EnsembleAlgorithm
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of the automatic alg selection in DifferentialEquations.jl?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I could work around this by making a stuct subtype of SciMLAlgorithm and wrap the optimizer's alg in that and handle it in
__solve
so we might not need to do this thenThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you double check some cases with this and defaults? Tests seem to pass