Skip to content

Commit

Permalink
Single Shooting Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 11, 2023
1 parent 6f40df9 commit b9074b6
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 48 deletions.
5 changes: 3 additions & 2 deletions ext/BoundaryValueDiffEqODEInterfaceExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module BoundaryValueDiffEqODEInterfaceExt

using SciMLBase, BoundaryValueDiffEq, ODEInterface
import SciMLBase: __solve
import ODEInterface: OptionsODE, OPT_ATOL, OPT_RTOL, OPT_METHODCHOICE, OPT_DIAGNOSTICOUTPUT,
OPT_ERRORCONTROL, OPT_SINGULARTERM, OPT_MAXSTEPS, OPT_BVPCLASS, OPT_SOLMETHOD,
OPT_RHS_CALLMODE, RHS_CALL_INSITU, evalSolution
Expand All @@ -18,7 +19,7 @@ end
# BVPM2
#------
## TODO: We can specify Drhs using forwarddiff if we want to
function SciMLBase.__solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3, kwargs...)
function __solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3, kwargs...)
_test_bvpm2_bvpsol_problem_criteria(prob, prob.problem_type, :BVPM2)

has_initial_guess = prob.u0 isa AbstractVector{<:AbstractArray}
Expand Down Expand Up @@ -64,7 +65,7 @@ end
#-------
# BVPSOL
#-------
function SciMLBase.__solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000, reltol = 1e-3,
function __solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000, reltol = 1e-3,
dt = 0.0, verbose = true, kwargs...)
_test_bvpm2_bvpsol_problem_criteria(prob, prob.problem_type, :BVPSOL)
@assert isa(prob.p, SciMLBase.NullParameters) "BVPSOL only supports NullParameters!"
Expand Down
5 changes: 2 additions & 3 deletions src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import ConcreteStructs: @concrete
import DiffEqBase: solve
import ForwardDiff: pickchunksize
import RecursiveArrayTools: ArrayPartition, DiffEqArray
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve
import RecursiveArrayTools: ArrayPartition
import SparseDiffTools: AbstractSparseADType
import TruncatedStacktraces: @truncate_stacktrace
Expand All @@ -33,8 +33,7 @@ include("sparse_jacobians.jl")
include("adaptivity.jl")
include("interpolation.jl")

function SciMLBase.__solve(prob::BVProblem, alg::BoundaryValueDiffEqAlgorithm, args...;
kwargs...)
function __solve(prob::BVProblem, alg::BoundaryValueDiffEqAlgorithm, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
return solve!(cache)
end
Expand Down
15 changes: 0 additions & 15 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,6 @@ Significantly more stable than Single Shooting.
grid_coarsening
end

function concrete_jacobian_algorithm(jac_alg::BVPJacobianAlgorithm, prob,
alg::MultipleShooting)
diffmode = jac_alg.diffmode === nothing ? AutoSparseForwardDiff() : jac_alg.diffmode
bc_diffmode = if jac_alg.bc_diffmode === nothing
prob.problem_type isa TwoPointBVProblem ? AutoSparseForwardDiff() :
AutoForwardDiff()
else
jac_alg.bc_diffmode
end
nonbc_diffmode = jac_alg.nonbc_diffmode === nothing ? AutoSparseForwardDiff() :
jac_alg.nonbc_diffmode

return BVPJacobianAlgorithm(bc_diffmode, nonbc_diffmode, diffmode)
end

function MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(),
grid_coarsening = true, jac_alg = BVPJacobianAlgorithm())
@assert grid_coarsening isa Bool || grid_coarsening isa Function ||
Expand Down
8 changes: 6 additions & 2 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
end
(__vecbc_a!, __vecbc_b!)
end
bcresid_prototype = vec(bcresid_prototype)
vecf!, vecbc!
else
vecf(u, p, t) = vec(prob.f(reshape(u, size(X)), p, t))
Expand All @@ -102,6 +103,7 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
__vecbc_b(ub, p) = vec(prob.f.bc[2](reshape(ub, size(X)), p))
(__vecbc_a, __vecbc_b)
end
bcresid_prototype = vec(bcresid_prototype)
vecf, vecbc
end

