Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast path for Two Point BVPs #109

Merged
merged 15 commits into from
Sep 30, 2023
Merged
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ jobs:
- Core
version:
- '1'
- '1.6'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
20 changes: 15 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "BoundaryValueDiffEq"
uuid = "764a87c0-6b3e-53db-9096-fe964310641d"
version = "4.1.0"
version = "4.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -21,30 +21,40 @@ SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"

[extensions]
BoundaryValueDiffEqODEInterfaceExt = "ODEInterface"

[compat]
ADTypes = "0.2"
Adapt = "3"
ArrayInterface = "7"
ConcreteStructs = "0.2"
DiffEqBase = "6.94.2"
ForwardDiff = "0.10"
NonlinearSolve = "1"
NonlinearSolve = "2"
ODEInterface = "0.5"
PreallocationTools = "0.4"
RecursiveArrayTools = "2.38.10"
Reexport = "0.2, 1.0"
SciMLBase = "1.70"
SciMLBase = "2"
Setfield = "1"
SparseDiffTools = "2.6"
TruncatedStacktraces = "1"
UnPack = "1"
julia = "1.6"
julia = "1.9"

[extras]
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["StaticArrays", "Random", "DiffEqDevTools", "OrdinaryDiffEq", "Test", "NonlinearSolve", "SafeTestsets"]
test = ["StaticArrays", "Random", "DiffEqDevTools", "OrdinaryDiffEq", "Test", "NonlinearSolve", "SafeTestsets", "ODEInterface"]
113 changes: 113 additions & 0 deletions ext/BoundaryValueDiffEqODEInterfaceExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
module BoundaryValueDiffEqODEInterfaceExt

using SciMLBase, BoundaryValueDiffEq, ODEInterface
import ODEInterface: OptionsODE, OPT_ATOL, OPT_RTOL, OPT_METHODCHOICE, OPT_DIAGNOSTICOUTPUT,
OPT_ERRORCONTROL, OPT_SINGULARTERM, OPT_MAXSTEPS, OPT_BVPCLASS, OPT_SOLMETHOD,
OPT_RHS_CALLMODE, RHS_CALL_INSITU, evalSolution
import ODEInterface: Bvpm2, bvpm2_init, bvpm2_solve, bvpm2_destroy, bvpm2_get_x
import ODEInterface: bvpsol

function _test_bvpm2_bvpsol_problem_criteria(_, ::SciMLBase.StandardBVProblem, alg::Symbol)
throw(ArgumentError("$(alg) does not support standard BVProblem. Only TwoPointBVProblem is supported."))
end
function _test_bvpm2_bvpsol_problem_criteria(prob, ::TwoPointBVProblem, alg::Symbol)
@assert isinplace(prob) "$(alg) only supports inplace TwoPointBVProblem!"
end

#------
# BVPM2
#------
## TODO: We can specify Drhs using forwarddiff if we want to
function SciMLBase.__solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3, kwargs...)
_test_bvpm2_bvpsol_problem_criteria(prob, prob.problem_type, :BVPM2)

has_initial_guess = prob.u0 isa AbstractVector{<:AbstractArray}
no_odes, n, u0 = if has_initial_guess
length(first(prob.u0)), (length(prob.u0) - 1), reduce(hcat, prob.u0)
else
dt ≤ 0 && throw(ArgumentError("dt must be positive"))
length(prob.u0), Int(cld((prob.tspan[2] - prob.tspan[1]), dt)), prob.u0
end

mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1))

no_left_bc = length(first(prob.f.bcresid_prototype.x))

initial_guess = Bvpm2()
bvpm2_init(initial_guess, no_odes, no_left_bc, mesh, u0, eltype(u0)[],
alg.max_num_subintervals)

bvp2m_f(t, u, du) = prob.f(du, u, prob.p, t)
bvp2m_bc(ya, yb, bca, bcb) = prob.bc((bca, bcb), (ya, yb), prob.p)

opt = OptionsODE(OPT_RTOL => reltol, OPT_METHODCHOICE => alg.method_choice,
OPT_DIAGNOSTICOUTPUT => alg.diagnostic_output,
OPT_SINGULARTERM => alg.singular_term, OPT_ERRORCONTROL => alg.error_control)

