Skip to content

Commit

Permalink
Merge pull request #110 from avik-pal/ap/multiple-shooting
Browse files Browse the repository at this point in the history
Porting over Multiple Shooting and Single Shooting from NeuralBVP
  • Loading branch information
ChrisRackauckas authored Oct 12, 2023
2 parents bdc87f1 + 4065e9c commit 9611bb7
Show file tree
Hide file tree
Showing 26 changed files with 1,476 additions and 633 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ jobs:
strategy:
matrix:
group:
- Core
- Shooting
- MIRK
- Others
version:
- '1'
steps:
Expand All @@ -32,6 +34,8 @@ jobs:
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
GROUP: ${{ matrix.group }}
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v3
with:
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

Expand Down
5 changes: 3 additions & 2 deletions ext/BoundaryValueDiffEqODEInterfaceExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module BoundaryValueDiffEqODEInterfaceExt

using SciMLBase, BoundaryValueDiffEq, ODEInterface
import SciMLBase: __solve
import ODEInterface: OptionsODE, OPT_ATOL, OPT_RTOL, OPT_METHODCHOICE, OPT_DIAGNOSTICOUTPUT,
OPT_ERRORCONTROL, OPT_SINGULARTERM, OPT_MAXSTEPS, OPT_BVPCLASS, OPT_SOLMETHOD,
OPT_RHS_CALLMODE, RHS_CALL_INSITU, evalSolution
Expand All @@ -18,7 +19,7 @@ end
# BVPM2
#------
## TODO: We can specify Drhs using forwarddiff if we want to
function SciMLBase.__solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3, kwargs...)
function __solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3, kwargs...)
_test_bvpm2_bvpsol_problem_criteria(prob, prob.problem_type, :BVPM2)

has_initial_guess = prob.u0 isa AbstractVector{<:AbstractArray}
Expand Down Expand Up @@ -64,7 +65,7 @@ end
#-------
# BVPSOL
#-------
function SciMLBase.__solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000, reltol = 1e-3,
function __solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000, reltol = 1e-3,
dt = 0.0, verbose = true, kwargs...)
_test_bvpm2_bvpsol_problem_criteria(prob, prob.problem_type, :BVPSOL)
@assert isa(prob.p, SciMLBase.NullParameters) "BVPSOL only supports NullParameters!"
Expand Down
24 changes: 14 additions & 10 deletions src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
module BoundaryValueDiffEq

using Adapt, LinearAlgebra, PreallocationTools, Reexport, Setfield, SparseArrays, SciMLBase,
RecursiveArrayTools
Static, RecursiveArrayTools, ForwardDiff
@reexport using ADTypes, DiffEqBase, NonlinearSolve, SparseDiffTools, SciMLBase

import ADTypes: AbstractADType
import ArrayInterface: matrix_colors, parameterless_type
import ArrayInterface: matrix_colors, parameterless_type, undefmatrix
import ConcreteStructs: @concrete
import DiffEqBase: solve
import ForwardDiff: pickchunksize
import RecursiveArrayTools: ArrayPartition, DiffEqArray
import SciMLBase: AbstractDiffEqInterpolation
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve
import RecursiveArrayTools: ArrayPartition
import SparseDiffTools: AbstractSparseADType
import TruncatedStacktraces: @truncate_stacktrace
import UnPack: @unpack
Expand All @@ -19,24 +20,27 @@ include("types.jl")
include("utils.jl")
include("algorithms.jl")
include("alg_utils.jl")

include("mirk_tableaus.jl")
include("cache.jl")
include("collocation.jl")
include("nlprob.jl")

include("solve/single_shooting.jl")
include("solve/multiple_shooting.jl")
include("solve/mirk.jl")

include("collocation.jl")
include("sparse_jacobians.jl")

include("adaptivity.jl")
include("interpolation.jl")

