Skip to content

Commit

Permalink
Use sparse Jacobian for Multiple Shooting
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 5, 2023
1 parent 4126b22 commit 3185e7e
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 99 deletions.
2 changes: 1 addition & 1 deletion src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Adapt, LinearAlgebra, PreallocationTools, Reexport, Setfield, SparseArrays
@reexport using ADTypes, DiffEqBase, NonlinearSolve, SparseDiffTools, SciMLBase

import ADTypes: AbstractADType
import ArrayInterface: matrix_colors, parameterless_type
import ArrayInterface: matrix_colors, parameterless_type, undefmatrix
import ConcreteStructs: @concrete
import DiffEqBase: solve
import ForwardDiff: pickchunksize
Expand Down
74 changes: 70 additions & 4 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ function construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {ii
function loss_collocation_internal(u::AbstractVector, p = cache.p)
y_ = recursive_unflatten!(cache.y, u)
resids = Φ(cache, y_, u, p)
xxx = mapreduce(vec, vcat, resids)
return xxx
return mapreduce(vec, vcat, resids)
end
end

Expand Down Expand Up @@ -269,7 +268,7 @@ function generate_nlprob(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, lo
resid_bc, y)

sd_collocation = if jac_alg.collocation_diffmode isa AbstractSparseADType
Jₛ, cvec, rvec = construct_sparse_banded_jac_prototype(y, cache.M, N)
Jₛ, cvec, rvec = construct_sparse_banded_jac_prototype(cache, y, cache.M, N)
PrecomputedJacobianColorvec(; jac_prototype = Jₛ, row_colorvec = rvec,
col_colorvec = cvec)
else
Expand Down Expand Up @@ -322,7 +321,7 @@ function generate_nlprob(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, lo
end

sd = if jac_alg.diffmode isa AbstractSparseADType
Jₛ, cvec, rvec = construct_sparse_banded_jac_prototype(resid, cache.M, N)
Jₛ, cvec, rvec = construct_sparse_banded_jac_prototype(cache, resid, cache.M, N)
PrecomputedJacobianColorvec(; jac_prototype = Jₛ, row_colorvec = rvec,
col_colorvec = cvec)
else
Expand Down Expand Up @@ -351,3 +350,70 @@ function generate_nlprob(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, lo

return NonlinearProblem(NonlinearFunction{iip}(loss; jac, jac_prototype), y, cache.p)
end

# Generating Banded Matrix
function construct_sparse_banded_jac_prototype(::MIRKCache, y, M, N)
l = sum(i -> min(2M + i, M * N) - max(1, i - 1) + 1, 1:(M * (N - 1)))
Is = Vector{Int}(undef, l)
Js = Vector{Int}(undef, l)
idx = 1
for i in 1:(M * (N - 1)), j in max(1, i - 1):min(2M + i, M * N)
Is[idx] = i
Js[idx] = j
idx += 1
end
col_colorvec = Vector{Int}(undef, M * N)
for i in eachindex(col_colorvec)
col_colorvec[i] = mod1(i, min(2M + 1, M * N) + 1)
end
row_colorvec = Vector{Int}(undef, M * (N - 1))
for i in eachindex(row_colorvec)
row_colorvec[i] = mod1(i, min(2M + 1, M * N) + 1)
end

y_ = similar(y, length(Is))
return (sparse(adapt(parameterless_type(y), Is), adapt(parameterless_type(y), Js),
y_, M * (N - 1), M * N), col_colorvec, row_colorvec)
end

# Two Point Specialization
function construct_sparse_banded_jac_prototype(::MIRKCache, y::ArrayPartition, M, N)
l = sum(i -> min(2M + i, M * N) - max(1, i - 1) + 1, 1:(M * (N - 1)))
l_top = M * length(y.x[1].x[1])
l_bot = M * length(y.x[1].x[2])

Is = Vector{Int}(undef, l + l_top + l_bot)
Js = Vector{Int}(undef, l + l_top + l_bot)
idx = 1

for i in 1:length(y.x[1].x[1]), j in 1:M
Is[idx] = i
Js[idx] = j
idx += 1
end

for i in 1:(M * (N - 1)), j in max(1, i - 1):min(2M + i, M * N)
Is[idx] = i + length(y.x[1].x[1])
Js[idx] = j
idx += 1
end

for i in 1:length(y.x[1].x[2]), j in 1:M
Is[idx] = i + length(y.x[1].x[1]) + M * (N - 1)
Js[idx] = j + M * (N - 1)
idx += 1
end

col_colorvec = Vector{Int}(undef, M * N)
for i in eachindex(col_colorvec)
col_colorvec[i] = mod1(i, min(2M + 1, M * N) + 1)
end
row_colorvec = Vector{Int}(undef, M * N)
for i in eachindex(row_colorvec)
row_colorvec[i] = mod1(i, min(2M + 1, M * N) + 1)
end

y_ = similar(y, length(Is))
return (sparse(adapt(parameterless_type(y), Is), adapt(parameterless_type(y), Js),
y_, M * N, M * N), col_colorvec, row_colorvec)
end
156 changes: 135 additions & 21 deletions src/solve/multiple_shooting.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# TODO: incorporate `initial_guess` similar to MIRK methods
function SciMLBase.__solve(prob::BVProblem, alg::MultipleShooting; odesolve_kwargs = (;),
nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), kwargs...)
nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), verbose = true, kwargs...)
@unpack f, bc, tspan = prob
bcresid_prototype = prob.f.bcresid_prototype === nothing ? similar(prob.u0) :
has_initial_guess = prob.u0 isa AbstractVector{<:AbstractArray}
_u0 = has_initial_guess ? first(prob.u0) : prob.u0
N, u0_size, nshoots, iip = length(_u0), size(_u0), alg.nshoots, isinplace(prob)
bcresid_prototype = prob.f.bcresid_prototype === nothing ? similar(_u0) :
prob.f.bcresid_prototype
N, u0_size, nshoots, iip = length(prob.u0), size(prob.u0), alg.nshoots, isinplace(prob)

