From 3185e7eef6a4226bb7e17acf4842a1da713c8f66 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Oct 2023 17:36:33 -0400 Subject: [PATCH] Use sparse Jacobian for Multiple Shooting --- src/BoundaryValueDiffEq.jl | 2 +- src/solve/mirk.jl | 74 +++++++++++++++- src/solve/multiple_shooting.jl | 156 ++++++++++++++++++++++++++++----- src/solve/single_shooting.jl | 20 +++-- src/utils.jl | 67 -------------- 5 files changed, 220 insertions(+), 99 deletions(-) diff --git a/src/BoundaryValueDiffEq.jl b/src/BoundaryValueDiffEq.jl index f4fa5475..b6daf8dc 100644 --- a/src/BoundaryValueDiffEq.jl +++ b/src/BoundaryValueDiffEq.jl @@ -5,7 +5,7 @@ using Adapt, LinearAlgebra, PreallocationTools, Reexport, Setfield, SparseArrays @reexport using ADTypes, DiffEqBase, NonlinearSolve, SparseDiffTools, SciMLBase import ADTypes: AbstractADType -import ArrayInterface: matrix_colors, parameterless_type +import ArrayInterface: matrix_colors, parameterless_type, undefmatrix import ConcreteStructs: @concrete import DiffEqBase: solve import ForwardDiff: pickchunksize diff --git a/src/solve/mirk.jl b/src/solve/mirk.jl index 0041cfa7..1b48afd1 100644 --- a/src/solve/mirk.jl +++ b/src/solve/mirk.jl @@ -203,8 +203,7 @@ function construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {ii function loss_collocation_internal(u::AbstractVector, p = cache.p) y_ = recursive_unflatten!(cache.y, u) resids = Φ(cache, y_, u, p) - xxx = mapreduce(vec, vcat, resids) - return xxx + return mapreduce(vec, vcat, resids) end end @@ -269,7 +268,7 @@ function generate_nlprob(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, lo resid_bc, y) sd_collocation = if jac_alg.collocation_diffmode isa AbstractSparseADType - Jₛ, cvec, rvec = construct_sparse_banded_jac_prototype(y, cache.M, N) + Jₛ, cvec, rvec = construct_sparse_banded_jac_prototype(cache, y, cache.M, N) PrecomputedJacobianColorvec(; jac_prototype = Jₛ, row_colorvec = rvec, col_colorvec = cvec) else @@ -322,7 +321,7 @@ function generate_nlprob(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, lo end sd = if jac_alg.diffmode isa AbstractSparseADType - Jₛ, cvec, rvec = construct_sparse_banded_jac_prototype(resid, cache.M, N) + Jₛ, cvec, rvec = construct_sparse_banded_jac_prototype(cache, resid, cache.M, N) PrecomputedJacobianColorvec(; jac_prototype = Jₛ, row_colorvec = rvec, col_colorvec = cvec) else @@ -351,3 +350,70 @@ function generate_nlprob(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, lo return NonlinearProblem(NonlinearFunction{iip}(loss; jac, jac_prototype), y, cache.p) end + +# Generating Banded Matrix +function construct_sparse_banded_jac_prototype(::MIRKCache, y, M, N) + l = sum(i -> min(2M + i, M * N) - max(1, i - 1) + 1, 1:(M * (N - 1))) + Is = Vector{Int}(undef, l) + Js = Vector{Int}(undef, l) + idx = 1 + for i in 1:(M * (N - 1)), j in max(1, i - 1):min(2M + i, M * N) + Is[idx] = i + Js[idx] = j + idx += 1 + end + col_colorvec = Vector{Int}(undef, M * N) + for i in eachindex(col_colorvec) + col_colorvec[i] = mod1(i, min(2M + 1, M * N) + 1) + end + row_colorvec = Vector{Int}(undef, M * (N - 1)) + for i in eachindex(row_colorvec) + row_colorvec[i] = mod1(i, min(2M + 1, M * N) + 1) + end + + y_ = similar(y, length(Is)) + return (sparse(adapt(parameterless_type(y), Is), adapt(parameterless_type(y), Js), + y_, M * (N - 1), M * N), col_colorvec, row_colorvec) +end + +# Two Point Specialization +function construct_sparse_banded_jac_prototype(::MIRKCache, y::ArrayPartition, M, N) + l = sum(i -> min(2M + i, M * N) - max(1, i - 1) + 1, 1:(M * (N - 1))) + l_top = M * length(y.x[1].x[1]) + l_bot = M * length(y.x[1].x[2]) + + Is = Vector{Int}(undef, l + l_top + l_bot) + Js = Vector{Int}(undef, l + l_top + l_bot) + idx = 1 + + for i in 1:length(y.x[1].x[1]), j in 1:M + Is[idx] = i + Js[idx] = j + idx += 1 + end + + for i in 1:(M * (N - 1)), j in max(1, i - 1):min(2M + i, M * N) + Is[idx] = i + length(y.x[1].x[1]) + Js[idx] = j + idx += 1 + end + + for i in 1:length(y.x[1].x[2]), j in 1:M + Is[idx] = i + length(y.x[1].x[1]) + M * (N - 1) + Js[idx] = j + M * (N - 1) + idx += 1 + end + + col_colorvec = Vector{Int}(undef, M * N) + for i in eachindex(col_colorvec) + col_colorvec[i] = mod1(i, min(2M + 1, M * N) + 1) + end + row_colorvec = Vector{Int}(undef, M * N) + for i in eachindex(row_colorvec) + row_colorvec[i] = mod1(i, min(2M + 1, M * N) + 1) + end + + y_ = similar(y, length(Is)) + return (sparse(adapt(parameterless_type(y), Is), adapt(parameterless_type(y), Js), + y_, M * N, M * N), col_colorvec, row_colorvec) +end diff --git a/src/solve/multiple_shooting.jl b/src/solve/multiple_shooting.jl index e286e611..0ab9f87c 100644 --- a/src/solve/multiple_shooting.jl +++ b/src/solve/multiple_shooting.jl @@ -1,10 +1,18 @@ # TODO: incorporate `initial_guess` similar to MIRK methods function SciMLBase.__solve(prob::BVProblem, alg::MultipleShooting; odesolve_kwargs = (;), - nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), kwargs...) + nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), verbose = true, kwargs...) @unpack f, bc, tspan = prob - bcresid_prototype = prob.f.bcresid_prototype === nothing ? similar(prob.u0) : + has_initial_guess = prob.u0 isa AbstractVector{<:AbstractArray} + _u0 = has_initial_guess ? first(prob.u0) : prob.u0 + N, u0_size, nshoots, iip = length(_u0), size(_u0), alg.nshoots, isinplace(prob) + bcresid_prototype = prob.f.bcresid_prototype === nothing ? similar(_u0) : prob.f.bcresid_prototype - N, u0_size, nshoots, iip = length(prob.u0), size(prob.u0), alg.nshoots, isinplace(prob) + + if has_initial_guess && length(prob.u0) != nshoots + 1 + nshoots = length(prob.u0) - 1 + verbose && + @warn "Initial guess length != `nshoots + 1`! Adapting to `nshoots = $(nshoots)`" + end @views function loss!(resid::ArrayPartition, us, p, cur_nshoots, nodes) ts_ = Vector{Vector{typeof(first(tspan))}}(undef, cur_nshoots) @@ -32,7 +40,8 @@ function SciMLBase.__solve(prob::BVProblem, alg::MultipleShooting; odesolve_kwar ensemble_prob = EnsembleProblem(odeprob; prob_func, reduction, safetycopy = false, u_init = (; us = us_, ts = ts_, resid = resid_nodes)) ensemble_sol = solve(ensemble_prob, alg.ode_alg, ensemblealg; odesolve_kwargs..., - kwargs..., save_end = true, save_everystep = false, trajectories = cur_nshoots) + verbose, kwargs..., save_end = true, save_everystep = false, + trajectories = cur_nshoots) _us = reduce(vcat, ensemble_sol.u.us) _ts = reduce(vcat, ensemble_sol.u.ts) @@ -50,44 +59,123 @@ function SciMLBase.__solve(prob::BVProblem, alg::MultipleShooting; odesolve_kwar return resid end + @views function jac!(J::AbstractMatrix, us, p, cur_nshoots, nodes, resid_bc) + J_bc = J[1:N, :] + J_c = J[(N + 1):end, :] + + # Threads.@threads :static + # FIXME: Threading here leads to segfaults + for i in 1:cur_nshoots + uᵢ = us[((i - 1) * N + 1):(i * N)] + idx = ((i - 1) * N + 1):(i * N) + probᵢ = ODEProblem{iip}(f, reshape(uᵢ, u0_size), (nodes[i], nodes[i + 1]), p) + function solve_func(u₀) + sJ = solve(probᵢ, alg.ode_alg; u0 = u₀, odesolve_kwargs..., + kwargs..., save_end = true, save_everystep = false, saveat = ()) + return -last(sJ) + end + # @show sum(J_c[idx, idx]), sum(J_c[idx, idx .+ N]) + ForwardDiff.jacobian!(J_c[idx, idx], solve_func, uᵢ) + J_c′ = J_c[idx, idx .+ N] + J_c′[diagind(J_c′)] .= 1 + # @show sum(J_c[idx, idx]), sum(J_c[idx, idx .+ N]) + end + + function evaluate_boundary_condition(us) + ts_ = Vector{Vector{typeof(first(tspan))}}(undef, cur_nshoots) + us_ = Vector{Vector{typeof(us)}}(undef, cur_nshoots) + + function prob_func(probᵢ, i, repeat) + return remake(probᵢ; u0 = reshape(us[((i - 1) * N + 1):(i * N)], u0_size), + tspan = (nodes[i], nodes[i + 1])) + end + + function reduction(u, data, I) + for i in I + u.us[i] = data[i].u + u.ts[i] = data[i].t + end + return (u, false) + end + + odeprob = ODEProblem{iip}(f, reshape(us[1:N], u0_size), tspan, p) + + ensemble_prob = EnsembleProblem(odeprob; prob_func, reduction, + safetycopy = false, u_init = (; us = us_, ts = ts_)) + ensemble_sol = solve(ensemble_prob, alg.ode_alg, ensemblealg; + odesolve_kwargs..., kwargs..., save_end = true, save_everystep = false, + trajectories = cur_nshoots) + + _us = reduce(vcat, ensemble_sol.u.us) + _ts = reduce(vcat, ensemble_sol.u.ts) + + # Boundary conditions + # Builds an ODESolution object to keep the framework for bc(,,) consistent + total_solution = SciMLBase.build_solution(odeprob, alg.ode_alg, _ts, _us) + + if iip + _resid_bc = get_tmp(resid_bc, us) + eval_bc_residual!(_resid_bc, prob.problem_type, bc, total_solution, p) + return _resid_bc + else + return eval_bc_residual(prob.problem_type, bc, total_solution, p) + end + end + + ForwardDiff.jacobian!(J_bc, evaluate_boundary_condition, us) + + return nothing + end + # This gets all the nshoots except the final SingleShooting case - all_nshoots = get_all_nshoots(alg) + all_nshoots = get_all_nshoots(alg.grid_coarsening, nshoots) u_at_nodes, nodes = nothing, nothing for (i, cur_nshoot) in enumerate(all_nshoots) if i == 1 - nodes, u_at_nodes = multiple_shooting_initialize(prob, alg; odesolve_kwargs, - kwargs...) + nodes, u_at_nodes = multiple_shooting_initialize(prob, alg, has_initial_guess, + nshoots; odesolve_kwargs, verbose, kwargs...) else nodes, u_at_nodes = multiple_shooting_initialize(u_at_nodes, prob, alg, nodes, - cur_nshoot, all_nshoots[i - 1]; odesolve_kwargs, kwargs...) + cur_nshoot, all_nshoots[i - 1], has_initial_guess; odesolve_kwargs, verbose, + kwargs...) end resid_prototype = ArrayPartition(bcresid_prototype, similar(u_at_nodes, cur_nshoot * N)) + residbc_prototype = DiffCache(bcresid_prototype, pickchunksize(cur_nshoot * N)) + jac_prototype = __generate_sparse_jacobian_prototype(alg, _u0, bcresid_prototype, N, + cur_nshoot) + loss_function! = NonlinearFunction{true}((args...) -> loss!(args..., cur_nshoot, - nodes); resid_prototype) + nodes); resid_prototype, jac = (args...) -> jac!(args..., cur_nshoot, nodes, residbc_prototype), jac_prototype) nlprob = NonlinearProblem(loss_function!, u_at_nodes, prob.p) - sol_nlsolve = solve(nlprob, alg.nlsolve; nlsolve_kwargs..., kwargs...) + sol_nlsolve = solve(nlprob, alg.nlsolve; nlsolve_kwargs..., verbose, kwargs...) u_at_nodes = sol_nlsolve.u end single_shooting_prob = remake(prob; u0 = reshape(u_at_nodes[1:N], u0_size)) return SciMLBase.__solve(single_shooting_prob, Shooting(alg.ode_alg; alg.nlsolve); - odesolve_kwargs, nlsolve_kwargs, kwargs...) + odesolve_kwargs, nlsolve_kwargs, verbose, kwargs...) end -function multiple_shooting_initialize(prob, alg::MultipleShooting; odesolve_kwargs = (;), - kwargs...) +function multiple_shooting_initialize(prob, alg::MultipleShooting, has_initial_guess, + nshoots; odesolve_kwargs = (;), verbose = true, kwargs...) @unpack f, bc, u0, tspan, p = prob - @unpack ode_alg, nshoots = alg + @unpack ode_alg = alg - N = length(u0) nodes = range(tspan[1], tspan[2]; length = nshoots + 1) + N = has_initial_guess ? length(first(u0)) : length(u0) + + if has_initial_guess + u_at_nodes = similar(first(u0), (nshoots + 1) * N) + recursive_flatten!(u_at_nodes, u0) + return nodes, u_at_nodes + end # Ensures type stability in case the parameters are dual numbers if !(typeof(p) <: SciMLBase.NullParameters) - if !isconcretetype(eltype(p)) + if !isconcretetype(eltype(p)) && verbose @warn "Type inference will fail if eltype(p) is not a concrete type" end u_at_nodes = similar(u0, promote_type(eltype(u0), eltype(p)), (nshoots + 1) * N) @@ -97,7 +185,7 @@ function multiple_shooting_initialize(prob, alg::MultipleShooting; odesolve_kwar # Assumes no initial guess for now start_prob = ODEProblem{isinplace(prob)}(f, u0, tspan, p) - sol = solve(start_prob, ode_alg; odesolve_kwargs..., kwargs..., saveat = nodes) + sol = solve(start_prob, ode_alg; odesolve_kwargs..., verbose, kwargs..., saveat = nodes) if SciMLBase.successful_retcode(sol) u_at_nodes[1:N] .= sol.u[1] @@ -114,10 +202,10 @@ function multiple_shooting_initialize(prob, alg::MultipleShooting; odesolve_kwar end @views @inline function multiple_shooting_initialize(u_at_nodes_prev, prob, alg, - prev_nodes, nshoots, old_nshoots; odesolve_kwargs = (;), kwargs...) + prev_nodes, nshoots, old_nshoots, has_initial_guess; odesolve_kwargs = (;), kwargs...) @unpack f, bc, u0, tspan, p = prob nodes = range(tspan[1], tspan[2]; length = nshoots + 1) - N = length(u0) + N = has_initial_guess ? length(first(u0)) : length(u0) u_at_nodes = similar(u_at_nodes_prev, N + nshoots * N) u_at_nodes[1:N] .= u_at_nodes_prev[1:N] @@ -156,8 +244,7 @@ end return nodes, u_at_nodes end -@inline function get_all_nshoots(alg::MultipleShooting) - @unpack nshoots, grid_coarsening = alg +@inline function get_all_nshoots(grid_coarsening, nshoots) if grid_coarsening isa Bool !grid_coarsening && return [nshoots] update_fn = Base.Fix2(÷, 2) @@ -176,3 +263,30 @@ end @assert !(1 in nshoots_vec) return nshoots_vec end + +function __generate_sparse_jacobian_prototype(::MultipleShooting, u0, bcresid_prototype, + N::Int, nshoots::Int) + # Assume dense BC + J_bc = similar(bcresid_prototype, length(bcresid_prototype), N * (nshoots + 1)) + + # Sparse for Stitching solution together + Is = Vector{UInt32}(undef, (N^2 + N) * nshoots) + Js = Vector{UInt32}(undef, (N^2 + N) * nshoots) + + idx = 1 + for i in 1:nshoots + for (i₁, i₂) in Iterators.product(1:N, 1:N) + Is[idx] = i₁ + ((i - 1) * N) + Js[idx] = i₂ + ((i - 1) * N) + idx += 1 + end + Is[idx:(idx + N - 1)] .= (1:N) .+ ((i - 1) * N) + Js[idx:(idx + N - 1)] .= (1:N) .+ (i * N) + idx += N + end + + J_c = sparse(adapt(parameterless_type(u0), Is), adapt(parameterless_type(u0), Js), + similar(u0, length(Is))) + + return vcat(J_bc, J_c) +end diff --git a/src/solve/single_shooting.jl b/src/solve/single_shooting.jl index e1d53da1..2fc8ef21 100644 --- a/src/solve/single_shooting.jl +++ b/src/solve/single_shooting.jl @@ -1,13 +1,20 @@ function SciMLBase.__solve(prob::BVProblem, alg::Shooting; odesolve_kwargs = (;), - nlsolve_kwargs = (;), kwargs...) - iip, bc, u0, u0_size = isinplace(prob), prob.bc, deepcopy(prob.u0), size(prob.u0) + nlsolve_kwargs = (;), verbose = true, kwargs...) + has_initial_guess = prob.u0 isa AbstractVector{<:AbstractArray} + has_initial_guess && verbose && + @warn "Initial guess provided, but will be ignored for Shooting!" + u0 = has_initial_guess ? first(prob.u0) : prob.u0 + + iip, bc, u0, u0_size = isinplace(prob), prob.bc, deepcopy(u0), size(u0) resid_size = prob.f.bcresid_prototype === nothing ? u0_size : size(prob.f.bcresid_prototype) + loss_fn = if iip function loss!(resid, u0_, p) u0_internal = reshape(u0_, u0_size) tmp_prob = ODEProblem{iip}(prob.f, u0_internal, prob.tspan, p) - internal_sol = solve(tmp_prob, alg.ode_alg; odesolve_kwargs..., kwargs...) + internal_sol = solve(tmp_prob, alg.ode_alg; odesolve_kwargs..., verbose, + kwargs...) eval_bc_residual!(reshape(resid, resid_size), prob.problem_type, bc, internal_sol, p) return nothing @@ -16,15 +23,16 @@ function SciMLBase.__solve(prob::BVProblem, alg::Shooting; odesolve_kwargs = (;) function loss(u0_, p) u0_internal = reshape(u0_, u0_size) tmp_prob = ODEProblem(prob.f, u0_internal, prob.tspan, p) - internal_sol = solve(tmp_prob, alg.ode_alg; odesolve_kwargs..., kwargs...) + internal_sol = solve(tmp_prob, alg.ode_alg; odesolve_kwargs..., verbose, + kwargs...) return vec(eval_bc_residual(prob.problem_type, bc, internal_sol, p)) end end opt = solve(NonlinearProblem(NonlinearFunction{iip}(loss_fn; prob.f.jac_prototype, resid_prototype = prob.f.bcresid_prototype), vec(u0), prob.p), alg.nlsolve; - nlsolve_kwargs..., kwargs...) + nlsolve_kwargs..., verbose, kwargs...) newprob = ODEProblem{iip}(prob.f, reshape(opt.u, u0_size), prob.tspan, prob.p) - sol = solve(newprob, alg.ode_alg; odesolve_kwargs..., kwargs...) + sol = solve(newprob, alg.ode_alg; odesolve_kwargs..., verbose, kwargs...) if !SciMLBase.successful_retcode(opt) return SciMLBase.solution_new_retcode(sol, ReturnCode.Failure) diff --git a/src/utils.jl b/src/utils.jl index 4535d084..1f34a057 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -89,73 +89,6 @@ eval_bc_residual!(resid, _, bc!, sol, p, t) = bc!(resid, sol, p, t) return resid end -# Generating Banded Matrix -function construct_sparse_banded_jac_prototype(y, M, N) - l = sum(i -> min(2M + i, M * N) - max(1, i - 1) + 1, 1:(M * (N - 1))) - Is = Vector{Int}(undef, l) - Js = Vector{Int}(undef, l) - idx = 1 - for i in 1:(M * (N - 1)), j in max(1, i - 1):min(2M + i, M * N) - Is[idx] = i - Js[idx] = j - idx += 1 - end - col_colorvec = Vector{Int}(undef, M * N) - for i in eachindex(col_colorvec) - col_colorvec[i] = mod1(i, min(2M + 1, M * N) + 1) - end - row_colorvec = Vector{Int}(undef, M * (N - 1)) - for i in eachindex(row_colorvec) - row_colorvec[i] = mod1(i, min(2M + 1, M * N) + 1) - end - - y_ = similar(y, length(Is)) - return (sparse(adapt(parameterless_type(y), Is), adapt(parameterless_type(y), Js), - y_, M * (N - 1), M * N), col_colorvec, row_colorvec) -end - -# Two Point Specialization -function construct_sparse_banded_jac_prototype(y::ArrayPartition, M, N) - l = sum(i -> min(2M + i, M * N) - max(1, i - 1) + 1, 1:(M * (N - 1))) - l_top = M * length(y.x[1].x[1]) - l_bot = M * length(y.x[1].x[2]) - - Is = Vector{Int}(undef, l + l_top + l_bot) - Js = Vector{Int}(undef, l + l_top + l_bot) - idx = 1 - - for i in 1:length(y.x[1].x[1]), j in 1:M - Is[idx] = i - Js[idx] = j - idx += 1 - end - - for i in 1:(M * (N - 1)), j in max(1, i - 1):min(2M + i, M * N) - Is[idx] = i + length(y.x[1].x[1]) - Js[idx] = j - idx += 1 - end - - for i in 1:length(y.x[1].x[2]), j in 1:M - Is[idx] = i + length(y.x[1].x[1]) + M * (N - 1) - Js[idx] = j + M * (N - 1) - idx += 1 - end - - col_colorvec = Vector{Int}(undef, M * N) - for i in eachindex(col_colorvec) - col_colorvec[i] = mod1(i, min(2M + 1, M * N) + 1) - end - row_colorvec = Vector{Int}(undef, M * N) - for i in eachindex(row_colorvec) - row_colorvec[i] = mod1(i, min(2M + 1, M * N) + 1) - end - - y_ = similar(y, length(Is)) - return (sparse(adapt(parameterless_type(y), Is), adapt(parameterless_type(y), Js), - y_, M * N, M * N), col_colorvec, row_colorvec) -end - # Helpers for IIP/OOP functions function __sparse_jacobian_cache(::Val{iip}, ad, sd, fn, fx, y) where {iip} if iip