Skip to content

Commit

Permalink
Add type stability tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 23, 2023
1 parent 95d6205 commit 77e10cb
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 15 deletions.
5 changes: 2 additions & 3 deletions src/solve/single_shooting.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# TODO: Add in u0/p into `__solve`: Needed for differentiation
# TODO: Support Non-Vector Inputs
function SciMLBase.__solve(prob::BVProblem, alg::Shooting; odesolve_kwargs = (;),
nlsolve_kwargs = (;), kwargs...)
iip, bc, u0, u0_size = isinplace(prob), prob.bc, deepcopy(prob.u0), size(prob.u0)
resid_size = prob.f.bcresid_prototype === nothing ? u0_size :
size(prob.f.bcresid_prototype)
loss_fn = if iip
resid_size = prob.f.bcresid_prototype === nothing ? u0_size :
size(prob.f.bcresid_prototype)
function loss!(resid, u0_, p)
u0_internal = reshape(u0_, u0_size)
tmp_prob = ODEProblem{iip}(prob.f, u0_internal, prob.tspan, p)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
51 changes: 51 additions & 0 deletions test/misc/type_stability.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
using BoundaryValueDiffEq, OrdinaryDiffEq, LinearAlgebra, Test

f(u, p, t) = [p[1] * u[1] - p[2] * u[1] * u[2], p[3] * u[1] * u[2] - p[4] * u[2]]
function f!(du, u, p, t)
du[1] = p[1] * u[1] - p[2] * u[1] * u[2]
du[2] = p[3] * u[1] * u[2] - p[4] * u[2]
end

bc(sol, p, t) = [sol[1][1] - 1, sol[end][2] - 2]
function bc!(res, sol, p, t)
res[1] = sol[1][1] - 1
res[2] = sol[end][2] - 2
end
twobc((ua, ub), p) = ([ua[1] - 1], [ub[2] - 2])
function twobc!((resa, resb), (ua, ub), p)
resa[1] = ua[1] - 1
resb[1] = ub[2] - 2
end

u0 = Float64[0, 0]
tspan = (0.0, 1.0)
p = [1.0, 1.0, 1.0, 1.0]
bcresid_prototype = (zeros(1), zeros(1))

# Multi-Point BVP
mpbvp_iip = BVProblem(f!, bc!, u0, tspan, p)
mpbvp_oop = BVProblem(f, bc, u0, tspan, p)

@inferred solve(mpbvp_iip, Shooting(Tsit5()))
@inferred solve(mpbvp_oop, Shooting(Tsit5()))
@inferred solve(mpbvp_iip, MultipleShooting(5, Tsit5()))
@inferred solve(mpbvp_oop, MultipleShooting(5, Tsit5()))

for solver in (MIRK2(), MIRK3(), MIRK4(), MIRK5(), MIRK6())
@inferred solve(mpbvp_iip, solver; dt = 0.2)
@inferred solve(mpbvp_oop, solver; dt = 0.2)
end

# Two-Point BVP
tpbvp_iip = TwoPointBVProblem(f!, twobc!, u0, tspan, p; bcresid_prototype)
tpbvp_oop = TwoPointBVProblem(f, twobc, u0, tspan, p)

@inferred solve(tpbvp_iip, Shooting(Tsit5()))
@inferred solve(tpbvp_oop, Shooting(Tsit5()))
@inferred solve(tpbvp_iip, MultipleShooting(5, Tsit5()))
@inferred solve(tpbvp_oop, MultipleShooting(5, Tsit5()))

for solver in (MIRK2(), MIRK3(), MIRK4(), MIRK5(), MIRK6())
@inferred solve(tpbvp_iip, solver; dt = 0.2)
@inferred solve(tpbvp_oop, solver; dt = 0.2)
end
26 changes: 14 additions & 12 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,36 @@ using Test, SafeTestsets
@testset "Boundary Value Problem Tests" begin
@time @testset "Shooting Method Tests" begin
@time @safetestset "Shooting Tests" begin
include("shooting_tests.jl")
include("shooting/shooting_tests.jl")
end
@time @safetestset "Orbital" begin
include("orbital.jl")
include("shooting/orbital.jl")
end
end

@time @testset "Collocation Method (MIRK) Tests" begin
@time @safetestset "Ensemble" begin
include("ensemble.jl")
include("mirk/ensemble.jl")
end
@time @safetestset "MIRK Convergence Tests" begin
include("mirk_convergence_tests.jl")
include("mirk/mirk_convergence_tests.jl")
end
@time @safetestset "Vector of Vector" begin
include("vectorofvector_initials.jl")
include("mirk/vectorofvector_initials.jl")
end
end

@time @testset "ODE Interface Solvers" begin
@time @safetestset "ODE Interface Tests" begin
include("odeinterface_ex7.jl")
@time @testset "Miscelleneous" begin
@time @safetestset "Non Vector Inputs" begin
include("misc/non_vector_inputs.jl")
end
end

@time @testset "Non Vector Inputs Tests" begin
@time @safetestset "Non Vector Inputs" begin
include("non_vector_inputs.jl")
@time @safetestset "Type Stability" begin
include("misc/type_stability.jl")
end

@time @safetestset "ODE Interface Tests" begin
include("misc/odeinterface_ex7.jl")
end
end
end
File renamed without changes.
File renamed without changes.

0 comments on commit 77e10cb

Please sign in to comment.