if has_initial_guess && length(prob.u0) != nshoots + 1
nshoots = length(prob.u0) - 1
verbose &&
@warn "Initial guess length != `nshoots + 1`! Adapting to `nshoots = $(nshoots)`"
end

@views function loss!(resid::ArrayPartition, us, p, cur_nshoots, nodes)
ts_ = Vector{Vector{typeof(first(tspan))}}(undef, cur_nshoots)
Expand Down Expand Up @@ -32,7 +40,8 @@ function SciMLBase.__solve(prob::BVProblem, alg::MultipleShooting; odesolve_kwar
ensemble_prob = EnsembleProblem(odeprob; prob_func, reduction, safetycopy = false,
u_init = (; us = us_, ts = ts_, resid = resid_nodes))
ensemble_sol = solve(ensemble_prob, alg.ode_alg, ensemblealg; odesolve_kwargs...,
kwargs..., save_end = true, save_everystep = false, trajectories = cur_nshoots)
verbose, kwargs..., save_end = true, save_everystep = false,
trajectories = cur_nshoots)

_us = reduce(vcat, ensemble_sol.u.us)
_ts = reduce(vcat, ensemble_sol.u.ts)
Expand All @@ -50,44 +59,123 @@ function SciMLBase.__solve(prob::BVProblem, alg::MultipleShooting; odesolve_kwar
return resid
end

@views function jac!(J::AbstractMatrix, us, p, cur_nshoots, nodes, resid_bc)
J_bc = J[1:N, :]
J_c = J[(N + 1):end, :]

# Threads.@threads :static
# FIXME: Threading here leads to segfaults
for i in 1:cur_nshoots
uᵢ = us[((i - 1) * N + 1):(i * N)]
idx = ((i - 1) * N + 1):(i * N)
probᵢ = ODEProblem{iip}(f, reshape(uᵢ, u0_size), (nodes[i], nodes[i + 1]), p)
function solve_func(u₀)
sJ = solve(probᵢ, alg.ode_alg; u0 = u₀, odesolve_kwargs...,
kwargs..., save_end = true, save_everystep = false, saveat = ())
return -last(sJ)
end
# @show sum(J_c[idx, idx]), sum(J_c[idx, idx .+ N])
ForwardDiff.jacobian!(J_c[idx, idx], solve_func, uᵢ)
J_c′ = J_c[idx, idx .+ N]
J_c′[diagind(J_c′)] .= 1
# @show sum(J_c[idx, idx]), sum(J_c[idx, idx .+ N])
end

function evaluate_boundary_condition(us)
ts_ = Vector{Vector{typeof(first(tspan))}}(undef, cur_nshoots)
us_ = Vector{Vector{typeof(us)}}(undef, cur_nshoots)

function prob_func(probᵢ, i, repeat)
return remake(probᵢ; u0 = reshape(us[((i - 1) * N + 1):(i * N)], u0_size),
tspan = (nodes[i], nodes[i + 1]))
end

function reduction(u, data, I)
for i in I
u.us[i] = data[i].u
u.ts[i] = data[i].t
end
return (u, false)
end

odeprob = ODEProblem{iip}(f, reshape(us[1:N], u0_size), tspan, p)

ensemble_prob = EnsembleProblem(odeprob; prob_func, reduction,
safetycopy = false, u_init = (; us = us_, ts = ts_))
ensemble_sol = solve(ensemble_prob, alg.ode_alg, ensemblealg;
odesolve_kwargs..., kwargs..., save_end = true, save_everystep = false,
trajectories = cur_nshoots)

_us = reduce(vcat, ensemble_sol.u.us)
_ts = reduce(vcat, ensemble_sol.u.ts)

# Boundary conditions
# Builds an ODESolution object to keep the framework for bc(,,) consistent
total_solution = SciMLBase.build_solution(odeprob, alg.ode_alg, _ts, _us)

if iip
_resid_bc = get_tmp(resid_bc, us)
eval_bc_residual!(_resid_bc, prob.problem_type, bc, total_solution, p)
return _resid_bc
else
return eval_bc_residual(prob.problem_type, bc, total_solution, p)
end
end

ForwardDiff.jacobian!(J_bc, evaluate_boundary_condition, us)

return nothing
end

# This gets all the nshoots except the final SingleShooting case
all_nshoots = get_all_nshoots(alg)
all_nshoots = get_all_nshoots(alg.grid_coarsening, nshoots)
u_at_nodes, nodes = nothing, nothing

for (i, cur_nshoot) in enumerate(all_nshoots)
if i == 1
nodes, u_at_nodes = multiple_shooting_initialize(prob, alg; odesolve_kwargs,
kwargs...)
nodes, u_at_nodes = multiple_shooting_initialize(prob, alg, has_initial_guess,
nshoots; odesolve_kwargs, verbose, kwargs...)
else
nodes, u_at_nodes = multiple_shooting_initialize(u_at_nodes, prob, alg, nodes,
cur_nshoot, all_nshoots[i - 1]; odesolve_kwargs, kwargs...)
cur_nshoot, all_nshoots[i - 1], has_initial_guess; odesolve_kwargs, verbose,
kwargs...)
end