function SciMLBase.__solve(prob::BVProblem, alg::BoundaryValueDiffEqAlgorithm, args...;
kwargs...)
function __solve(prob::BVProblem, alg::BoundaryValueDiffEqAlgorithm, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
return solve!(cache)
end

export Shooting
export Shooting, MultipleShooting
export MIRK2, MIRK3, MIRK4, MIRK5, MIRK6
export MIRKJacobianComputationAlgorithm
export MIRKJacobianComputationAlgorithm, BVPJacobianAlgorithm
# From ODEInterface.jl
export BVPM2, BVPSOL

Expand Down
25 changes: 7 additions & 18 deletions src/adaptivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ end
Find the interval that `t` belongs to in `mesh`. Assumes that `mesh` is sorted.
"""
function interval(mesh, t)
t == first(mesh) && return 1
t == last(mesh) && return length(mesh) - 1
return searchsortedfirst(mesh, t) - 1
return clamp(searchsortedfirst(mesh, t) - 1, 1, length(mesh) - 1)
end

"""
Expand Down Expand Up @@ -237,11 +235,8 @@ function sum_stages!(z, cache::MIRKCache, w, i::Int, dt = cache.mesh_dt[i])

z .= zero(z)
__maybe_matmul!(z, k_discrete[i].du[:, 1:stage], w[1:stage])
__maybe_matmul!(z,
k_interp[i][:, 1:(s_star - stage)],
w[(stage + 1):s_star],
true,
true)
__maybe_matmul!(z, k_interp[i][:, 1:(s_star - stage)],
w[(stage + 1):s_star], true, true)
z .= z .* dt .+ cache.y₀[i]

return z
Expand All @@ -253,18 +248,12 @@ end

z .= zero(z)
__maybe_matmul!(z, k_discrete[i].du[:, 1:stage], w[1:stage])
__maybe_matmul!(z,
k_interp[i][:, 1:(s_star - stage)],
w[(stage + 1):s_star],
true,
true)
__maybe_matmul!(z, k_interp[i][:, 1:(s_star - stage)],
w[(stage + 1):s_star], true, true)
z′ .= zero(z′)
__maybe_matmul!(z′, k_discrete[i].du[:, 1:stage], w′[1:stage])
__maybe_matmul!(z′,
k_interp[i][:, 1:(s_star - stage)],
w′[(stage + 1):s_star],
true,
true)
__maybe_matmul!(z′, k_interp[i][:, 1:(s_star - stage)],
w′[(stage + 1):s_star], true, true)
z .= z .* dt[1] .+ cache.y₀[i]

return z, z′
Expand Down
67 changes: 50 additions & 17 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
const DEFAULT_NLSOLVE_SHOOTING = NewtonRaphson()
const DEFAULT_NLSOLVE_MIRK = NewtonRaphson()
const DEFAULT_JACOBIAN_ALGORITHM_MIRK = MIRKJacobianComputationAlgorithm()

# Algorithms
abstract type BoundaryValueDiffEqAlgorithm <: SciMLBase.AbstractBVPAlgorithm end
abstract type AbstractMIRK <: BoundaryValueDiffEqAlgorithm end

"""
Shooting(ode_alg; nlsolve = BoundaryValueDiffEq.DEFAULT_NLSOLVE_SHOOTING)
Shooting(ode_alg; nlsolve = NewtonRaphson())
Single shooting method, reduces BVP to an initial value problem and solves the IVP.
"""
Expand All @@ -16,36 +12,73 @@ struct Shooting{O, N} <: BoundaryValueDiffEqAlgorithm
nlsolve::N
end

Shooting(ode_alg; nlsolve = DEFAULT_NLSOLVE_SHOOTING) = Shooting(ode_alg, nlsolve)
Shooting(ode_alg; nlsolve = NewtonRaphson()) = Shooting(ode_alg, nlsolve)

"""
MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(),
grid_coarsening = true)
Multiple Shooting method, reduces BVP to an initial value problem and solves the IVP.
Significantly more stable than Single Shooting.
"""
@concrete struct MultipleShooting{J <: BVPJacobianAlgorithm}
ode_alg
nlsolve
jac_alg::J
nshoots::Int
grid_coarsening
end

function concretize_jacobian_algorithm(alg::MultipleShooting, prob)
jac_alg = concrete_jacobian_algorithm(alg.jac_alg, prob, alg)
return MultipleShooting(alg.ode_alg, alg.nlsolve, jac_alg, alg.nshoots,
alg.grid_coarsening)
end

function update_nshoots(alg::MultipleShooting, nshoots::Int)
return MultipleShooting(alg.ode_alg, alg.nlsolve, alg.jac_alg, nshoots,
alg.grid_coarsening)
end

function MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(),
grid_coarsening = true, jac_alg = BVPJacobianAlgorithm())
@assert grid_coarsening isa Bool || grid_coarsening isa Function ||
grid_coarsening isa AbstractVector{<:Integer} ||
grid_coarsening isa NTuple{N, <:Integer} where {N}
grid_coarsening isa Tuple && (grid_coarsening = Vector(grid_coarsening...))
if grid_coarsening isa AbstractVector
sort!(grid_coarsening; rev = true)
@assert all(grid_coarsening .> 0) && 1 grid_coarsening
end
return MultipleShooting(ode_alg, nlsolve, jac_alg, nshoots, grid_coarsening)
end

