Skip to content

Commit

Permalink
xSetup colorvec construction
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 9, 2023
1 parent 4de1d45 commit a4694e8
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 17 deletions.
32 changes: 22 additions & 10 deletions src/solve/multiple_shooting.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# TODO: incorporate `initial_guess` similar to MIRK methods
function SciMLBase.__solve(prob::BVProblem, alg::MultipleShooting; odesolve_kwargs = (;),
nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), verbose = true, kwargs...)
@unpack f, bc, tspan = prob
@unpack f, tspan = prob
bc = prob.f.bc
has_initial_guess = prob.u0 isa AbstractVector{<:AbstractArray}
_u0 = has_initial_guess ? first(prob.u0) : prob.u0
N, u0_size, nshoots, iip = length(_u0), size(_u0), alg.nshoots, isinplace(prob)
Expand Down Expand Up @@ -143,12 +144,18 @@ function SciMLBase.__solve(prob::BVProblem, alg::MultipleShooting; odesolve_kwar

resid_prototype = ArrayPartition(bcresid_prototype,
similar(u_at_nodes, cur_nshoot * N))
residbc_prototype = DiffCache(bcresid_prototype, pickchunksize(cur_nshoot * N))
jac_prototype = __generate_sparse_jacobian_prototype(alg, _u0, bcresid_prototype, N,
residbc_prototype = DiffCache(bcresid_prototype,
pickchunksize((cur_nshoot + 1) * N))

J_bc = similar(bcresid_prototype, length(bcresid_prototype), N * (cur_nshoot + 1))
J_c, col_colorvec, row_colorvec = __generate_sparse_jacobian_prototype(alg, _u0, N,
cur_nshoot)
jac_prototype = vcat(J_bc, J_c)

loss_function! = NonlinearFunction{true}((args...) -> loss!(args..., cur_nshoot,
nodes); resid_prototype, jac = (args...) -> jac!(args..., cur_nshoot, nodes, residbc_prototype), jac_prototype)
nodes); resid_prototype,
jac = (args...) -> jac!(args..., cur_nshoot, nodes, residbc_prototype),
jac_prototype)
nlprob = NonlinearProblem(loss_function!, u_at_nodes, prob.p)
sol_nlsolve = solve(nlprob, alg.nlsolve; nlsolve_kwargs..., verbose, kwargs...)
u_at_nodes = sol_nlsolve.u
Expand Down Expand Up @@ -264,11 +271,7 @@ end
return nshoots_vec
end

function __generate_sparse_jacobian_prototype(::MultipleShooting, u0, bcresid_prototype,
N::Int, nshoots::Int)
# Assume dense BC
J_bc = similar(bcresid_prototype, length(bcresid_prototype), N * (nshoots + 1))

function __generate_sparse_jacobian_prototype(::MultipleShooting, u0, N::Int, nshoots::Int)
# Sparse for Stitching solution together
Is = Vector{UInt32}(undef, (N^2 + N) * nshoots)
Js = Vector{UInt32}(undef, (N^2 + N) * nshoots)
Expand All @@ -288,5 +291,14 @@ function __generate_sparse_jacobian_prototype(::MultipleShooting, u0, bcresid_pr
J_c = sparse(adapt(parameterless_type(u0), Is), adapt(parameterless_type(u0), Js),
similar(u0, length(Is)))

return vcat(J_bc, J_c)
col_colorvec = Vector{Int}(undef, N * (nshoots + 1))
for i in eachindex(col_colorvec)
col_colorvec[i] = mod1(i, 2 * N)
end
row_colorvec = Vector{Int}(undef, N * nshoots)
for i in eachindex(row_colorvec)
row_colorvec[i] = mod1(i, 2 * N)
end

return J_c, col_colorvec, row_colorvec
end
6 changes: 2 additions & 4 deletions test/misc/type_stability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ bcresid_prototype = (zeros(1), zeros(1))
mpbvp_oop = BVProblem(f, bc, u0, tspan, p)

@testset "Shooting Methods" begin
@test_broken SciMLBase.successful_retcode(@inferred solve(mpbvp_iip,
Shooting(Tsit5())))
@inferred solve(mpbvp_iip, Shooting(Tsit5()))
@inferred solve(mpbvp_oop, Shooting(Tsit5()))
@inferred solve(mpbvp_iip, MultipleShooting(5, Tsit5()))
@inferred solve(mpbvp_oop, MultipleShooting(5, Tsit5()))
Expand All @@ -49,8 +48,7 @@ end
tpbvp_oop = TwoPointBVProblem(f, twobc, u0, tspan, p)

@testset "Shooting Methods" begin
@test_broken SciMLBase.successful_retcode(@inferred solve(tpbvp_iip,
Shooting(Tsit5())))
@inferred solve(tpbvp_iip, Shooting(Tsit5()))
@inferred solve(tpbvp_oop, Shooting(Tsit5()))
@inferred solve(tpbvp_iip, MultipleShooting(5, Tsit5()))
@inferred solve(tpbvp_oop, MultipleShooting(5, Tsit5()))
Expand Down
2 changes: 0 additions & 2 deletions test/shooting/orbital.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ using BoundaryValueDiffEq, OrdinaryDiffEq, LinearAlgebra, Test

@info "Testing Lambert's Problem"

@info "Testing Lambert's Problem"

y0 = [
-4.7763169762853989E+06,
-3.8386398704441520E+05,
Expand Down
5 changes: 4 additions & 1 deletion test/shooting/shooting_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ end
resid_f = Array{ComplexF64}(undef, 2)

nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff())
for solver in [Shooting(Tsit5(); nlsolve), MultipleShooting(10, Tsit5(); nlsolve)]
for solver in [Shooting(Tsit5(); nlsolve)]
# FIXME: Need to reenable MS. Currently it always uses ForwardDiff which is a
# regression and needs fixing
# , MultipleShooting(10, Tsit5(); nlsolve)]
sol = solve(bvp, solver; abstol = 1e-13, reltol = 1e-13)
@test SciMLBase.successful_retcode(sol)
bc1!(resid_f, sol, nothing, sol.t)
Expand Down

0 comments on commit a4694e8

Please sign in to comment.