Skip to content

Commit

Permalink
Generalize the Jacobian Computation for MultipleShooting
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 10, 2023
1 parent 97d0f5d commit 8eedf1f
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 79 deletions.
2 changes: 1 addition & 1 deletion src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ end

export Shooting, MultipleShooting
export MIRK2, MIRK3, MIRK4, MIRK5, MIRK6
export MIRKJacobianComputationAlgorithm
export MIRKJacobianComputationAlgorithm, BVPJacobianAlgorithm
# From ODEInterface.jl
export BVPM2, BVPSOL

Expand Down
39 changes: 17 additions & 22 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
const DEFAULT_NLSOLVE_SHOOTING = NewtonRaphson()
const DEFAULT_NLSOLVE_MIRK = NewtonRaphson()
const DEFAULT_JACOBIAN_ALGORITHM_MIRK = MIRKJacobianComputationAlgorithm()

# Algorithms
abstract type BoundaryValueDiffEqAlgorithm <: SciMLBase.AbstractBVPAlgorithm end
abstract type AbstractMIRK <: BoundaryValueDiffEqAlgorithm end

"""
Shooting(ode_alg; nlsolve = DEFAULT_NLSOLVE_SHOOTING)
Shooting(ode_alg; nlsolve = NewtonRaphson())
Single shooting method, reduces BVP to an initial value problem and solves the IVP.
"""
Expand All @@ -16,24 +12,25 @@ struct Shooting{O, N} <: BoundaryValueDiffEqAlgorithm
nlsolve::N
end

Shooting(ode_alg; nlsolve = DEFAULT_NLSOLVE_SHOOTING) = Shooting(ode_alg, nlsolve)
Shooting(ode_alg; nlsolve = NewtonRaphson()) = Shooting(ode_alg, nlsolve)

"""
MultipleShooting(nshoots::Int, ode_alg; nlsolve = DEFAULT_NLSOLVE_SHOOTING,
MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(),
grid_coarsening = true)
Multiple Shooting method, reduces BVP to an initial value problem and solves the IVP.
Significantly more stable than Single Shooting.
"""
@concrete struct MultipleShooting
@concrete struct MultipleShooting{J <: BVPJacobianAlgorithm}
ode_alg
nlsolve
jac_alg::J
nshoots::Int
grid_coarsening
end

