-
-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Relax type of alg in ensemble solve for optimization #534
Changes from 2 commits
8c17212
23754a8
9431771
275cfe1
7e7367b
dea66d1
1d87d19
e85bdb3
145f495
ebcbae3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,6 +44,26 @@ | |
EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy) | ||
end | ||
|
||
function EnsembleProblem(; prob, | ||
u0s::Union{Nothing, Vector{uType}} = nothing, | ||
prob_func = (prob, i, repeat) -> remake(prob, u0 = u0s[i]), | ||
output_func = DEFAULT_OUTPUT_FUNC, | ||
reduction = DEFAULT_REDUCTION, | ||
u_init = nothing, p = nothing, | ||
safetycopy = prob_func !== DEFAULT_PROB_FUNC) where {uType} | ||
EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy) | ||
end | ||
|
||
function EnsembleProblem(; prob, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll use this dispatch for this constructor https://github.com/SciML/Optimization.jl/pull/620/files#diff-6eea19b562ee1839438c6fa37d876324d501594af403caf5c836620ad7e3b793R6 |
||
trajectories::Int, | ||
prob_func, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this doesn't define a prob_func? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, because it'll need to access generated u0s, like https://github.com/SciML/Optimization.jl/pull/620/files#diff-6eea19b562ee1839438c6fa37d876324d501594af403caf5c836620ad7e3b793R6 or as a global variable so |
||
output_func = DEFAULT_OUTPUT_FUNC, | ||
reduction = DEFAULT_REDUCTION, | ||
u_init = nothing, p = nothing, | ||
safetycopy = prob_func !== DEFAULT_PROB_FUNC) | ||
EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy) | ||
end | ||
|
||
struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <: | ||
AbstractEnsembleProblem | ||
ensembleprob::T1 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
using DifferentialEquations | ||
|
||
f(u, p, t) = 1.01 * u | ||
u0 = 1 / 2 | ||
tspan = (0.0, 1.0) | ||
prob = ODEProblem(f, u0, tspan) | ||
ensemble_prob = EnsembleProblem(prob, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand())) | ||
sim = solve(ensemble_prob, EnsembleThreads(), trajectories = 10, dt = 0.1) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
using Optimization, OptimizationOptimJL, ForwardDiff, Test | ||
|
||
x0 = zeros(2) | ||
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2 | ||
l1 = rosenbrock(x0) | ||
|
||
optf = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff()) | ||
prob = OptimizationProblem(optf, x0) | ||
sol1 = Optimization.solve(prob, OptimizationOptimJL.BFGS(), maxiters = 5) | ||
|
||
ensembleprob = Optimization.EnsembleProblem(prob, [x0, x0 .+ rand(2), x0 .+ rand(2), x0 .+ rand(2)]) | ||
|
||
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 4, maxiters = 5) | ||
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective | ||
|
||
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 4, maxiters = 5) | ||
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective | ||
|
||
prob = OptimizationProblem(optf, x0, lb = [-0.5, -0.5], ub = [0.5, 0.5]) | ||
ensembleprob = Optimization.EnsembleProblem(prob, 5, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand(-0.5:0.001:0.5, 2))) | ||
|
||
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 5, maxiters = 5) | ||
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective | ||
|
||
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 5, maxiters = 5) | ||
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for this to work it needs a dispatch that's not
EnsembleAlgorithm
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of the automatic alg selection in DifferentialEquations.jl?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I could work around this by making a stuct subtype of SciMLAlgorithm and wrap the optimizer's alg in that and handle it in
__solve
so we might not need to do this thenThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you double check some cases with this and defaults? Tests seem to pass