From 4889ba5ace64cf5b5fbfe9caafe9c433b07806a7 Mon Sep 17 00:00:00 2001
From: Avik Pal <avikpal@mit.edu>
Date: Tue, 19 Sep 2023 18:13:20 -0400
Subject: [PATCH] Add BVPSOL and BVPM2

---
 Project.toml               |  8 ++++
 ext/BVPODEInterfaceExt.jl  | 95 ++++++++++++++++++++++++++++++++++++++
 src/BoundaryValueDiffEq.jl |  2 +
 src/adaptivity.jl          |  5 +-
 src/algorithms.jl          | 47 +++++++++++++++++++
 5 files changed, 153 insertions(+), 4 deletions(-)
 create mode 100644 ext/BVPODEInterfaceExt.jl

diff --git a/Project.toml b/Project.toml
index 4b407102..b3286e45 100644
--- a/Project.toml
+++ b/Project.toml
@@ -29,6 +29,7 @@ ConcreteStructs = "0.2"
 DiffEqBase = "6.94.2"
 ForwardDiff = "0.10"
 NonlinearSolve = "2"
+ODEInterface = "0.5"
 Reexport = "0.2, 1.0"
 SciMLBase = "1"
 Setfield = "1"
@@ -36,9 +37,16 @@ TruncatedStacktraces = "1"
 UnPack = "1"
 julia = "1.6"
 
+[weakdeps]
+ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
+
+[extensions]
+BVPODEInterfaceExt = "ODEInterface"
+
 [extras]
 DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
 NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
+ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
 OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
diff --git a/ext/BVPODEInterfaceExt.jl b/ext/BVPODEInterfaceExt.jl
new file mode 100644
index 00000000..b9c40692
--- /dev/null
+++ b/ext/BVPODEInterfaceExt.jl
@@ -0,0 +1,95 @@
+module BVPODEInterfaceExt
+
+using SciMLBase, BoundaryValueDiffEq, ODEInterface
+import ODEInterface: OptionsODE, OPT_ATOL, OPT_RTOL, OPT_METHODCHOICE, OPT_DIAGNOSTICOUTPUT,
+    OPT_ERRORCONTROL, OPT_SINGULARTERM, OPT_MAXSTEPS, OPT_BVPCLASS, OPT_SOLMETHOD,
+    OPT_RHS_CALLMODE, RHS_CALL_INSITU, evalSolution
+import ODEInterface: Bvpm2, bvpm2_init, bvpm2_solve, bvpm2_destroy, bvpm2_get_x
+import ODEInterface: bvpsol
+
+function _test_bvpm2_bvpsol_problem_criteria(_, ::SciMLBase.StandardBVProblem, alg::Symbol)
+    throw(ArgumentError("$(alg) does not support standard BVProblem. Only TwoPointBVProblem is supported."))
+end
+function _test_bvpm2_bvpsol_problem_criteria(prob, ::TwoPointBVProblem, alg::Symbol)
+    @assert isinplace(prob) "$(alg) only supports inplace TwoPointBVProblem!"
+end
+
+#------
+# BVPM2
+#------
+_no_param(::SciMLBase.NullParameters) = Float64[]
+_no_param(p) = p
+
+bvpm2_bc(bc, ya, yb, bca, bcb) = bc((bca, bcb), (ya, yb), SciMLBase.NullParameters())
+bvpm2_bc(bc, ya, yb, p, bca, bcb) = bc((bca, bcb), (ya, yb), p)
+
+bvp2m_f(f, t, u, du) = f(du, u, SciMLBase.NullParameters(), t)
+bvp2m_f(f, t, u, p, du) = f(du, u, p, t)
+
+## TODO: We can specify Drhs using forwarddiff if we want to
+function SciMLBase.__solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3, kwargs...)
+    _test_bvpm2_bvpsol_problem_criteria(prob, prob.problem_type, :BVPM2)
+
+    has_initial_guess = prob.u0 isa AbstractVector{<:AbstractArray}
+    no_odes, n, u0 = if has_initial_guess
+        length(first(prob.u0)), (length(prob.u0) - 1), reduce(hcat, prob.u0)
+    else
+        dt ≤ 0 && throw(ArgumentError("dt must be positive"))
+        length(prob.u0), Int(cld((prob.tspan[2] - prob.tspan[1]), dt)), prob.u0
+    end
+
+    mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1))
+
+    no_left_bc = length(first(prob.f.bcresid_prototype.x))
+
+    initial_guess = Bvpm2()
+    bvpm2_init(initial_guess, no_odes, no_left_bc, mesh, u0, _no_param(prob.p),
+        alg.max_num_subintervals)
+
+    rhs = (args...) -> bvp2m_f(prob.f, args...)
+    bc = (args...) -> bvpm2_bc(prob.bc, args...)
+
+    opt = OptionsODE(OPT_RTOL => reltol, OPT_METHODCHOICE => alg.method_choice,
+        OPT_DIAGNOSTICOUTPUT => alg.diagnostic_output,
+        OPT_SINGULARTERM => alg.singular_term, OPT_ERRORCONTROL => alg.error_control)
+
+    sol, retcode, stats = bvpm2_solve(initial_guess, rhs, bc, opt)
+    retcode = retcode ≥ 0 ? ReturnCode.Success : ReturnCode.Failure
+
+    x_mesh = bvpm2_get_x(sol)
+    return DiffEqBase.build_solution(prob, alg, x_mesh, eachcol(evalSolution(sol, x_mesh));
+        retcode, stats)
+end
+
+#-------
+# BVPSOL
+#-------
+bvpsol_f(f, t, u, du) = f(du, u, SciMLBase.NullParameters(), t)
+function bvpsol_bc(bc, ra, rb, ya, yb, r)
+    bc((view(r, 1:(length(ra))), view(r, (length(ra) + 1):(length(ra) + length(rb)))),
+        (ya, yb), SciMLBase.NullParameters())
+end
+
+function SciMLBase.__solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000, reltol = 1e-3,
+    dt = 0.0, kwargs...)
+    _test_bvpm2_bvpsol_problem_criteria(prob, prob.problem_type, :BVPSOL)
+    @assert isa(prob.p, SciMLBase.NullParameters) "BVPSOL only supports NullParameters!"
+    @assert isa(prob.u0, AbstractVector{<:AbstractArray}) "BVPSOL requires a vector of initial guesses!"
+    n, u0 = (length(prob.u0) - 1), reduce(hcat, prob.u0)
+    mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1))
+
+    opt = OptionsODE(OPT_RTOL => reltol, OPT_MAXSTEPS => maxiters,
+        OPT_BVPCLASS => alg.bvpclass, OPT_SOLMETHOD => alg.sol_method,
+        OPT_RHS_CALLMODE => RHS_CALL_INSITU)
+
+    f! = (args...) -> bvpsol_f(prob.f, args...)
+    bc! = (args...) -> bvpsol_bc(prob.bc, first(prob.f.bcresid_prototype.x),
+        last(prob.f.bcresid_prototype.x), args...)
+
+    sol_t, sol_x, retcode, stats = bvpsol(f!, bc!, mesh, u0, alg.odesolver, opt)
+
+    return DiffEqBase.build_solution(prob, alg, sol_t, eachcol(sol_x);
+        retcode = retcode ≥ 0 ? ReturnCode.Success : ReturnCode.Failure, stats)
+end
+
+end
diff --git a/src/BoundaryValueDiffEq.jl b/src/BoundaryValueDiffEq.jl
index 6fa5f1ca..a9d923c4 100644
--- a/src/BoundaryValueDiffEq.jl
+++ b/src/BoundaryValueDiffEq.jl
@@ -34,5 +34,7 @@ include("adaptivity.jl")
 export Shooting
 export MIRK2, MIRK3, MIRK4, MIRK5, MIRK6
 export MIRKJacobianComputationAlgorithm
