Skip to content

Commit

Permalink
Add tests for ODEInterface integration and remove compat with old ver…
Browse files Browse the repository at this point in the history
…sions
  • Loading branch information
Avik Pal committed Sep 21, 2023
1 parent ba61f36 commit 23d5f81
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 50 deletions.
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ jobs:
- Core
version:
- '1'
- '1.6'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ ForwardDiff = "0.10"
NonlinearSolve = "2"
ODEInterface = "0.5"
Reexport = "0.2, 1.0"
SciMLBase = "1"
SciMLBase = "2"
Setfield = "1"
TruncatedStacktraces = "1"
UnPack = "1"
julia = "1.6"
julia = "1.9"

[weakdeps]
ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
Expand Down
44 changes: 33 additions & 11 deletions ext/BoundaryValueDiffEqODEInterfaceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,16 @@ function SciMLBase.__solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3,
eachcol(evalSolution(sol, x_mesh)); retcode, stats)

bvpm2_destroy(initial_guess)
bvpm2_destroy(sol_final)
bvpm2_destroy(sol)

return sol_final
end

#-------
# BVPSOL
#-------
bvpsol_f(f, t, u, du) = f(du, u, SciMLBase.NullParameters(), t)
function bvpsol_bc(bc, ra, rb, ya, yb, r)
bc((view(r, 1:(length(ra))), view(r, (length(ra) + 1):(length(ra) + length(rb)))),
(ya, yb), SciMLBase.NullParameters())
end

function SciMLBase.__solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000, reltol = 1e-3,
dt = 0.0, kwargs...)
dt = 0.0, verbose = true, kwargs...)
_test_bvpm2_bvpsol_problem_criteria(prob, prob.problem_type, :BVPSOL)
@assert isa(prob.p, SciMLBase.NullParameters) "BVPSOL only supports NullParameters!"
@assert isa(prob.u0, AbstractVector{<:AbstractArray}) "BVPSOL requires a vector of initial guesses!"
Expand All @@ -78,12 +72,40 @@ function SciMLBase.__solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000, reltol
OPT_BVPCLASS => alg.bvpclass, OPT_SOLMETHOD => alg.sol_method,
OPT_RHS_CALLMODE => RHS_CALL_INSITU)

f! = (args...) -> bvpsol_f(prob.f, args...)
bc! = (args...) -> bvpsol_bc(prob.bc, first(prob.f.bcresid_prototype.x),
last(prob.f.bcresid_prototype.x), args...)
f!(t, u, du) = prob.f(du, u, prob.p, t)
function bc!(ya, yb, r)
ra = first(prob.f.bcresid_prototype.x)
rb = last(prob.f.bcresid_prototype.x)
prob.bc((ra, rb), (ya, yb), prob.p)
r[1:length(ra)] .= ra
r[(length(ra) + 1):(length(ra) + length(rb))] .= rb
return r
end

sol_t, sol_x, retcode, stats = bvpsol(f!, bc!, mesh, u0, alg.odesolver, opt)

if verbose
if retcode == -3
@error "Integrator failed to complete the trajectory"
elseif retcode == -4
@error "Gauss Newton method failed to converge"
elseif retcode == -5
@error "Given initial values inconsistent with separable linear bc"
elseif retcode == -6
@error """Iterative refinement faild to converge for `sol_method=0`
Termination since multiple shooting condition or
condition of Jacobian is too bad for `sol_method=1`"""
elseif retcode == -8
@error "Condensing algorithm for linear block system fails, try `sol_method=1`"
elseif retcode == -9
@error "Sparse linear solver failed"
elseif retcode == -10
@error "Real or integer work-space exhausted"
elseif retcode == -11
@error "Rank reduction failed - resulting rank is zero"
end
end

return DiffEqBase.build_solution(prob, alg, sol_t, eachcol(sol_x);
retcode = retcode 0 ? ReturnCode.Success : ReturnCode.Failure, stats)
end
Expand Down
8 changes: 4 additions & 4 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const DEFAULT_NLSOLVE_SHOOTING = TrustRegion(; autodiff = Val(true))
const DEFAULT_NLSOLVE_MIRK = NewtonRaphson(; autodiff = Val(true))
const DEFAULT_NLSOLVE_SHOOTING = NewtonRaphson(; autodiff = AutoForwardDiff())
const DEFAULT_NLSOLVE_MIRK = NewtonRaphson(; autodiff = AutoForwardDiff())
const DEFAULT_JACOBIAN_ALGORITHM_MIRK = MIRKJacobianComputationAlgorithm()