resid_prototype = ArrayPartition(bcresid_prototype,
similar(u_at_nodes, cur_nshoot * N))
residbc_prototype = DiffCache(bcresid_prototype, pickchunksize(cur_nshoot * N))
jac_prototype = __generate_sparse_jacobian_prototype(alg, _u0, bcresid_prototype, N,
cur_nshoot)

loss_function! = NonlinearFunction{true}((args...) -> loss!(args..., cur_nshoot,
nodes); resid_prototype)
nodes); resid_prototype, jac = (args...) -> jac!(args..., cur_nshoot, nodes, residbc_prototype), jac_prototype)
nlprob = NonlinearProblem(loss_function!, u_at_nodes, prob.p)
sol_nlsolve = solve(nlprob, alg.nlsolve; nlsolve_kwargs..., kwargs...)
sol_nlsolve = solve(nlprob, alg.nlsolve; nlsolve_kwargs..., verbose, kwargs...)
u_at_nodes = sol_nlsolve.u
end

single_shooting_prob = remake(prob; u0 = reshape(u_at_nodes[1:N], u0_size))
return SciMLBase.__solve(single_shooting_prob, Shooting(alg.ode_alg; alg.nlsolve);
odesolve_kwargs, nlsolve_kwargs, kwargs...)
odesolve_kwargs, nlsolve_kwargs, verbose, kwargs...)
end