+# From ODEInterface.jl
+export BVPM2, BVPSOL
 
 end
diff --git a/src/adaptivity.jl b/src/adaptivity.jl
index d7d16b7d..4ed4ffff 100644
--- a/src/adaptivity.jl
+++ b/src/adaptivity.jl
@@ -85,10 +85,7 @@ end
 
 Generate a new mesh based on the `ŝ`.
 """
-function redistribute!(cache::MIRKCache{iip, T},
-    Nsub_star,
-    ŝ,
-    mesh,
+function redistribute!(cache::MIRKCache{iip, T}, Nsub_star, ŝ, mesh,
     mesh_dt) where {iip, T}
     N = length(mesh)
     ζ = sum(ŝ .* mesh_dt) / Nsub_star
diff --git a/src/algorithms.jl b/src/algorithms.jl
index a8aa25bc..cbba22df 100644
--- a/src/algorithms.jl
+++ b/src/algorithms.jl
@@ -50,3 +50,50 @@ for order in (2, 3, 4, 5, 6)
         end
     end
 end
+
+"""
+    BVPM2(; max_num_subintervals = 3000, method_choice = 4, diagnostic_output = 1,
+        error_control = 1, singular_term = nothing)
+    BVPM2(max_num_subintervals::Int, method_choice::Int, diagnostic_output::Int,
+        error_control::Int, singular_term)
+
+Fortran code for solving two-point boundary value problems. For detailed documentation, see
+[ODEInterface.jl](https://github.com/luchr/ODEInterface.jl/blob/master/doc/SolverOptions.md#bvpm2).
+
+!!! warning
+    Only supports inplace two-point boundary value problems, with very limited forms of
+    input structures!
+
+!!! note
+    Only available in julia 1.9+ and if the `ODEInterface` package is loaded.
+"""
+Base.@kwdef struct BVPM2{S} <: BoundaryValueDiffEqAlgorithm
+    max_num_subintervals::Int = 3000
+    method_choice::Int = 4
+    diagnostic_output::Int = -1
+    error_control::Int = 1
+    singular_term::S = nothing
+end
+
+"""
+    BVPSOL(; bvpclass = 2, sol_method = 0, odesolver = nothing)
+    BVPSOL(bvpclass::Int, sol_methods::Int, odesolver)
+
+A FORTRAN77 code which solves highly nonlinear two point boundary value problems using a
+local linear solver (condensing algorithm) or a global sparse linear solver for the solution
+of the arising linear subproblems, by Peter Deuflhard, Georg Bader, Lutz Weimann.
+For detailed documentation, see
+[ODEInterface.jl](https://github.com/luchr/ODEInterface.jl/blob/master/doc/SolverOptions.md#bvpsol).
+
+!!! warning
+    Only supports inplace two-point boundary value problems, with very limited forms of
+    input structures!
+
+!!! note
+    Only available in julia 1.9+ and if the `ODEInterface` package is loaded.
+"""
+Base.@kwdef struct BVPSOL{O} <: BoundaryValueDiffEqAlgorithm
+    bvpclass::Int = 2
+    sol_method::Int = 0
+    odesolver::O = nothing
+end