diff --git a/ext/BoundaryValueDiffEqODEInterfaceExt.jl b/ext/BoundaryValueDiffEqODEInterfaceExt.jl index 55fba78a..268bbb55 100644 --- a/ext/BoundaryValueDiffEqODEInterfaceExt.jl +++ b/ext/BoundaryValueDiffEqODEInterfaceExt.jl @@ -17,15 +17,6 @@ end #------ # BVPM2 #------ -_no_param(::SciMLBase.NullParameters) = Float64[] -_no_param(p) = p - -bvpm2_bc(bc, ya, yb, bca, bcb) = bc((bca, bcb), (ya, yb), SciMLBase.NullParameters()) -bvpm2_bc(bc, ya, yb, p, bca, bcb) = bc((bca, bcb), (ya, yb), p) - -bvp2m_f(f, t, u, du) = f(du, u, SciMLBase.NullParameters(), t) -bvp2m_f(f, t, u, p, du) = f(du, u, p, t) - ## TODO: We can specify Drhs using forwarddiff if we want to function SciMLBase.__solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3, kwargs...) _test_bvpm2_bvpsol_problem_criteria(prob, prob.problem_type, :BVPM2) @@ -43,23 +34,27 @@ function SciMLBase.__solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3, no_left_bc = length(first(prob.f.bcresid_prototype.x)) initial_guess = Bvpm2() - bvpm2_init(initial_guess, no_odes, no_left_bc, mesh, u0, _no_param(prob.p), + bvpm2_init(initial_guess, no_odes, no_left_bc, mesh, u0, eltype(u0)[], alg.max_num_subintervals) - rhs = (args...) -> bvp2m_f(prob.f, args...) - bc = (args...) -> bvpm2_bc(prob.bc, args...) + bvp2m_f(t, u, du) = prob.f(du, u, prob.p, t) + bvp2m_bc(ya, yb, bca, bcb) = prob.bc((bca, bcb), (ya, yb), prob.p) opt = OptionsODE(OPT_RTOL => reltol, OPT_METHODCHOICE => alg.method_choice, OPT_DIAGNOSTICOUTPUT => alg.diagnostic_output, OPT_SINGULARTERM => alg.singular_term, OPT_ERRORCONTROL => alg.error_control) - sol, retcode, stats = bvpm2_solve(initial_guess, rhs, bc, opt) + sol, retcode, stats = bvpm2_solve(initial_guess, bvp2m_f, bvp2m_bc, opt) retcode = retcode ≥ 0 ? ReturnCode.Success : ReturnCode.Failure - bvpm2_destroy(initial_guess) x_mesh = bvpm2_get_x(sol) - return DiffEqBase.build_solution(prob, alg, x_mesh, eachcol(evalSolution(sol, x_mesh)); - retcode, stats) + sol_final = DiffEqBase.build_solution(prob, alg, x_mesh, + eachcol(evalSolution(sol, x_mesh)); retcode, stats) + + bvpm2_destroy(initial_guess) + bvpm2_destroy(sol_final) + + return sol_final end #-------