Skip to content

Commit

Permalink
Merge pull request #109 from avik-pal/ap/twopoint
Browse files Browse the repository at this point in the history
Fast path for Two Point BVPs
  • Loading branch information
ChrisRackauckas authored Sep 30, 2023
2 parents 186339d + 2d5dc83 commit 9a97340
Show file tree
Hide file tree
Showing 21 changed files with 883 additions and 245 deletions.
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ jobs:
- Core
version:
- '1'
- '1.6'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
20 changes: 15 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "BoundaryValueDiffEq"
uuid = "764a87c0-6b3e-53db-9096-fe964310641d"
version = "4.1.0"
version = "4.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -21,30 +21,40 @@ SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"

[extensions]
BoundaryValueDiffEqODEInterfaceExt = "ODEInterface"

[compat]
ADTypes = "0.2"
Adapt = "3"
ArrayInterface = "7"
ConcreteStructs = "0.2"
DiffEqBase = "6.94.2"
ForwardDiff = "0.10"
NonlinearSolve = "1"
NonlinearSolve = "2"
ODEInterface = "0.5"
PreallocationTools = "0.4"
RecursiveArrayTools = "2.38.10"
Reexport = "0.2, 1.0"
SciMLBase = "1.70"
SciMLBase = "2"
Setfield = "1"
SparseDiffTools = "2.6"
TruncatedStacktraces = "1"
UnPack = "1"
julia = "1.6"
julia = "1.9"

[extras]
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["StaticArrays", "Random", "DiffEqDevTools", "OrdinaryDiffEq", "Test", "NonlinearSolve", "SafeTestsets"]
test = ["StaticArrays", "Random", "DiffEqDevTools", "OrdinaryDiffEq", "Test", "NonlinearSolve", "SafeTestsets", "ODEInterface"]
113 changes: 113 additions & 0 deletions ext/BoundaryValueDiffEqODEInterfaceExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
module BoundaryValueDiffEqODEInterfaceExt

using SciMLBase, BoundaryValueDiffEq, ODEInterface
import ODEInterface: OptionsODE, OPT_ATOL, OPT_RTOL, OPT_METHODCHOICE, OPT_DIAGNOSTICOUTPUT,
OPT_ERRORCONTROL, OPT_SINGULARTERM, OPT_MAXSTEPS, OPT_BVPCLASS, OPT_SOLMETHOD,
OPT_RHS_CALLMODE, RHS_CALL_INSITU, evalSolution
import ODEInterface: Bvpm2, bvpm2_init, bvpm2_solve, bvpm2_destroy, bvpm2_get_x
import ODEInterface: bvpsol

function _test_bvpm2_bvpsol_problem_criteria(_, ::SciMLBase.StandardBVProblem, alg::Symbol)
throw(ArgumentError("$(alg) does not support standard BVProblem. Only TwoPointBVProblem is supported."))
end
function _test_bvpm2_bvpsol_problem_criteria(prob, ::TwoPointBVProblem, alg::Symbol)
@assert isinplace(prob) "$(alg) only supports inplace TwoPointBVProblem!"
end

#------
# BVPM2
#------
## TODO: We can specify Drhs using forwarddiff if we want to
function SciMLBase.__solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3, kwargs...)
_test_bvpm2_bvpsol_problem_criteria(prob, prob.problem_type, :BVPM2)

has_initial_guess = prob.u0 isa AbstractVector{<:AbstractArray}
no_odes, n, u0 = if has_initial_guess
length(first(prob.u0)), (length(prob.u0) - 1), reduce(hcat, prob.u0)
else
dt 0 && throw(ArgumentError("dt must be positive"))
length(prob.u0), Int(cld((prob.tspan[2] - prob.tspan[1]), dt)), prob.u0
end

mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1))

no_left_bc = length(first(prob.f.bcresid_prototype.x))

initial_guess = Bvpm2()
bvpm2_init(initial_guess, no_odes, no_left_bc, mesh, u0, eltype(u0)[],
alg.max_num_subintervals)

bvp2m_f(t, u, du) = prob.f(du, u, prob.p, t)
bvp2m_bc(ya, yb, bca, bcb) = prob.bc((bca, bcb), (ya, yb), prob.p)

