Skip to content

Commit

Permalink
Make it consistent with out parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 21, 2023
1 parent ff583d1 commit ba61f36
Showing 1 changed file with 11 additions and 16 deletions.
27 changes: 11 additions & 16 deletions ext/BoundaryValueDiffEqODEInterfaceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,6 @@ end
#------
# BVPM2
#------
_no_param(::SciMLBase.NullParameters) = Float64[]
_no_param(p) = p

bvpm2_bc(bc, ya, yb, bca, bcb) = bc((bca, bcb), (ya, yb), SciMLBase.NullParameters())
bvpm2_bc(bc, ya, yb, p, bca, bcb) = bc((bca, bcb), (ya, yb), p)

bvp2m_f(f, t, u, du) = f(du, u, SciMLBase.NullParameters(), t)
bvp2m_f(f, t, u, p, du) = f(du, u, p, t)

## TODO: We can specify Drhs using forwarddiff if we want to
function SciMLBase.__solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3, kwargs...)
_test_bvpm2_bvpsol_problem_criteria(prob, prob.problem_type, :BVPM2)
Expand All @@ -43,23 +34,27 @@ function SciMLBase.__solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3,
no_left_bc = length(first(prob.f.bcresid_prototype.x))

initial_guess = Bvpm2()
bvpm2_init(initial_guess, no_odes, no_left_bc, mesh, u0, _no_param(prob.p),
bvpm2_init(initial_guess, no_odes, no_left_bc, mesh, u0, eltype(u0)[],
alg.max_num_subintervals)

rhs = (args...) -> bvp2m_f(prob.f, args...)
bc = (args...) -> bvpm2_bc(prob.bc, args...)
bvp2m_f(t, u, du) = prob.f(du, u, prob.p, t)
bvp2m_bc(ya, yb, bca, bcb) = prob.bc((bca, bcb), (ya, yb), prob.p)

opt = OptionsODE(OPT_RTOL => reltol, OPT_METHODCHOICE => alg.method_choice,
OPT_DIAGNOSTICOUTPUT => alg.diagnostic_output,
OPT_SINGULARTERM => alg.singular_term, OPT_ERRORCONTROL => alg.error_control)

sol, retcode, stats = bvpm2_solve(initial_guess, rhs, bc, opt)
sol, retcode, stats = bvpm2_solve(initial_guess, bvp2m_f, bvp2m_bc, opt)
retcode = retcode 0 ? ReturnCode.Success : ReturnCode.Failure
bvpm2_destroy(initial_guess)

x_mesh = bvpm2_get_x(sol)
return DiffEqBase.build_solution(prob, alg, x_mesh, eachcol(evalSolution(sol, x_mesh));
retcode, stats)
sol_final = DiffEqBase.build_solution(prob, alg, x_mesh,
eachcol(evalSolution(sol, x_mesh)); retcode, stats)

bvpm2_destroy(initial_guess)
bvpm2_destroy(sol_final)

return sol_final
end

#-------
Expand Down

0 comments on commit ba61f36

Please sign in to comment.