diff --git a/src/initialize_dae.jl b/src/initialize_dae.jl index 947e3aa5..50b52b1e 100644 --- a/src/initialize_dae.jl +++ b/src/initialize_dae.jl @@ -1,13 +1,13 @@ struct SDEDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end -function DiffEqBase.initialize_dae!(integrator::SDEIntegrator, initializealg = integrator.initializealg) +function DiffEqBase.initialize_dae!(integrator::AbstractSDEIntegrator, initializealg = integrator.initializealg) OrdinaryDiffEqCore._initialize_dae!(integrator, integrator.sol.prob, initializealg, Val(DiffEqBase.isinplace(integrator.sol.prob))) end -function OrdinaryDiffEqCore._initialize_dae!(integrator::SDEIntegrator, prob, ::SDEDefaultInit, isinplace) +function OrdinaryDiffEqCore._initialize_dae!(integrator::AbstractSDEIntegrator, prob, ::SDEDefaultInit, isinplace) if SciMLBase.has_initializeprob(prob.f) OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace) - else + elseif prob.f.mass_matrix isa Tuple && any(mm != I for mm in prob.f.mass_matrix) || prob.f.mass_matrix != I OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.CheckInit(), isinplace) end end