opt = OptionsODE(OPT_RTOL => reltol, OPT_METHODCHOICE => alg.method_choice,
OPT_DIAGNOSTICOUTPUT => alg.diagnostic_output,
OPT_SINGULARTERM => alg.singular_term, OPT_ERRORCONTROL => alg.error_control)

sol, retcode, stats = bvpm2_solve(initial_guess, bvp2m_f, bvp2m_bc, opt)
retcode = retcode 0 ? ReturnCode.Success : ReturnCode.Failure

x_mesh = bvpm2_get_x(sol)
sol_final = DiffEqBase.build_solution(prob, alg, x_mesh,
eachcol(evalSolution(sol, x_mesh)); retcode, stats)

bvpm2_destroy(initial_guess)
bvpm2_destroy(sol)

return sol_final
end

#-------
# BVPSOL
#-------
function SciMLBase.__solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000, reltol = 1e-3,
dt = 0.0, verbose = true, kwargs...)
_test_bvpm2_bvpsol_problem_criteria(prob, prob.problem_type, :BVPSOL)
@assert isa(prob.p, SciMLBase.NullParameters) "BVPSOL only supports NullParameters!"
@assert isa(prob.u0, AbstractVector{<:AbstractArray}) "BVPSOL requires a vector of initial guesses!"
n, u0 = (length(prob.u0) - 1), reduce(hcat, prob.u0)
mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1))

opt = OptionsODE(OPT_RTOL => reltol, OPT_MAXSTEPS => maxiters,
OPT_BVPCLASS => alg.bvpclass, OPT_SOLMETHOD => alg.sol_method,
OPT_RHS_CALLMODE => RHS_CALL_INSITU)

f!(t, u, du) = prob.f(du, u, prob.p, t)
function bc!(ya, yb, r)
ra = first(prob.f.bcresid_prototype.x)
rb = last(prob.f.bcresid_prototype.x)
prob.bc((ra, rb), (ya, yb), prob.p)
r[1:length(ra)] .= ra
r[(length(ra) + 1):(length(ra) + length(rb))] .= rb
return r
end

sol_t, sol_x, retcode, stats = bvpsol(f!, bc!, mesh, u0, alg.odesolver, opt)

if verbose
if retcode == -3
@warn "Integrator failed to complete the trajectory"
elseif retcode == -4
@warn "Gauss Newton method failed to converge"
elseif retcode == -5
@warn "Given initial values inconsistent with separable linear bc"
elseif retcode == -6
@warn """Iterative refinement faild to converge for `sol_method=0`
Termination since multiple shooting condition or
condition of Jacobian is too bad for `sol_method=1`"""
elseif retcode == -8
@warn "Condensing algorithm for linear block system fails, try `sol_method=1`"
elseif retcode == -9
@warn "Sparse linear solver failed"
elseif retcode == -10
@warn "Real or integer work-space exhausted"
elseif retcode == -11
@warn "Rank reduction failed - resulting rank is zero"
end
end

return DiffEqBase.build_solution(prob, alg, sol_t, eachcol(sol_x);
retcode = retcode 0 ? ReturnCode.Success : ReturnCode.Failure, stats)
end

end
13 changes: 11 additions & 2 deletions src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import ArrayInterface: matrix_colors, parameterless_type
import ConcreteStructs: @concrete
import DiffEqBase: solve
import ForwardDiff: pickchunksize
import RecursiveArrayTools: DiffEqArray
import RecursiveArrayTools: ArrayPartition, DiffEqArray
import SciMLBase: AbstractDiffEqInterpolation
import SparseDiffTools: AbstractSparseADType
import TruncatedStacktraces: @truncate_stacktrace
Expand All @@ -23,12 +23,21 @@ include("mirk_tableaus.jl")
include("cache.jl")
include("collocation.jl")
include("nlprob.jl")
include("solve.jl")
include("solve/single_shooting.jl")
include("solve/mirk.jl")
include("adaptivity.jl")
include("interpolation.jl")