for order in (2, 3, 4, 5, 6)
alg = Symbol("MIRK$(order)")

@eval begin
"""
$($alg)(; nlsolve = BoundaryValueDiffEq.DEFAULT_NLSOLVE_MIRK,
jac_alg = BoundaryValueDiffEq.DEFAULT_JACOBIAN_ALGORITHM_MIRK)
$($alg)(; nlsolve = NewtonRaphson(), jac_alg = BVPJacobianAlgorithm())
$($order)th order Monotonic Implicit Runge Kutta method, with Newton Raphson nonlinear solver as default.
## References
@article{Enright1996RungeKuttaSW,
title={Runge-Kutta Software with Defect Control for Boundary Value ODEs},
author={Wayne H. Enright and Paul H. Muir},
journal={SIAM J. Sci. Comput.},
year={1996},
volume={17},
pages={479-497}
title={Runge-Kutta Software with Defect Control for Boundary Value ODEs},
author={Wayne H. Enright and Paul H. Muir},
journal={SIAM J. Sci. Comput.},
year={1996},
volume={17},
pages={479-497}
}
"""
struct $(alg){N, J <: MIRKJacobianComputationAlgorithm} <: AbstractMIRK
struct $(alg){N, J <: BVPJacobianAlgorithm} <: AbstractMIRK
nlsolve::N
jac_alg::J
end

function $(alg)(; nlsolve = DEFAULT_NLSOLVE_MIRK,
jac_alg = DEFAULT_JACOBIAN_ALGORITHM_MIRK)
function $(alg)(; nlsolve = NewtonRaphson(), jac_alg = BVPJacobianAlgorithm())
return $(alg)(nlsolve, jac_alg)
end
end
Expand Down
67 changes: 0 additions & 67 deletions src/cache.jl

This file was deleted.

6 changes: 0 additions & 6 deletions src/collocation.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
__initial_state_from_prob(prob::BVProblem, mesh) = __initial_state_from_prob(prob.u0, mesh)
__initial_state_from_prob(u0::AbstractArray, mesh) = [copy(vec(u0)) for _ in mesh]
function __initial_state_from_prob(u0::AbstractVector{<:AbstractVector}, _)
[copy(vec(u)) for u in u0]
end

function Φ!(residual, cache::MIRKCache, y, u, p = cache.p)
return Φ!(residual, cache.fᵢ_cache, cache.k_discrete, cache.f, cache.TU,
y, u, p, cache.mesh, cache.mesh_dt, cache.stage)
Expand Down
Loading

0 comments on commit 9611bb7

Please sign in to comment.