sol, retcode, stats = bvpm2_solve(initial_guess, bvp2m_f, bvp2m_bc, opt)
retcode = retcode ≥ 0 ? ReturnCode.Success : ReturnCode.Failure

x_mesh = bvpm2_get_x(sol)
sol_final = DiffEqBase.build_solution(prob, alg, x_mesh,
eachcol(evalSolution(sol, x_mesh)); retcode, stats)

bvpm2_destroy(initial_guess)
bvpm2_destroy(sol)

return sol_final
end

#-------
# BVPSOL
#-------
function SciMLBase.__solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000, reltol = 1e-3,
dt = 0.0, verbose = true, kwargs...)
_test_bvpm2_bvpsol_problem_criteria(prob, prob.problem_type, :BVPSOL)
@assert isa(prob.p, SciMLBase.NullParameters) "BVPSOL only supports NullParameters!"
@assert isa(prob.u0, AbstractVector{<:AbstractArray}) "BVPSOL requires a vector of initial guesses!"
n, u0 = (length(prob.u0) - 1), reduce(hcat, prob.u0)
mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1))

opt = OptionsODE(OPT_RTOL => reltol, OPT_MAXSTEPS => maxiters,
OPT_BVPCLASS => alg.bvpclass, OPT_SOLMETHOD => alg.sol_method,
OPT_RHS_CALLMODE => RHS_CALL_INSITU)

f!(t, u, du) = prob.f(du, u, prob.p, t)
function bc!(ya, yb, r)
ra = first(prob.f.bcresid_prototype.x)
rb = last(prob.f.bcresid_prototype.x)
prob.bc((ra, rb), (ya, yb), prob.p)
r[1:length(ra)] .= ra
r[(length(ra) + 1):(length(ra) + length(rb))] .= rb
return r
end

sol_t, sol_x, retcode, stats = bvpsol(f!, bc!, mesh, u0, alg.odesolver, opt)

if verbose
if retcode == -3
@warn "Integrator failed to complete the trajectory"
elseif retcode == -4
@warn "Gauss Newton method failed to converge"
elseif retcode == -5
@warn "Given initial values inconsistent with separable linear bc"
elseif retcode == -6
@warn """Iterative refinement faild to converge for `sol_method=0`
Termination since multiple shooting condition or
condition of Jacobian is too bad for `sol_method=1`"""
elseif retcode == -8
@warn "Condensing algorithm for linear block system fails, try `sol_method=1`"
elseif retcode == -9
@warn "Sparse linear solver failed"
elseif retcode == -10
@warn "Real or integer work-space exhausted"
elseif retcode == -11
@warn "Rank reduction failed - resulting rank is zero"
end
end

return DiffEqBase.build_solution(prob, alg, sol_t, eachcol(sol_x);
retcode = retcode ≥ 0 ? ReturnCode.Success : ReturnCode.Failure, stats)
end

end
13 changes: 11 additions & 2 deletions src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import ArrayInterface: matrix_colors, parameterless_type
import ConcreteStructs: @concrete
import DiffEqBase: solve
import ForwardDiff: pickchunksize
import RecursiveArrayTools: DiffEqArray
import RecursiveArrayTools: ArrayPartition, DiffEqArray
import SciMLBase: AbstractDiffEqInterpolation
import SparseDiffTools: AbstractSparseADType
import TruncatedStacktraces: @truncate_stacktrace
Expand All @@ -23,12 +23,21 @@ include("mirk_tableaus.jl")
include("cache.jl")
include("collocation.jl")
include("nlprob.jl")
include("solve.jl")
include("solve/single_shooting.jl")
include("solve/mirk.jl")
include("adaptivity.jl")
include("interpolation.jl")

function SciMLBase.__solve(prob::BVProblem, alg::BoundaryValueDiffEqAlgorithm, args...;
kwargs...)
cache = init(prob, alg, args...; kwargs...)
return solve!(cache)
end

