From 8924e951ef3af10f4440aa159222c8347260f4a1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 28 Sep 2023 13:28:35 -0400 Subject: [PATCH] Fix Dispatches if Differential Equations is loaded --- src/BoundaryValueDiffEq.jl | 12 ++++++------ src/solve/mirk.jl | 4 ++-- src/solve/single_shooting.jl | 1 - 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/BoundaryValueDiffEq.jl b/src/BoundaryValueDiffEq.jl index a4513606..157e6f69 100644 --- a/src/BoundaryValueDiffEq.jl +++ b/src/BoundaryValueDiffEq.jl @@ -15,12 +15,6 @@ import SparseDiffTools: AbstractSparseADType import TruncatedStacktraces: @truncate_stacktrace import UnPack: @unpack -function SciMLBase.__solve(prob::BVProblem, alg; kwargs...) - # If dispatch not directly defined - cache = init(prob, alg; kwargs...) - return solve!(cache) -end - include("types.jl") include("utils.jl") include("algorithms.jl") @@ -34,6 +28,12 @@ include("solve/mirk.jl") include("adaptivity.jl") include("interpolation.jl") +function SciMLBase.__solve(prob::BVProblem, alg::BoundaryValueDiffEqAlgorithm, args...; + kwargs...) + cache = init(prob, alg, args...; kwargs...) + return solve!(cache) +end + export Shooting export MIRK2, MIRK3, MIRK4, MIRK5, MIRK6 export MIRKJacobianComputationAlgorithm diff --git a/src/solve/mirk.jl b/src/solve/mirk.jl index 581db9ad..65978c4e 100644 --- a/src/solve/mirk.jl +++ b/src/solve/mirk.jl @@ -1,5 +1,5 @@ -function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, abstol = 1e-3, - adaptive = true, kwargs...) +function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, + abstol = 1e-3, adaptive = true, kwargs...) has_initial_guess = prob.u0 isa AbstractVector{<:AbstractArray} iip = isinplace(prob) (T, M, n) = if has_initial_guess diff --git a/src/solve/single_shooting.jl b/src/solve/single_shooting.jl index b0c7c94a..8b8cc9ab 100644 --- a/src/solve/single_shooting.jl +++ b/src/solve/single_shooting.jl @@ -1,5 +1,4 @@ # TODO: Differentiate between nlsolve kwargs and odesolve kwargs -# TODO: Add in u0/p into `__solve`: Needed for differentiation # TODO: Support Non-Vector Inputs function SciMLBase.__solve(prob::BVProblem, alg::Shooting; kwargs...) iip = isinplace(prob)