-
-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Porting over Multiple Shooting and Single Shooting from NeuralBVP #110
Conversation
6f5da29
to
a1ba675
Compare
8d69044
to
ee7e240
Compare
c44d95b
to
060116a
Compare
Co-authored-by: Qingyu Qu <[email protected]>
060116a
to
a4694e8
Compare
b432e16
to
df3909c
Compare
bdcef85
to
8eedf1f
Compare
cfbeba1
to
a5b10cc
Compare
20b966e
to
6f40df9
Compare
src/solve/multiple_shooting.jl
Outdated
compute_bc_residual! = if prob.problem_type isa TwoPointBVProblem | ||
@views function compute_bc_residual_tp!(resid_bc, us::ArrayPartition, p, | ||
cur_nshoots, nodes, resid_nodes::Union{Nothing, MaybeDiffCache} = nothing) | ||
ua, ub0 = us.x | ||
# Just Recompute the last ODE Solution | ||
lastodeprob = ODEProblem{iip}(f, reshape(ub0, u0_size), | ||
(nodes[end - 1], nodes[end]), p) | ||
sol_ode_last = solve(lastodeprob, alg.ode_alg; odesolve_kwargs..., verbose, | ||
kwargs..., save_everystep = false, saveat = (), save_end = true) | ||
ub = vec(sol_ode_last.u[end]) | ||
|
||
resid_bc_a, resid_bc_b = if resid_bc isa ArrayPartition | ||
resid_bc.x | ||
else | ||
resid_bc[1:resida_len], resid_bc[(resida_len + 1):end] | ||
end | ||
|
||
if iip | ||
bc[1](resid_bc_a, ua, p) | ||
bc[2](resid_bc_b, ub, p) | ||
else | ||
resid_bc_a .= bc[1](ua, p) | ||
resid_bc_b .= bc[2](ub, p) | ||
end | ||
|
||
return resid_bc | ||
end | ||
else | ||
@views function compute_bc_residual_mp!(resid_bc, us, p, cur_nshoots, nodes, | ||
resid_nodes::Union{Nothing, MaybeDiffCache} = nothing) | ||
if resid_nodes === nothing | ||
_resid_nodes = similar(us, cur_nshoots * N) # This might be Dual based on `us` | ||
else | ||
_resid_nodes = get_tmp(resid_nodes, us) | ||
end | ||
|
||
# NOTE: We need to recompute this to correctly propagate the dual numbers / gradients | ||
_us, _ts = solve_internal_odes!(_resid_nodes, us, p, cur_nshoots, nodes) | ||
|
||
# Boundary conditions | ||
# Builds an ODESolution object to keep the framework for bc(,,) consistent | ||
odeprob = ODEProblem{iip}(f, reshape(us[1:N], u0_size), tspan, p) | ||
total_solution = SciMLBase.build_solution(odeprob, alg.ode_alg, _ts, _us) | ||
|
||
if iip | ||
eval_bc_residual!(resid_bc, prob.problem_type, bc, total_solution, p) | ||
else | ||
resid_bc .= eval_bc_residual(prob.problem_type, bc, total_solution, p) | ||
end | ||
|
||
return resid_bc | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ChrisRackauckas can you take a look here? For the correct BC jacobian part we do need to forward solve again?
4c135e2
to
d4326d6
Compare
d4326d6
to
b9074b6
Compare
This PR got more bloated than I anticipated. But it should be done now. Takeaways:
TODOs for later:
|
🎉 It was a bit hard to read since there were multiple things mixed in there, but generally fine so merging and will deal with anything else in the future. |
Don't merge before #109odesolve_kwargs
andnlsolve_kwargs
to allow different kwargs. repetition inkwargs
overrides these choicesTODOs
Translate NeuralBVP.jl to use BoundaryValueDiffEq.jl(Upstream issue)Test BVP with NLS(Will do it in a separate PR)