function multiple_shooting_initialize(prob, alg::MultipleShooting; odesolve_kwargs = (;),
kwargs...)
function multiple_shooting_initialize(prob, alg::MultipleShooting, has_initial_guess,
nshoots; odesolve_kwargs = (;), verbose = true, kwargs...)
@unpack f, bc, u0, tspan, p = prob
@unpack ode_alg, nshoots = alg
@unpack ode_alg = alg

N = length(u0)
nodes = range(tspan[1], tspan[2]; length = nshoots + 1)
N = has_initial_guess ? length(first(u0)) : length(u0)

if has_initial_guess
u_at_nodes = similar(first(u0), (nshoots + 1) * N)
recursive_flatten!(u_at_nodes, u0)
return nodes, u_at_nodes
end

# Ensures type stability in case the parameters are dual numbers
if !(typeof(p) <: SciMLBase.NullParameters)
if !isconcretetype(eltype(p))
if !isconcretetype(eltype(p)) && verbose
@warn "Type inference will fail if eltype(p) is not a concrete type"
end
u_at_nodes = similar(u0, promote_type(eltype(u0), eltype(p)), (nshoots + 1) * N)
Expand All @@ -97,7 +185,7 @@ function multiple_shooting_initialize(prob, alg::MultipleShooting; odesolve_kwar

# Assumes no initial guess for now
start_prob = ODEProblem{isinplace(prob)}(f, u0, tspan, p)
sol = solve(start_prob, ode_alg; odesolve_kwargs..., kwargs..., saveat = nodes)
sol = solve(start_prob, ode_alg; odesolve_kwargs..., verbose, kwargs..., saveat = nodes)

if SciMLBase.successful_retcode(sol)
u_at_nodes[1:N] .= sol.u[1]
Expand All @@ -114,10 +202,10 @@ function multiple_shooting_initialize(prob, alg::MultipleShooting; odesolve_kwar
end

@views @inline function multiple_shooting_initialize(u_at_nodes_prev, prob, alg,
prev_nodes, nshoots, old_nshoots; odesolve_kwargs = (;), kwargs...)
prev_nodes, nshoots, old_nshoots, has_initial_guess; odesolve_kwargs = (;), kwargs...)
@unpack f, bc, u0, tspan, p = prob
nodes = range(tspan[1], tspan[2]; length = nshoots + 1)
N = length(u0)
N = has_initial_guess ? length(first(u0)) : length(u0)

u_at_nodes = similar(u_at_nodes_prev, N + nshoots * N)
u_at_nodes[1:N] .= u_at_nodes_prev[1:N]
Expand Down Expand Up @@ -156,8 +244,7 @@ end
return nodes, u_at_nodes
end

@inline function get_all_nshoots(alg::MultipleShooting)
@unpack nshoots, grid_coarsening = alg
@inline function get_all_nshoots(grid_coarsening, nshoots)
if grid_coarsening isa Bool
!grid_coarsening && return [nshoots]
update_fn = Base.Fix2(÷, 2)
Expand All @@ -176,3 +263,30 @@ end
@assert !(1 in nshoots_vec)
return nshoots_vec
end

function __generate_sparse_jacobian_prototype(::MultipleShooting, u0, bcresid_prototype,
N::Int, nshoots::Int)
# Assume dense BC
J_bc = similar(bcresid_prototype, length(bcresid_prototype), N * (nshoots + 1))

# Sparse for Stitching solution together
Is = Vector{UInt32}(undef, (N^2 + N) * nshoots)
Js = Vector{UInt32}(undef, (N^2 + N) * nshoots)

idx = 1
for i in 1:nshoots
for (i₁, i₂) in Iterators.product(1:N, 1:N)
Is[idx] = i₁ + ((i - 1) * N)
Js[idx] = i₂ + ((i - 1) * N)
idx += 1
end
Is[idx:(idx + N - 1)] .= (1:N) .+ ((i - 1) * N)
Js[idx:(idx + N - 1)] .= (1:N) .+ (i * N)
idx += N
end

J_c = sparse(adapt(parameterless_type(u0), Is), adapt(parameterless_type(u0), Js),
similar(u0, length(Is)))

return vcat(J_bc, J_c)
end
20 changes: 14 additions & 6 deletions src/solve/single_shooting.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
function SciMLBase.__solve(prob::BVProblem, alg::Shooting; odesolve_kwargs = (;),
nlsolve_kwargs = (;), kwargs...)
iip, bc, u0, u0_size = isinplace(prob), prob.bc, deepcopy(prob.u0), size(prob.u0)
nlsolve_kwargs = (;), verbose = true, kwargs...)
has_initial_guess = prob.u0 isa AbstractVector{<:AbstractArray}
has_initial_guess && verbose &&
@warn "Initial guess provided, but will be ignored for Shooting!"
u0 = has_initial_guess ? first(prob.u0) : prob.u0