function MultipleShooting(nshoots::Int, ode_alg; nlsolve = DEFAULT_NLSOLVE_SHOOTING,
grid_coarsening = true)
function MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(),
grid_coarsening = true, jac_alg = BVPJacobianAlgorithm())
@assert grid_coarsening isa Bool || grid_coarsening isa Function ||
grid_coarsening isa AbstractVector{<:Integer} ||
grid_coarsening isa NTuple{N, <:Integer} where {N}
Expand All @@ -42,37 +39,35 @@ function MultipleShooting(nshoots::Int, ode_alg; nlsolve = DEFAULT_NLSOLVE_SHOOT
sort!(grid_coarsening; rev = true)
@assert all(grid_coarsening .> 0) && 1 grid_coarsening
end
return MultipleShooting(ode_alg, nlsolve, nshoots, grid_coarsening)
return MultipleShooting(ode_alg, nlsolve, jac_alg, nshoots, grid_coarsening)
end

for order in (2, 3, 4, 5, 6)
alg = Symbol("MIRK$(order)")

@eval begin
"""
$($alg)(; nlsolve = BoundaryValueDiffEq.DEFAULT_NLSOLVE_MIRK,
jac_alg = BoundaryValueDiffEq.DEFAULT_JACOBIAN_ALGORITHM_MIRK)
$($alg)(; nlsolve = NewtonRaphson(), jac_alg = BVPJacobianAlgorithm())
$($order)th order Monotonic Implicit Runge Kutta method, with Newton Raphson nonlinear solver as default.
## References
@article{Enright1996RungeKuttaSW,
title={Runge-Kutta Software with Defect Control for Boundary Value ODEs},
author={Wayne H. Enright and Paul H. Muir},
journal={SIAM J. Sci. Comput.},
year={1996},
volume={17},
pages={479-497}
title={Runge-Kutta Software with Defect Control for Boundary Value ODEs},
author={Wayne H. Enright and Paul H. Muir},
journal={SIAM J. Sci. Comput.},
year={1996},
volume={17},
pages={479-497}
}
"""
struct $(alg){N, J <: MIRKJacobianComputationAlgorithm} <: AbstractMIRK
struct $(alg){N, J <: BVPJacobianAlgorithm} <: AbstractMIRK
nlsolve::N
jac_alg::J
end

function $(alg)(; nlsolve = DEFAULT_NLSOLVE_MIRK,
jac_alg = DEFAULT_JACOBIAN_ALGORITHM_MIRK)
function $(alg)(; nlsolve = NewtonRaphson(), jac_alg = BVPJacobianAlgorithm())
return $(alg)(nlsolve, jac_alg)
end
end
Expand Down
12 changes: 5 additions & 7 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,28 +276,26 @@ function generate_nlprob(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, lo
cache_bc = __sparse_jacobian_cache(Val(iip), jac_alg.bc_diffmode, sd_bc, loss_bc,
resid_bc, y)

sd_collocation = if jac_alg.collocation_diffmode isa AbstractSparseADType
sd_collocation = if jac_alg.nonbc_diffmode isa AbstractSparseADType
Jₛ, cvec, rvec = construct_sparse_banded_jac_prototype(cache, y, cache.M, N)
PrecomputedJacobianColorvec(; jac_prototype = Jₛ, row_colorvec = rvec,
col_colorvec = cvec)
else
NoSparsityDetection()
end

cache_collocation = __sparse_jacobian_cache(Val(iip), jac_alg.collocation_diffmode,
cache_collocation = __sparse_jacobian_cache(Val(iip), jac_alg.nonbc_diffmode,
sd_collocation, loss_collocation, resid_collocation, y)

jac_prototype = vcat(init_jacobian(cache_bc),
jac_alg.collocation_diffmode isa AbstractSparseADType ? Jₛ :
init_jacobian(cache_collocation))
jac_prototype = vcat(init_jacobian(cache_bc), init_jacobian(cache_collocation))

# TODO: Pass `p` into `loss_bc` and `loss_collocation`. Currently leads to a Tag
# mismatch for ForwardDiff
jac = if iip
function jac_internal!(J, x, p)
sparse_jacobian!(@view(J[1:(cache.M), :]), jac_alg.bc_diffmode, cache_bc,
loss_bc, resid_bc, x)
sparse_jacobian!(@view(J[(cache.M + 1):end, :]), jac_alg.collocation_diffmode,
sparse_jacobian!(@view(J[(cache.M + 1):end, :]), jac_alg.nonbc_diffmode,
cache_collocation, loss_collocation, resid_collocation, x)
return J
end
Expand All @@ -306,7 +304,7 @@ function generate_nlprob(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, lo
function jac_internal(x, p)
sparse_jacobian!(@view(J_[1:(cache.M), :]), jac_alg.bc_diffmode, cache_bc,
loss_bc, x)
sparse_jacobian!(@view(J_[(cache.M + 1):end, :]), jac_alg.collocation_diffmode,
sparse_jacobian!(@view(J_[(cache.M + 1):end, :]), jac_alg.nonbc_diffmode,
cache_collocation, loss_collocation, x)
return J_
end
Expand Down
39 changes: 20 additions & 19 deletions src/solve/multiple_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function SciMLBase.__solve(prob::BVProblem, alg::MultipleShooting; odesolve_kwar
@warn "Initial guess length != `nshoots + 1`! Adapting to `nshoots = $(nshoots)`"
end

# We will use colored AD for this parts!
# We will use colored AD for this part!
@views function solve_internal_odes!(resid_nodes, us, p, cur_nshoots, nodes)
ts_ = Vector{Vector{typeof(first(tspan))}}(undef, cur_nshoots)
us_ = Vector{Vector{typeof(us)}}(undef, cur_nshoots)
Expand Down Expand Up @@ -94,12 +94,11 @@ function SciMLBase.__solve(prob::BVProblem, alg::MultipleShooting; odesolve_kwar
J_bc = J[1:N, :]
J_c = J[(N + 1):end, :]

# FIXME: External control
sparse_jacobian!(J_c, AutoSparseForwardDiff(), ode_jac_cache, ode_fn,
sparse_jacobian!(J_c, alg.jac_alg.nonbc_diffmode, ode_jac_cache, ode_fn,
resid_nodes.du, us)

# For BC
sparse_jacobian!(J_bc, AutoForwardDiff(), bc_jac_cache, bc_fn, resid_bc, us)
sparse_jacobian!(J_bc, alg.jac_alg.bc_diffmode, bc_jac_cache, bc_fn, resid_bc, us)

return nothing
end
Expand All @@ -120,28 +119,30 @@ function SciMLBase.__solve(prob::BVProblem, alg::MultipleShooting; odesolve_kwar

resid_prototype = ArrayPartition(bcresid_prototype,
similar(u_at_nodes, cur_nshoot * N))
residbc_prototype = DiffCache(bcresid_prototype,
pickchunksize((cur_nshoot + 1) * N))
resid_nodes = maybe_allocate_diffcache(resid_prototype.x[2],
pickchunksize((cur_nshoot + 1) * N),
AutoForwardDiff())

J_c, col_colorvec, row_colorvec = __generate_sparse_jacobian_prototype(alg, _u0, N,
cur_nshoot)
pickchunksize((cur_nshoot + 1) * N), alg.jac_alg.bc_diffmode)

ode_fn = (du, u) -> solve_internal_odes!(du, u, prob.p, cur_nshoot, nodes)
ode_jac_cache = sparse_jacobian_cache(AutoSparseForwardDiff(),
PrecomputedJacobianColorvec(; jac_prototype = J_c, col_colorvec, row_colorvec),
ode_fn, copy(resid_prototype.x[2]), u_at_nodes)
sd_ode = if alg.jac_alg.nonbc_diffmode isa AbstractSparseADType
J_c, col_colorvec, row_colorvec = __generate_sparse_jacobian_prototype(alg, _u0,
N, cur_nshoot)
PrecomputedJacobianColorvec(; jac_prototype = J_c, row_colorvec, col_colorvec)
else
NoSparsityDetection()
end
ode_jac_cache = sparse_jacobian_cache(alg.jac_alg.nonbc_diffmode, sd_ode,
ode_fn, similar(u_at_nodes, cur_nshoot * N), u_at_nodes)

bc_fn = (du, u) -> compute_bc_residual!(du, u, prob.p, cur_nshoot,
nodes, resid_nodes)
bc_jac_cache = sparse_jacobian_cache(AutoForwardDiff(),
NoSparsityDetection(), bc_fn, copy(resid_prototype.x[1]), u_at_nodes)
sd_bc = alg.jac_alg.bc_diffmode isa AbstractSparseADType ?
SymbolicsSparsityDetection() : NoSparsityDetection()
bc_jac_cache = sparse_jacobian_cache(alg.jac_alg.bc_diffmode,
sd_bc, bc_fn, similar(bcresid_prototype), u_at_nodes)

jac_prototype = vcat(init_jacobian(bc_jac_cache), init_jacobian(ode_jac_cache))

jac_fn = (J, us, p) -> jac!(J, us, p, resid_prototype.x[1], resid_nodes,
jac_fn = (J, us, p) -> jac!(J, us, p, similar(bcresid_prototype), resid_nodes,
ode_jac_cache, bc_jac_cache, ode_fn, bc_fn, cur_nshoot, nodes)

loss_function! = NonlinearFunction{true}((args...) -> loss!(args..., cur_nshoot,
Expand All @@ -152,8 +153,8 @@ function SciMLBase.__solve(prob::BVProblem, alg::MultipleShooting; odesolve_kwar
end

single_shooting_prob = remake(prob; u0 = reshape(u_at_nodes[1:N], u0_size))
return SciMLBase.__solve(single_shooting_prob, Shooting(alg.ode_alg; alg.nlsolve);
odesolve_kwargs, nlsolve_kwargs, verbose, kwargs...)
return solve(single_shooting_prob, Shooting(alg.ode_alg; alg.nlsolve); odesolve_kwargs,
nlsolve_kwargs, verbose, kwargs...)
end

function multiple_shooting_initialize(prob, alg::MultipleShooting, has_initial_guess,
Expand Down
46 changes: 23 additions & 23 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,40 +33,40 @@ end
@truncate_stacktrace MIRKInterpTableau 1

# Sparsity Detection
@concrete struct MIRKJacobianComputationAlgorithm
@concrete struct BVPJacobianAlgorithm
bc_diffmode
collocation_diffmode
nonbc_diffmode
diffmode
end

function MIRKJacobianComputationAlgorithm(diffmode = missing;
collocation_diffmode = missing, bc_diffmode = missing)
function BVPJacobianAlgorithm(diffmode = missing; nonbc_diffmode = missing,
bc_diffmode = missing)
if diffmode !== missing
@assert collocation_diffmode === missing && bc_diffmode === missing
return MIRKJacobianComputationAlgorithm(diffmode, diffmode, diffmode)
@assert nonbc_diffmode === missing && bc_diffmode === missing
return BVPJacobianAlgorithm(diffmode, diffmode, diffmode)
else
@static if VERSION < v"1.9"
diffmode = AutoForwardDiff()
bc_diffmode = bc_diffmode === missing ? AutoForwardDiff() : bc_diffmode
collocation_diffmode = collocation_diffmode === missing ?
AutoForwardDiff() : collocation_diffmode
else
diffmode = AutoSparseForwardDiff()
bc_diffmode = bc_diffmode === missing ? AutoForwardDiff() : bc_diffmode
collocation_diffmode = collocation_diffmode === missing ?
AutoSparseForwardDiff() : collocation_diffmode
end
return MIRKJacobianComputationAlgorithm(bc_diffmode, collocation_diffmode,
collocation_diffmode)
diffmode = AutoSparseForwardDiff()
bc_diffmode = bc_diffmode === missing ? AutoForwardDiff() : bc_diffmode
nonbc_diffmode = nonbc_diffmode === missing ?
AutoSparseForwardDiff() : nonbc_diffmode
return BVPJacobianAlgorithm(bc_diffmode, nonbc_diffmode, nonbc_diffmode)
end
end

function MIRKJacobianComputationAlgorithm(diffmode = missing;
collocation_diffmode = missing, bc_diffmode = missing)
Base.depwarn("`MIRKJacobianComputationAlgorithm` has been deprecated in favor of \
`BVPJacobianAlgorithm`. Replace `collocation_diffmode` with `nonbc_diffmode",
:MIRKJacobianComputationAlgorithm)
return BVPJacobianAlgorithm(diffmode; nonbc_diffmode = collocation_diffmode,
bc_diffmode)
end

__needs_diffcache(::Union{AutoForwardDiff, AutoSparseForwardDiff}) = true
__needs_diffcache(_) = false
function __needs_diffcache(jac_alg::MIRKJacobianComputationAlgorithm)
return __needs_diffcache(jac_alg.diffmode) ||
__needs_diffcache(jac_alg.bc_diffmode) ||
__needs_diffcache(jac_alg.collocation_diffmode)
function __needs_diffcache(jac_alg::BVPJacobianAlgorithm)
return __needs_diffcache(jac_alg.diffmode) || __needs_diffcache(jac_alg.bc_diffmode) ||
__needs_diffcache(jac_alg.nonbc_diffmode)
end

# We don't need to always allocate a DiffCache. This works around that.
Expand Down
6 changes: 4 additions & 2 deletions test/shooting/orbital.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ for autodiff in (AutoForwardDiff(), AutoFiniteDiff(; fdtype = Val(:central)),
cur_bc!(resid_f, sol, nothing, sol.t)
@test norm(resid_f, Inf) < TestTol

@time sol = solve(bvp, MultipleShooting(10, DP5(); nlsolve); abstol = 1e-6,
jac_alg = BVPJacobianAlgorithm(; nonbc_diffmode = autodiff)
@time sol = solve(bvp, MultipleShooting(10, DP5(); nlsolve, jac_alg); abstol = 1e-6,
reltol = 1e-6)
@test SciMLBase.successful_retcode(sol)
cur_bc!(resid_f, sol, nothing, sol.t)
Expand All @@ -97,7 +98,8 @@ for autodiff in (AutoForwardDiff(), AutoFiniteDiff(; fdtype = Val(:central)),
cur_bc_2point_b!(resid_f_2p[2], sol(t1), nothing)
@test norm(reduce(vcat, resid_f_2p), Inf) < TestTol

@time sol = solve(bvp, MultipleShooting(10, DP5(); nlsolve); abstol = 1e-6,
jac_alg = BVPJacobianAlgorithm(; nonbc_diffmode = autodiff)
@time sol = solve(bvp, MultipleShooting(10, DP5(); nlsolve, jac_alg); abstol = 1e-6,
reltol = 1e-6)
@test SciMLBase.successful_retcode(sol)
cur_bc_2point_a!(resid_f_2p[1], sol(t0), nothing)
Expand Down
10 changes: 5 additions & 5 deletions test/shooting/shooting_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ using BoundaryValueDiffEq, LinearAlgebra, OrdinaryDiffEq, Test
sol = solve(bvp4, solver; abstol = 1e-13, reltol = 1e-13)
@test SciMLBase.successful_retcode(sol)
resid_f = reduce(vcat, (bc2a(sol(tspan[1]), nothing), bc2b(sol(tspan[2]), nothing)))
@test norm(resid_f) < 1e-12
@test norm(resid_f) < 1e-11
end
end

Expand All @@ -101,10 +101,10 @@ end
resid_f = Array{ComplexF64}(undef, 2)

nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff())
for solver in [Shooting(Tsit5(); nlsolve)]
# FIXME: Need to reenable MS. Currently it always uses ForwardDiff which is a
# regression and needs fixing
# , MultipleShooting(10, Tsit5(); nlsolve)]
jac_alg = BVPJacobianAlgorithm(; bc_diffmode = AutoFiniteDiff(),
nonbc_diffmode = AutoSparseFiniteDiff())
for solver in [Shooting(Tsit5(); nlsolve),
MultipleShooting(10, Tsit5(); nlsolve, jac_alg)]
sol = solve(bvp, solver; abstol = 1e-13, reltol = 1e-13)
@test SciMLBase.successful_retcode(sol)
bc1!(resid_f, sol, nothing, sol.t)
Expand Down

0 comments on commit 8eedf1f

Please sign in to comment.