function SciMLBase.__solve(prob::BVProblem, alg::BoundaryValueDiffEqAlgorithm, args...;
kwargs...)
cache = init(prob, alg, args...; kwargs...)
return solve!(cache)
end

export Shooting
export MIRK2, MIRK3, MIRK4, MIRK5, MIRK6
export MIRKJacobianComputationAlgorithm
# From ODEInterface.jl
export BVPM2, BVPSOL

end
42 changes: 27 additions & 15 deletions src/adaptivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ function interval(mesh, t)
end

"""
mesh_selector!(cache::MIRKCache{T})
mesh_selector!(cache::MIRKCache)
Generate new mesh based on the defect.
"""
@views function mesh_selector!(cache::MIRKCache{T}) where {T}
@views function mesh_selector!(cache::MIRKCache{iip, T}) where {iip, T}
@unpack M, order, defect, mesh, mesh_dt = cache
(_, MxNsub, abstol, _, _), kwargs = __split_mirk_kwargs(; cache.kwargs...)
N = length(cache.mesh)
Expand Down Expand Up @@ -81,11 +81,12 @@ Generate new mesh based on the defect.
end

"""
redistribute!(cache::MIRKCache{T}, Nsub_star, ŝ, mesh, mesh_dt) where {T}
redistribute!(cache::MIRKCache, Nsub_star, ŝ, mesh, mesh_dt)
Generate a new mesh based on the `ŝ`.
"""
function redistribute!(cache::MIRKCache{T}, Nsub_star, ŝ, mesh, mesh_dt) where {T}
function redistribute!(cache::MIRKCache{iip, T}, Nsub_star, ŝ, mesh,
mesh_dt) where {iip, T}
N = length(mesh)
ζ = sum(ŝ .* mesh_dt) / Nsub_star
k, i = 1, 0
Expand Down Expand Up @@ -138,14 +139,14 @@ end
half_mesh!(cache::MIRKCache) = half_mesh!(cache.mesh, cache.mesh_dt)

"""
defect_estimate!(cache::MIRKCache{T})
defect_estimate!(cache::MIRKCache)
defect_estimate use the discrete solution approximation Y, plus stages of
the RK method in 'k_discrete', plus some new stages in 'k_interp' to construct
an interpolant
"""
@views function defect_estimate!(cache::MIRKCache{T}) where {T}
@unpack M, stage, f!, alg, mesh, mesh_dt, defect = cache
@views function defect_estimate!(cache::MIRKCache{iip, T}) where {iip, T}
@unpack M, stage, f, alg, mesh, mesh_dt, defect = cache
@unpack s_star, τ_star = cache.ITU

# Evaluate at the first sample point
Expand All @@ -157,16 +158,24 @@ an interpolant

for i in 1:(length(mesh) - 1)
dt = mesh_dt[i]
yᵢ₁ = cache.y[i].du
yᵢ₂ = cache.y[i + 1].du

z, z′ = sum_stages!(cache, w₁, w₁′, i)
f!(yᵢ₁, z, cache.p, mesh[i] + τ_star * dt)
if iip
yᵢ₁ = cache.y[i].du
f(yᵢ₁, z, cache.p, mesh[i] + τ_star * dt)
else
yᵢ₁ = f(z, cache.p, mesh[i] + τ_star * dt)
end
yᵢ₁ .= (z′ .- yᵢ₁) ./ (abs.(yᵢ₁) .+ T(1))
est₁ = maximum(abs, yᵢ₁)

z, z′ = sum_stages!(cache, w₂, w₂′, i)
f!(yᵢ₂, z, cache.p, mesh[i] + (T(1) - τ_star) * dt)
if iip
yᵢ₂ = cache.y[i + 1].du
f(yᵢ₂, z, cache.p, mesh[i] + (T(1) - τ_star) * dt)
else
yᵢ₂ = f(z, cache.p, mesh[i] + (T(1) - τ_star) * dt)
end
yᵢ₂ .= (z′ .- yᵢ₂) ./ (abs.(yᵢ₂) .+ T(1))
est₂ = maximum(abs, yᵢ₂)

