diff --git a/src/solve/single_shooting.jl b/src/solve/single_shooting.jl index 327d052a..18e11eed 100644 --- a/src/solve/single_shooting.jl +++ b/src/solve/single_shooting.jl @@ -1,10 +1,9 @@ -# TODO: Support Non-Vector Inputs function SciMLBase.__solve(prob::BVProblem, alg::Shooting; odesolve_kwargs = (;), nlsolve_kwargs = (;), kwargs...) iip, bc, u0, u0_size = isinplace(prob), prob.bc, deepcopy(prob.u0), size(prob.u0) + resid_size = prob.f.bcresid_prototype === nothing ? u0_size : + size(prob.f.bcresid_prototype) loss_fn = if iip - resid_size = prob.f.bcresid_prototype === nothing ? u0_size : - size(prob.f.bcresid_prototype) function loss!(resid, u0_, p) u0_internal = reshape(u0_, u0_size) tmp_prob = ODEProblem{iip}(prob.f, u0_internal, prob.tspan, p) diff --git a/test/ensemble.jl b/test/mirk/ensemble.jl similarity index 100% rename from test/ensemble.jl rename to test/mirk/ensemble.jl diff --git a/test/mirk_convergence_tests.jl b/test/mirk/mirk_convergence_tests.jl similarity index 100% rename from test/mirk_convergence_tests.jl rename to test/mirk/mirk_convergence_tests.jl diff --git a/test/vectorofvector_initials.jl b/test/mirk/vectorofvector_initials.jl similarity index 100% rename from test/vectorofvector_initials.jl rename to test/mirk/vectorofvector_initials.jl diff --git a/test/non_vector_inputs.jl b/test/misc/non_vector_inputs.jl similarity index 100% rename from test/non_vector_inputs.jl rename to test/misc/non_vector_inputs.jl diff --git a/test/odeinterface_ex7.jl b/test/misc/odeinterface_ex7.jl similarity index 100% rename from test/odeinterface_ex7.jl rename to test/misc/odeinterface_ex7.jl diff --git a/test/misc/type_stability.jl b/test/misc/type_stability.jl new file mode 100644 index 00000000..a035f4af --- /dev/null +++ b/test/misc/type_stability.jl @@ -0,0 +1,51 @@ +using BoundaryValueDiffEq, OrdinaryDiffEq, LinearAlgebra, Test + +f(u, p, t) = [p[1] * u[1] - p[2] * u[1] * u[2], p[3] * u[1] * u[2] - p[4] * u[2]] +function f!(du, u, p, t) + du[1] = p[1] * u[1] - p[2] * u[1] * u[2] + du[2] = p[3] * u[1] * u[2] - p[4] * u[2] +end + +bc(sol, p, t) = [sol[1][1] - 1, sol[end][2] - 2] +function bc!(res, sol, p, t) + res[1] = sol[1][1] - 1 + res[2] = sol[end][2] - 2 +end +twobc((ua, ub), p) = ([ua[1] - 1], [ub[2] - 2]) +function twobc!((resa, resb), (ua, ub), p) + resa[1] = ua[1] - 1 + resb[1] = ub[2] - 2 +end + +u0 = Float64[0, 0] +tspan = (0.0, 1.0) +p = [1.0, 1.0, 1.0, 1.0] +bcresid_prototype = (zeros(1), zeros(1)) + +# Multi-Point BVP +mpbvp_iip = BVProblem(f!, bc!, u0, tspan, p) +mpbvp_oop = BVProblem(f, bc, u0, tspan, p) + +@inferred solve(mpbvp_iip, Shooting(Tsit5())) +@inferred solve(mpbvp_oop, Shooting(Tsit5())) +@inferred solve(mpbvp_iip, MultipleShooting(5, Tsit5())) +@inferred solve(mpbvp_oop, MultipleShooting(5, Tsit5())) + +for solver in (MIRK2(), MIRK3(), MIRK4(), MIRK5(), MIRK6()) + @inferred solve(mpbvp_iip, solver; dt = 0.2) + @inferred solve(mpbvp_oop, solver; dt = 0.2) +end + +# Two-Point BVP +tpbvp_iip = TwoPointBVProblem(f!, twobc!, u0, tspan, p; bcresid_prototype) +tpbvp_oop = TwoPointBVProblem(f, twobc, u0, tspan, p) + +@inferred solve(tpbvp_iip, Shooting(Tsit5())) +@inferred solve(tpbvp_oop, Shooting(Tsit5())) +@inferred solve(tpbvp_iip, MultipleShooting(5, Tsit5())) +@inferred solve(tpbvp_oop, MultipleShooting(5, Tsit5())) + +for solver in (MIRK2(), MIRK3(), MIRK4(), MIRK5(), MIRK6()) + @inferred solve(tpbvp_iip, solver; dt = 0.2) + @inferred solve(tpbvp_oop, solver; dt = 0.2) +end diff --git a/test/runtests.jl b/test/runtests.jl index af8b3bbd..0d404a56 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,34 +3,36 @@ using Test, SafeTestsets @testset "Boundary Value Problem Tests" begin @time @testset "Shooting Method Tests" begin @time @safetestset "Shooting Tests" begin - include("shooting_tests.jl") + include("shooting/shooting_tests.jl") end @time @safetestset "Orbital" begin - include("orbital.jl") + include("shooting/orbital.jl") end end @time @testset "Collocation Method (MIRK) Tests" begin @time @safetestset "Ensemble" begin - include("ensemble.jl") + include("mirk/ensemble.jl") end @time @safetestset "MIRK Convergence Tests" begin - include("mirk_convergence_tests.jl") + include("mirk/mirk_convergence_tests.jl") end @time @safetestset "Vector of Vector" begin - include("vectorofvector_initials.jl") + include("mirk/vectorofvector_initials.jl") end end - @time @testset "ODE Interface Solvers" begin - @time @safetestset "ODE Interface Tests" begin - include("odeinterface_ex7.jl") + @time @testset "Miscelleneous" begin + @time @safetestset "Non Vector Inputs" begin + include("misc/non_vector_inputs.jl") end - end - @time @testset "Non Vector Inputs Tests" begin - @time @safetestset "Non Vector Inputs" begin - include("non_vector_inputs.jl") + @time @safetestset "Type Stability" begin + include("misc/type_stability.jl") + end + + @time @safetestset "ODE Interface Tests" begin + include("misc/odeinterface_ex7.jl") end end end diff --git a/test/orbital.jl b/test/shooting/orbital.jl similarity index 100% rename from test/orbital.jl rename to test/shooting/orbital.jl diff --git a/test/shooting_tests.jl b/test/shooting/shooting_tests.jl similarity index 100% rename from test/shooting_tests.jl rename to test/shooting/shooting_tests.jl