export Shooting
export MIRK2, MIRK3, MIRK4, MIRK5, MIRK6
export MIRKJacobianComputationAlgorithm
# From ODEInterface.jl
export BVPM2, BVPSOL

end
42 changes: 27 additions & 15 deletions src/adaptivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ function interval(mesh, t)
end

"""
mesh_selector!(cache::MIRKCache{T})
mesh_selector!(cache::MIRKCache)

Generate new mesh based on the defect.
"""
@views function mesh_selector!(cache::MIRKCache{T}) where {T}
@views function mesh_selector!(cache::MIRKCache{iip, T}) where {iip, T}
@unpack M, order, defect, mesh, mesh_dt = cache
(_, MxNsub, abstol, _, _), kwargs = __split_mirk_kwargs(; cache.kwargs...)
N = length(cache.mesh)
Expand Down Expand Up @@ -81,11 +81,12 @@ Generate new mesh based on the defect.
end

"""
redistribute!(cache::MIRKCache{T}, Nsub_star, ŝ, mesh, mesh_dt) where {T}
redistribute!(cache::MIRKCache, Nsub_star, ŝ, mesh, mesh_dt)

Generate a new mesh based on the `ŝ`.
"""
function redistribute!(cache::MIRKCache{T}, Nsub_star, ŝ, mesh, mesh_dt) where {T}
function redistribute!(cache::MIRKCache{iip, T}, Nsub_star, ŝ, mesh,
mesh_dt) where {iip, T}
N = length(mesh)
ζ = sum(ŝ .* mesh_dt) / Nsub_star
k, i = 1, 0
Expand Down Expand Up @@ -138,14 +139,14 @@ end
half_mesh!(cache::MIRKCache) = half_mesh!(cache.mesh, cache.mesh_dt)

"""
defect_estimate!(cache::MIRKCache{T})
defect_estimate!(cache::MIRKCache)

defect_estimate use the discrete solution approximation Y, plus stages of
the RK method in 'k_discrete', plus some new stages in 'k_interp' to construct
an interpolant
"""
@views function defect_estimate!(cache::MIRKCache{T}) where {T}
@unpack M, stage, f!, alg, mesh, mesh_dt, defect = cache
@views function defect_estimate!(cache::MIRKCache{iip, T}) where {iip, T}
@unpack M, stage, f, alg, mesh, mesh_dt, defect = cache
@unpack s_star, τ_star = cache.ITU

# Evaluate at the first sample point
Expand All @@ -157,16 +158,24 @@ an interpolant

for i in 1:(length(mesh) - 1)
dt = mesh_dt[i]
yᵢ₁ = cache.y[i].du
yᵢ₂ = cache.y[i + 1].du

z, z′ = sum_stages!(cache, w₁, w₁′, i)
f!(yᵢ₁, z, cache.p, mesh[i] + τ_star * dt)
if iip
yᵢ₁ = cache.y[i].du
f(yᵢ₁, z, cache.p, mesh[i] + τ_star * dt)
else
yᵢ₁ = f(z, cache.p, mesh[i] + τ_star * dt)
end
yᵢ₁ .= (z′ .- yᵢ₁) ./ (abs.(yᵢ₁) .+ T(1))
est₁ = maximum(abs, yᵢ₁)

z, z′ = sum_stages!(cache, w₂, w₂′, i)
f!(yᵢ₂, z, cache.p, mesh[i] + (T(1) - τ_star) * dt)
if iip
yᵢ₂ = cache.y[i + 1].du
f(yᵢ₂, z, cache.p, mesh[i] + (T(1) - τ_star) * dt)
else
yᵢ₂ = f(z, cache.p, mesh[i] + (T(1) - τ_star) * dt)
end
yᵢ₂ .= (z′ .- yᵢ₂) ./ (abs.(yᵢ₂) .+ T(1))
est₂ = maximum(abs, yᵢ₂)