# Algorithms
Expand Down Expand Up @@ -65,7 +65,7 @@ Fortran code for solving two-point boundary value problems. For detailed documen
input structures!
!!! note
Only available in julia 1.9+ and if the `ODEInterface` package is loaded.
Only available if the `ODEInterface` package is loaded.
"""
Base.@kwdef struct BVPM2{S} <: BoundaryValueDiffEqAlgorithm
max_num_subintervals::Int = 3000
Expand All @@ -90,7 +90,7 @@ For detailed documentation, see
input structures!
!!! note
Only available in julia 1.9+ and if the `ODEInterface` package is loaded.
Only available if the `ODEInterface` package is loaded.
"""
Base.@kwdef struct BVPSOL{O} <: BoundaryValueDiffEqAlgorithm
bvpclass::Int = 2
Expand Down
6 changes: 2 additions & 4 deletions test/ensemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@ function bc!(residual, u, p, t)
residual[2] = u[end][1]
end

function prob_func(prob, i, repeat)
remake(prob, p = [rand()])
end
prob_func(prob, i, repeat) = remake(prob, p = [rand()])

initial_guess = [0.0, 1.0]
tspan = (0, pi / 2)
p = [rand()]
bvp = BVProblem(ode!, bc!, initial_guess, tspan, p)
ensemble_prob = EnsembleProblem(bvp, prob_func = prob_func)
ensemble_prob = EnsembleProblem(bvp; prob_func)

@testset "$(solver)" for solver in (MIRK2, MIRK3, MIRK4, MIRK5, MIRK6)
jac_algs = [MIRKJacobianComputationAlgorithm(),
Expand Down
2 changes: 1 addition & 1 deletion test/non_vector_inputs.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using BoundaryValueDiffEq, DiffEqBase, DiffEqDevTools, LinearAlgebra, OrdinaryDiffEq, Test
using BoundaryValueDiffEq, LinearAlgebra, OrdinaryDiffEq, Test

for order in (2, 3, 4, 5, 6)
s = Symbol("MIRK$(order)")
Expand Down
48 changes: 48 additions & 0 deletions test/odeinterface_ex7.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using Test, BoundaryValueDiffEq, LinearAlgebra, ODEInterface, Random

# Adaptation of https://github.com/luchr/ODEInterface.jl/blob/958b6023d1dabf775033d0b89c5401b33100bca3/examples/BasicExamples/ex7.jl
function ex7_f!(du, u, p, t)
ϵ = p[1]
u₁, λ = u
du[1] = (sin(t)^2 + λ * sin(t)^4 / u₁) / ϵ
du[2] = 0
return nothing
end

function ex7_2pbc!((resa, resb), (ua, ub), p)
resa[1] = ua[1] - 1
resb[1] = ub[1] - 1
return nothing
end

u0 = [0.5, 1.0]
p = [0.1]
tspan = (-π / 2, π / 2)

tpprob = TwoPointBVProblem(ex7_f!, ex7_2pbc!, u0, tspan, p;
bcresid_prototype = (zeros(1), zeros(1)))

@info "BVPM2"

sol_bvpm2 = solve(tpprob, BVPM2(); dt = π / 20)
@test SciMLBase.successful_retcode(sol_bvpm2)
resid_f = (Array{Float64, 1}(undef, 1), Array{Float64, 1}(undef, 1))
ex7_2pbc!(resid_f, (sol_bvpm2(tspan[1]), sol_bvpm2(tspan[2])), nothing)
@test norm(resid_f) < 1e-6

function ex7_f2!(du, u, p, t)
u₁, λ = u
du[1] = (sin(t)^2 + λ * sin(t)^4 / u₁) / 0.1
du[2] = 0
return nothing
end

@info "BVPSOL"

initial_u0 = [sol_bvpm2(t) .+ rand() for t in tspan[1]:/ 20):tspan[2]]
tpprob = TwoPointBVProblem(ex7_f2!, ex7_2pbc!, initial_u0, tspan;
bcresid_prototype = (zeros(1), zeros(1)))

# Just test that it runs. BVPSOL only works with linearly separable BCs.
# TODO: Implement appolo reentry example from ODEInterface.jl
sol_bvpsol = solve(tpprob, BVPSOL(); dt = π / 20)
5 changes: 1 addition & 4 deletions test/orbital.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# Lambert's Problem

using BoundaryValueDiffEq
using DiffEqBase, OrdinaryDiffEq, LinearAlgebra, NonlinearSolve
using Test
using BoundaryValueDiffEq, OrdinaryDiffEq, LinearAlgebra, Test

@info "Testing Lambert's Problem"

Expand Down
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ using Test, SafeTestsets
end
end

@time @testset "ODE Interface Solvers" begin
@time @safetestset "ODE Interface Tests" begin
include("odeinterface_ex7.jl")
end
end

@time @testset "Non Vector Inputs Tests" begin
@time @safetestset "Non Vector Inputs" begin
include("non_vector_inputs.jl")
Expand Down
28 changes: 13 additions & 15 deletions test/shooting_tests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using BoundaryValueDiffEq
using DiffEqBase, OrdinaryDiffEq, DiffEqDevTools
using Test, LinearAlgebra, PreallocationTools
using BoundaryValueDiffEq, LinearAlgebra, OrdinaryDiffEq, Test

@info "Shooting method"