iip, bc, u0, u0_size = isinplace(prob), prob.bc, deepcopy(u0), size(u0)
resid_size = prob.f.bcresid_prototype === nothing ? u0_size :
size(prob.f.bcresid_prototype)

loss_fn = if iip
function loss!(resid, u0_, p)
u0_internal = reshape(u0_, u0_size)
tmp_prob = ODEProblem{iip}(prob.f, u0_internal, prob.tspan, p)
internal_sol = solve(tmp_prob, alg.ode_alg; odesolve_kwargs..., kwargs...)
internal_sol = solve(tmp_prob, alg.ode_alg; odesolve_kwargs..., verbose,
kwargs...)
eval_bc_residual!(reshape(resid, resid_size), prob.problem_type, bc,
internal_sol, p)
return nothing
Expand All @@ -16,15 +23,16 @@ function SciMLBase.__solve(prob::BVProblem, alg::Shooting; odesolve_kwargs = (;)
function loss(u0_, p)
u0_internal = reshape(u0_, u0_size)
tmp_prob = ODEProblem(prob.f, u0_internal, prob.tspan, p)
internal_sol = solve(tmp_prob, alg.ode_alg; odesolve_kwargs..., kwargs...)
internal_sol = solve(tmp_prob, alg.ode_alg; odesolve_kwargs..., verbose,
kwargs...)
return vec(eval_bc_residual(prob.problem_type, bc, internal_sol, p))
end
end
opt = solve(NonlinearProblem(NonlinearFunction{iip}(loss_fn; prob.f.jac_prototype,
resid_prototype = prob.f.bcresid_prototype), vec(u0), prob.p), alg.nlsolve;
nlsolve_kwargs..., kwargs...)
nlsolve_kwargs..., verbose, kwargs...)
newprob = ODEProblem{iip}(prob.f, reshape(opt.u, u0_size), prob.tspan, prob.p)
sol = solve(newprob, alg.ode_alg; odesolve_kwargs..., kwargs...)
sol = solve(newprob, alg.ode_alg; odesolve_kwargs..., verbose, kwargs...)

if !SciMLBase.successful_retcode(opt)
return SciMLBase.solution_new_retcode(sol, ReturnCode.Failure)
Expand Down
Loading

0 comments on commit 3185e7e

Please sign in to comment.