Expand Down Expand Up @@ -225,8 +227,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {
function loss_internal!(resid::AbstractVector, u::AbstractVector, p = cache.p)
y_ = recursive_unflatten!(cache.y, u)
resids = [get_tmp(r, u) for r in cache.residual]
eval_bc_residual!(resids[1], cache.problem_type, cache.bc, y_, p,
cache.mesh)
eval_bc_residual!(resids[1], cache.problem_type, cache.bc, y_, p, cache.mesh)
Φ!(resids[2:end], cache, y_, u, p)
if cache.problem_type isa TwoPointBVProblem
recursive_flatten_twopoint!(resid, resids)
Expand Down Expand Up @@ -305,6 +306,9 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati

resid = ArrayPartition(cache.bcresid_prototype, similar(y, cache.M * (N - 1)))

# TODO: We can splitup the computation here as well similar to the Multiple Shooting
# TODO: code. That way for the BC part the actual jacobian computation is even cheaper
# TODO: Remember to not reorder if we end up using that implementation
sd = if jac_alg.diffmode isa AbstractSparseADType
PrecomputedJacobianColorvec(__generate_sparse_jacobian_prototype(cache,
cache.problem_type, resid.x[1], cache.M, N))
Expand Down
4 changes: 2 additions & 2 deletions src/solve/multiple_shooting.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function SciMLBase.__solve(prob::BVProblem, alg::MultipleShooting; odesolve_kwargs = (;),
function __solve(prob::BVProblem, alg::MultipleShooting; odesolve_kwargs = (;),
nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), verbose = true, kwargs...)
@unpack f, tspan = prob
bc = prob.f.bc
Expand Down Expand Up @@ -188,7 +188,7 @@ function SciMLBase.__solve(prob::BVProblem, alg::MultipleShooting; odesolve_kwar

resid_prototype = ArrayPartition(bcresid_prototype,
similar(u_at_nodes, cur_nshoot * N))
resid_nodes = maybe_allocate_diffcache(resid_prototype.x[2],
resid_nodes = __maybe_allocate_diffcache(resid_prototype.x[2],
pickchunksize((cur_nshoot + 1) * N), alg.jac_alg.bc_diffmode)

if prob.problem_type isa TwoPointBVProblem
Expand Down
29 changes: 12 additions & 17 deletions src/solve/single_shooting.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,32 @@
function SciMLBase.__solve(prob::BVProblem, alg::Shooting; odesolve_kwargs = (;),
function __solve(prob::BVProblem, alg::Shooting; odesolve_kwargs = (;),
nlsolve_kwargs = (;), verbose = true, kwargs...)
ig, T, _, _, u0 = __extract_problem_details(prob; dt = 0.1)
known(ig) && verbose &&
@warn "Initial guess provided, but will be ignored for Shooting!"

bcresid_prototype, resid_size = __get_bcresid_prototype(prob, u0)
iip, bc, u0, u0_size = isinplace(prob), prob.f.bc, deepcopy(u0), size(u0)
resid_size = prob.f.bcresid_prototype === nothing ? u0_size :
size(prob.f.bcresid_prototype)

loss_fn = if iip
function loss!(resid, u0_, p)
u0_internal = reshape(u0_, u0_size)
tmp_prob = ODEProblem{iip}(prob.f, u0_internal, prob.tspan, p)
internal_sol = solve(tmp_prob, alg.ode_alg; odesolve_kwargs..., verbose,
kwargs...)
eval_bc_residual!(reshape(resid, resid_size), prob.problem_type, bc,
internal_sol, p)
odeprob = ODEProblem{true}(prob.f, reshape(u0_, u0_size), prob.tspan, p)
odesol = __solve(odeprob, alg.ode_alg; odesolve_kwargs..., verbose, kwargs...)
eval_bc_residual!(__safe_reshape(resid, resid_size), prob.problem_type, bc,
odesol, p)
return nothing
end
else
function loss(u0_, p)
u0_internal = reshape(u0_, u0_size)
tmp_prob = ODEProblem(prob.f, u0_internal, prob.tspan, p)
internal_sol = solve(tmp_prob, alg.ode_alg; odesolve_kwargs..., verbose,
kwargs...)
return vec(eval_bc_residual(prob.problem_type, bc, internal_sol, p))
odeprob = ODEProblem{false}(prob.f, reshape(u0_, u0_size), prob.tspan, p)
odesol = __solve(odeprob, alg.ode_alg; odesolve_kwargs..., verbose, kwargs...)
return vec(eval_bc_residual(prob.problem_type, bc, odesol, p))
end
end
opt = solve(NonlinearProblem(NonlinearFunction{iip}(loss_fn; prob.f.jac_prototype,
resid_prototype = prob.f.bcresid_prototype), vec(u0), prob.p), alg.nlsolve;
opt = __solve(NonlinearProblem(NonlinearFunction{iip}(loss_fn; prob.f.jac_prototype,
resid_prototype = bcresid_prototype), vec(u0), prob.p), alg.nlsolve;
nlsolve_kwargs..., verbose, kwargs...)
newprob = ODEProblem{iip}(prob.f, reshape(opt.u, u0_size), prob.tspan, prob.p)
sol = solve(newprob, alg.ode_alg; odesolve_kwargs..., verbose, kwargs...)
sol = __solve(newprob, alg.ode_alg; odesolve_kwargs..., verbose, kwargs...)

if !SciMLBase.successful_retcode(opt)
return SciMLBase.solution_new_retcode(sol, ReturnCode.Failure)
Expand Down
2 changes: 1 addition & 1 deletion src/sparse_jacobians.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,4 @@ function __generate_sparse_jacobian_prototype(::MIRKCache, ::TwoPointBVProblem,
return ColoredMatrix(J, row_colorvec, col_colorvec)
end

# For Multiple Shooting
# For Multiple Shooting
29 changes: 28 additions & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,23 @@ function BVPJacobianAlgorithm(diffmode = missing; nonbc_diffmode = missing,
end
end

function concrete_jacobian_algorithm(jac_alg::BVPJacobianAlgorithm, prob, alg)
"""
concrete_jacobian_algorithm(jac_alg, prob, alg)
concrete_jacobian_algorithm(jac_alg, problem_type, prob, alg)
If user provided all the required fields, then return the user provided algorithm.
Otherwise, based on the problem type and the algorithm, decide the missing fields.
For example, for `TwoPointBVProblem`, the `bc_diffmode` is set to
`AutoSparseForwardDiff` while for `StandardBVProblem`, the `bc_diffmode` is set to
`AutoForwardDiff`.
"""
function concrete_jacobian_algorithm(jac_alg::BVPJacobianAlgorithm, prob::BVProblem, alg)
return concrete_jacobian_algorithm(jac_alg, prob.problem_type, prob, alg)
end

function concrete_jacobian_algorithm(jac_alg::BVPJacobianAlgorithm, ::StandardBVProblem,
prob::BVProblem, alg)
diffmode = jac_alg.diffmode === nothing ? AutoSparseForwardDiff() : jac_alg.diffmode
bc_diffmode = jac_alg.bc_diffmode === nothing ? AutoForwardDiff() : jac_alg.bc_diffmode
nonbc_diffmode = jac_alg.nonbc_diffmode === nothing ? AutoSparseForwardDiff() :
Expand All @@ -61,6 +77,17 @@ function concrete_jacobian_algorithm(jac_alg::BVPJacobianAlgorithm, prob, alg)
return BVPJacobianAlgorithm(bc_diffmode, nonbc_diffmode, diffmode)
end

function concrete_jacobian_algorithm(jac_alg::BVPJacobianAlgorithm, ::TwoPointBVProblem,
prob::BVProblem, alg)
diffmode = jac_alg.diffmode === nothing ? AutoSparseForwardDiff() : jac_alg.diffmode
bc_diffmode = jac_alg.bc_diffmode === nothing ? AutoSparseForwardDiff() :
jac_alg.bc_diffmode
nonbc_diffmode = jac_alg.nonbc_diffmode === nothing ? AutoSparseForwardDiff() :
jac_alg.nonbc_diffmode

return BVPJacobianAlgorithm(bc_diffmode, nonbc_diffmode, diffmode)
end

function MIRKJacobianComputationAlgorithm(diffmode = missing;
collocation_diffmode = missing, bc_diffmode = missing)
Base.depwarn("`MIRKJacobianComputationAlgorithm` has been deprecated in favor of \
Expand Down
18 changes: 13 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ function __append_similar!(x::AbstractVector{<:MaybeDiffCache}, n, M)
N == 0 && return x
N < 0 && throw(ArgumentError("Cannot append a negative number of elements"))
chunksize = pickchunksize(M * (N + length(x)))
append!(x, [maybe_allocate_diffcache(first(x), chunksize) for _ in 1:N])
append!(x, [__maybe_allocate_diffcache(first(x), chunksize) for _ in 1:N])
return x
end

Expand All @@ -132,19 +132,22 @@ function __initial_state_from_prob(u0::AbstractVector{<:AbstractVector}, _)
return [copy(vec(u)) for u in u0]
end

function __get_bcresid_prototype(::TwoPointBVProblem, prob, u)
function __get_bcresid_prototype(prob::BVProblem, u)
return __get_bcresid_prototype(prob.problem_type, prob, u)
end
function __get_bcresid_prototype(::TwoPointBVProblem, prob::BVProblem, u)
prototype = if isinplace(prob)
prob.f.bcresid_prototype
elseif prob.f.bcresid_prototype === nothing
elseif prob.f.bcresid_prototype !== nothing
prob.f.bcresid_prototype
else
ArrayPartition(first(prob.f.bc)(u, prob.p), last(prob.f.bc)(u, prob.p))
end
return prototype, size.(prototype.x)
end
function __get_bcresid_prototype(::StandardBVProblem, prob, u)
function __get_bcresid_prototype(::StandardBVProblem, prob::BVProblem, u)
prototype = prob.f.bcresid_prototype !== nothing ? prob.f.bcresid_prototype :
fill!(similar(u), 0)
__zeros_like(u)
return prototype, size(prototype)
end

Expand All @@ -155,3 +158,8 @@ function __fill_like(v, x, args...)
end
__zeros_like(args...) = __fill_like(0, args...)
__ones_like(args...) = __fill_like(1, args...)

__safe_reshape(x, args...) = reshape(x, args...)
function __safe_reshape(x::ArrayPartition, sizes::NTuple)
return ArrayPartition(__safe_reshape.(x.x, sizes))
end
2 changes: 2 additions & 0 deletions test/misc/non_vector_inputs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,6 @@ probs = [
@test norm(boundary(sol, prob.p, nothing)) < 0.01
end
end

# TODO: Multiple Shooting
end

0 comments on commit b9074b6

Please sign in to comment.