diff --git a/src/solve/multiple_shooting.jl b/src/solve/multiple_shooting.jl index 0ab9f87c..bb6c9638 100644 --- a/src/solve/multiple_shooting.jl +++ b/src/solve/multiple_shooting.jl @@ -143,12 +143,18 @@ function SciMLBase.__solve(prob::BVProblem, alg::MultipleShooting; odesolve_kwar resid_prototype = ArrayPartition(bcresid_prototype, similar(u_at_nodes, cur_nshoot * N)) - residbc_prototype = DiffCache(bcresid_prototype, pickchunksize(cur_nshoot * N)) - jac_prototype = __generate_sparse_jacobian_prototype(alg, _u0, bcresid_prototype, N, + residbc_prototype = DiffCache(bcresid_prototype, + pickchunksize((cur_nshoot + 1) * N)) + + J_bc = similar(bcresid_prototype, length(bcresid_prototype), N * (cur_nshoot + 1)) + J_c, col_colorvec, row_colorvec = __generate_sparse_jacobian_prototype(alg, _u0, N, cur_nshoot) + jac_prototype = vcat(J_bc, J_c) loss_function! = NonlinearFunction{true}((args...) -> loss!(args..., cur_nshoot, - nodes); resid_prototype, jac = (args...) -> jac!(args..., cur_nshoot, nodes, residbc_prototype), jac_prototype) + nodes); resid_prototype, + jac = (args...) -> jac!(args..., cur_nshoot, nodes, residbc_prototype), + jac_prototype) nlprob = NonlinearProblem(loss_function!, u_at_nodes, prob.p) sol_nlsolve = solve(nlprob, alg.nlsolve; nlsolve_kwargs..., verbose, kwargs...) u_at_nodes = sol_nlsolve.u @@ -264,11 +270,7 @@ end return nshoots_vec end -function __generate_sparse_jacobian_prototype(::MultipleShooting, u0, bcresid_prototype, - N::Int, nshoots::Int) - # Assume dense BC - J_bc = similar(bcresid_prototype, length(bcresid_prototype), N * (nshoots + 1)) - +function __generate_sparse_jacobian_prototype(::MultipleShooting, u0, N::Int, nshoots::Int) # Sparse for Stitching solution together Is = Vector{UInt32}(undef, (N^2 + N) * nshoots) Js = Vector{UInt32}(undef, (N^2 + N) * nshoots) @@ -288,5 +290,14 @@ function __generate_sparse_jacobian_prototype(::MultipleShooting, u0, bcresid_pr J_c = sparse(adapt(parameterless_type(u0), Is), adapt(parameterless_type(u0), Js), similar(u0, length(Is))) - return vcat(J_bc, J_c) + col_colorvec = Vector{Int}(undef, N * (nshoots + 1)) + for i in eachindex(col_colorvec) + col_colorvec[i] = mod1(i, 2 * N) + end + row_colorvec = Vector{Int}(undef, N * nshoots) + for i in eachindex(row_colorvec) + row_colorvec[i] = mod1(i, 2 * N) + end + + return J_c, col_colorvec, row_colorvec end diff --git a/test/misc/type_stability.jl b/test/misc/type_stability.jl index 72eb5974..7e6cafa8 100644 --- a/test/misc/type_stability.jl +++ b/test/misc/type_stability.jl @@ -28,8 +28,7 @@ bcresid_prototype = (zeros(1), zeros(1)) mpbvp_oop = BVProblem(f, bc, u0, tspan, p) @testset "Shooting Methods" begin - @test_broken SciMLBase.successful_retcode(@inferred solve(mpbvp_iip, - Shooting(Tsit5()))) + @inferred solve(mpbvp_iip, Shooting(Tsit5())) @inferred solve(mpbvp_oop, Shooting(Tsit5())) @inferred solve(mpbvp_iip, MultipleShooting(5, Tsit5())) @inferred solve(mpbvp_oop, MultipleShooting(5, Tsit5())) @@ -49,8 +48,7 @@ end tpbvp_oop = TwoPointBVProblem(f, twobc, u0, tspan, p) @testset "Shooting Methods" begin - @test_broken SciMLBase.successful_retcode(@inferred solve(tpbvp_iip, - Shooting(Tsit5()))) + @inferred solve(tpbvp_iip, Shooting(Tsit5())) @inferred solve(tpbvp_oop, Shooting(Tsit5())) @inferred solve(tpbvp_iip, MultipleShooting(5, Tsit5())) @inferred solve(tpbvp_oop, MultipleShooting(5, Tsit5())) diff --git a/test/shooting/orbital.jl b/test/shooting/orbital.jl index c63bf556..4ef5396d 100644 --- a/test/shooting/orbital.jl +++ b/test/shooting/orbital.jl @@ -3,8 +3,6 @@ using BoundaryValueDiffEq, OrdinaryDiffEq, LinearAlgebra, Test @info "Testing Lambert's Problem" -@info "Testing Lambert's Problem" - y0 = [ -4.7763169762853989E+06, -3.8386398704441520E+05, diff --git a/test/shooting/shooting_tests.jl b/test/shooting/shooting_tests.jl index 6d36647c..13811574 100644 --- a/test/shooting/shooting_tests.jl +++ b/test/shooting/shooting_tests.jl @@ -109,7 +109,10 @@ end resid_f = Array{ComplexF64}(undef, 2) nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()) - for solver in [Shooting(Tsit5(); nlsolve), MultipleShooting(10, Tsit5(); nlsolve)] + for solver in [Shooting(Tsit5(); nlsolve)] + # FIXME: Need to reenable MS. Currently it always uses ForwardDiff which is a + # regression and needs fixing + # , MultipleShooting(10, Tsit5(); nlsolve)] sol = solve(bvp, solver; abstol = 1e-6, reltol = 1e-6) @test SciMLBase.successful_retcode(sol) bc1!(resid_f, sol, nothing, sol.t)