Expand All @@ -182,11 +191,10 @@ end
`interp_setup!` prepare the extra stages in ki_interp for interpolant construction.
Here, the ki_interp is the stages in one subinterval.
"""
@views function interp_setup!(cache::MIRKCache{T}) where {T}
@views function interp_setup!(cache::MIRKCache{iip, T}) where {iip, T}
@unpack x_star, s_star, c_star, v_star = cache.ITU
@unpack k_interp, k_discrete, f!, stage, new_stages, y, p, mesh, mesh_dt = cache
@unpack k_interp, k_discrete, f, stage, new_stages, y, p, mesh, mesh_dt = cache

[fill!(s, zero(eltype(s))) for s in new_stages]
for r in 1:(s_star - stage)
idx₁ = ((1:stage) .- 1) .* (s_star - stage) .+ r
idx₂ = ((1:(r - 1)) .+ stage .- 1) .* (s_star - stage) .+ r
Expand All @@ -203,7 +211,11 @@ Here, the ki_interp is the stages in one subinterval.
new_stages[i] .= new_stages[i] .* mesh_dt[i] .+
(1 - v_star[r]) .* vec(y[i].du) .+
v_star[r] .* vec(y[i + 1].du)
f!(k_interp[i][:, r], new_stages[i], p, mesh[i] + c_star[r] * mesh_dt[i])
if iip
f(k_interp[i][:, r], new_stages[i], p, mesh[i] + c_star[r] * mesh_dt[i])
else
k_interp[i][:, r] .= f(new_stages[i], p, mesh[i] + c_star[r] * mesh_dt[i])
end
end
end

Expand Down
51 changes: 49 additions & 2 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const DEFAULT_NLSOLVE_SHOOTING = TrustRegion(; autodiff = Val(true))
const DEFAULT_NLSOLVE_MIRK = NewtonRaphson(; autodiff = Val(true))
const DEFAULT_NLSOLVE_SHOOTING = NewtonRaphson()
const DEFAULT_NLSOLVE_MIRK = NewtonRaphson()
const DEFAULT_JACOBIAN_ALGORITHM_MIRK = MIRKJacobianComputationAlgorithm()

# Algorithms
Expand Down Expand Up @@ -50,3 +50,50 @@ for order in (2, 3, 4, 5, 6)
end
end
end

"""
BVPM2(; max_num_subintervals = 3000, method_choice = 4, diagnostic_output = 1,
error_control = 1, singular_term = nothing)
BVPM2(max_num_subintervals::Int, method_choice::Int, diagnostic_output::Int,
error_control::Int, singular_term)
Fortran code for solving two-point boundary value problems. For detailed documentation, see
[ODEInterface.jl](https://github.com/luchr/ODEInterface.jl/blob/master/doc/SolverOptions.md#bvpm2).
!!! warning
Only supports inplace two-point boundary value problems, with very limited forms of
input structures!
!!! note
Only available if the `ODEInterface` package is loaded.
"""
Base.@kwdef struct BVPM2{S} <: BoundaryValueDiffEqAlgorithm
max_num_subintervals::Int = 3000
method_choice::Int = 4
diagnostic_output::Int = -1
error_control::Int = 1
singular_term::S = nothing
end

"""
BVPSOL(; bvpclass = 2, sol_method = 0, odesolver = nothing)
BVPSOL(bvpclass::Int, sol_methods::Int, odesolver)
A FORTRAN77 code which solves highly nonlinear two point boundary value problems using a
local linear solver (condensing algorithm) or a global sparse linear solver for the solution
of the arising linear subproblems, by Peter Deuflhard, Georg Bader, Lutz Weimann.
For detailed documentation, see
[ODEInterface.jl](https://github.com/luchr/ODEInterface.jl/blob/master/doc/SolverOptions.md#bvpsol).
!!! warning
Only supports inplace two-point boundary value problems, with very limited forms of
input structures!
!!! note
Only available if the `ODEInterface` package is loaded.
"""
Base.@kwdef struct BVPSOL{O} <: BoundaryValueDiffEqAlgorithm
bvpclass::Int = 2
sol_method::Int = 0
odesolver::O = nothing
end
Loading

0 comments on commit 9a97340

Please sign in to comment.