Expand All @@ -182,11 +191,10 @@ end
`interp_setup!` prepare the extra stages in ki_interp for interpolant construction.
Here, the ki_interp is the stages in one subinterval.
"""
@views function interp_setup!(cache::MIRKCache{T}) where {T}
@views function interp_setup!(cache::MIRKCache{iip, T}) where {iip, T}
@unpack x_star, s_star, c_star, v_star = cache.ITU
@unpack k_interp, k_discrete, f!, stage, new_stages, y, p, mesh, mesh_dt = cache
@unpack k_interp, k_discrete, f, stage, new_stages, y, p, mesh, mesh_dt = cache

[fill!(s, zero(eltype(s))) for s in new_stages]
for r in 1:(s_star - stage)
idx₁ = ((1:stage) .- 1) .* (s_star - stage) .+ r
idx₂ = ((1:(r - 1)) .+ stage .- 1) .* (s_star - stage) .+ r
Expand All @@ -203,7 +211,11 @@ Here, the ki_interp is the stages in one subinterval.
new_stages[i] .= new_stages[i] .* mesh_dt[i] .+
(1 - v_star[r]) .* vec(y[i].du) .+
v_star[r] .* vec(y[i + 1].du)
f!(k_interp[i][:, r], new_stages[i], p, mesh[i] + c_star[r] * mesh_dt[i])
if iip
f(k_interp[i][:, r], new_stages[i], p, mesh[i] + c_star[r] * mesh_dt[i])
else
k_interp[i][:, r] .= f(new_stages[i], p, mesh[i] + c_star[r] * mesh_dt[i])
end
end
end

Expand Down
51 changes: 49 additions & 2 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const DEFAULT_NLSOLVE_SHOOTING = TrustRegion(; autodiff = Val(true))
const DEFAULT_NLSOLVE_MIRK = NewtonRaphson(; autodiff = Val(true))
const DEFAULT_NLSOLVE_SHOOTING = NewtonRaphson()
const DEFAULT_NLSOLVE_MIRK = NewtonRaphson()
Comment on lines -1 to +2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not so sure about that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change to NewtonRaphson was before the TR bugs were fixed. We can revert this part. The autodiff = true is redundant wince that is the current default.

const DEFAULT_JACOBIAN_ALGORITHM_MIRK = MIRKJacobianComputationAlgorithm()

# Algorithms
Expand Down Expand Up @@ -50,3 +50,50 @@ for order in (2, 3, 4, 5, 6)
end
end
end

"""
BVPM2(; max_num_subintervals = 3000, method_choice = 4, diagnostic_output = 1,
error_control = 1, singular_term = nothing)
BVPM2(max_num_subintervals::Int, method_choice::Int, diagnostic_output::Int,
error_control::Int, singular_term)

Fortran code for solving two-point boundary value problems. For detailed documentation, see
[ODEInterface.jl](https://github.com/luchr/ODEInterface.jl/blob/master/doc/SolverOptions.md#bvpm2).

!!! warning
Only supports inplace two-point boundary value problems, with very limited forms of
input structures!

!!! note
Only available if the `ODEInterface` package is loaded.
"""
Base.@kwdef struct BVPM2{S} <: BoundaryValueDiffEqAlgorithm
max_num_subintervals::Int = 3000
method_choice::Int = 4
diagnostic_output::Int = -1
error_control::Int = 1
singular_term::S = nothing
end

"""
BVPSOL(; bvpclass = 2, sol_method = 0, odesolver = nothing)
BVPSOL(bvpclass::Int, sol_methods::Int, odesolver)

A FORTRAN77 code which solves highly nonlinear two point boundary value problems using a
local linear solver (condensing algorithm) or a global sparse linear solver for the solution
of the arising linear subproblems, by Peter Deuflhard, Georg Bader, Lutz Weimann.
For detailed documentation, see
[ODEInterface.jl](https://github.com/luchr/ODEInterface.jl/blob/master/doc/SolverOptions.md#bvpsol).

!!! warning
Only supports inplace two-point boundary value problems, with very limited forms of
input structures!

!!! note
Only available if the `ODEInterface` package is loaded.
"""
Base.@kwdef struct BVPSOL{O} <: BoundaryValueDiffEqAlgorithm
bvpclass::Int = 2
sol_method::Int = 0
odesolver::O = nothing
end
Loading