Expand All @@ -26,15 +24,13 @@ end
bvp1 = BVProblem(f1!, bc1!, u0, tspan)
@test SciMLBase.isinplace(bvp1)
resid_f = Array{Float64}(undef, 2)
sol = solve(bvp1, Shooting(Tsit5()))
sol = solve(bvp1, Shooting(Tsit5()); abstol = 1e-6, reltol = 1e-6)
@test SciMLBase.successful_retcode(sol)
bc1!(resid_f, sol, nothing, sol.t)
@test norm(resid_f) < 1e-4
@test norm(resid_f) < 1e-6

# Out of Place
function f1(u, p, t)
return [u[2], -u[1]]
end
f1(u, p, t) = [u[2], -u[1]]

function bc1(sol, p, t)
t₀, t₁ = first(t), last(t)
Expand All @@ -46,10 +42,10 @@ end

bvp2 = BVProblem(f1, bc1, u0, tspan)
@test !SciMLBase.isinplace(bvp2)
sol = solve(bvp2, Shooting(Tsit5()))
sol = solve(bvp2, Shooting(Tsit5()); abstol = 1e-6, reltol = 1e-6)
@test SciMLBase.successful_retcode(sol)
resid_f = bc1(sol, nothing, sol.t)
@test norm(resid_f) < 1e-4
@test norm(resid_f) < 1e-6

@info "Two Point BVProblem" # Not really but using that API

Expand All @@ -63,10 +59,11 @@ end
bvp3 = TwoPointBVProblem(f1!, bc2!, u0, tspan;
bcresid_prototype = (Array{Float64}(undef, 1), Array{Float64}(undef, 1)))
@test SciMLBase.isinplace(bvp3)
sol = solve(bvp3, Shooting(Tsit5(), TrustRegion(; autodiff = false)))
sol = solve(bvp3, Shooting(Tsit5()); abstol = 1e-6, reltol = 1e-6)
@test SciMLBase.successful_retcode(sol)
resid_f = (Array{Float64, 1}(undef, 1), Array{Float64, 1}(undef, 1))
bc2!(resid_f, (sol(tspan[1]), sol(tspan[2])), nothing)
@test_broken norm(resid_f) < 1e-6
@test norm(resid_f) < 1e-4

# Out of Place
Expand All @@ -76,16 +73,17 @@ end

bvp4 = TwoPointBVProblem(f1, bc2, u0, tspan)
@test !SciMLBase.isinplace(bvp4)
sol = solve(bvp4, Shooting(Tsit5()))
sol = solve(bvp4, Shooting(Tsit5()); abstol = 1e-6, reltol = 1e-6)
@test SciMLBase.successful_retcode(sol)
resid_f = reduce(vcat, bc2((sol(tspan[1]), sol(tspan[2])), nothing))
@test norm(resid_f) < 1e-4
@test norm(resid_f) < 1e-6

#Test for complex values
u0 = [0.0, 1.0] .+ 1im
bvp = BVProblem(f1!, bc1!, u0, tspan)
resid_f = Array{ComplexF64}(undef, 2)
sol = solve(bvp, Shooting(Tsit5(); nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff())))
sol = solve(bvp, Shooting(Tsit5(); nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()));
abstol = 1e-6, reltol = 1e-6)
resid_f = Array{ComplexF64}(undef, 2)
bc1!(resid_f, sol, nothing, sol.t)
@test norm(resid_f) < 1e-4
@test norm(resid_f) < 1e-6
12 changes: 4 additions & 8 deletions test/vectorofvector_initials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ prob = ODEProblem(TC!, u0, tspan, dt = 0.01)
sol = solve(prob, Rodas4P(), reltol = 1e-12, abstol = 1e-12, saveat = 0.5)

# The BVP set up
# This is not really kind of Two-Point BVP we support. However,
# This is not really kind of Two-Point BVP we support.
function bc_po!(residual, u, p, t)
residual[1] = u[1][1] - u[end][1]
residual[2] = u[1][2] - u[end][2]
Expand All @@ -65,10 +65,6 @@ bvp1 = BVProblem(TC!, bc_po!, sol.u, tspan)
sol6 = solve(bvp1, MIRK6(); dt = 0.5)
@test SciMLBase.successful_retcode(sol6.retcode)

@static if VERSION v"1.9"
# 1.6 runs without sparsity support. This takes over 2 hrs to run there :(
# Setup to test the mesh_selector! code
bvp1 = BVProblem(TC!, bc_po!, zero(first(sol.u)), tspan)
sol6 = solve(bvp1, MIRK6(); dt = 0.1, abstol = 1e-16)
@test SciMLBase.successful_retcode(sol6.retcode)
end
bvp1 = BVProblem(TC!, bc_po!, zero(first(sol.u)), tspan)
sol6 = solve(bvp1, MIRK6(); dt = 0.1, abstol = 1e-16)
@test SciMLBase.successful_retcode(sol6.retcode)

0 comments on commit 23d5f81

Please sign in to comment.