diff --git a/src/initialize_dae.jl b/src/initialize_dae.jl index 5d38c3945d..0067eeaaaf 100644 --- a/src/initialize_dae.jl +++ b/src/initialize_dae.jl @@ -71,6 +71,12 @@ function DiffEqBase.initialize_dae!(integrator::ODEIntegrator, Val(DiffEqBase.isinplace(integrator.sol.prob))) end +function DiffEqBase.initialize_dae!(integrator::DiffEqBase.NullODEIntegrator) + _initialize_dae!(integrator, integrator.sol.prob, + OverrideInit(), + Val(DiffEqBase.isinplace(integrator.sol.prob))) +end + ## Default algorithms function _initialize_dae!(integrator, prob::ODEProblem, @@ -135,11 +141,15 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem}, alg::OverrideInit, isinplace::Union{Val{true}, Val{false}}) initializeprob = prob.f.initializeprob if initializeprob.f.sys !== nothing && prob.f.sys !== nothing - initu0vars = variable_symbols(initializeprob) - initu0order = variable_index.((initializeprob,), initu0vars) - # Variable symbols are not guaranteed to be in order - invpermute!(initu0vars, initu0order) - initu0 = getu(prob.f.initializeprob, initu0vars)(prob) + if initializeprob.u0 === nothing || isempty(initializeprob.u0) + initu0 = Float64[] + else + initu0vars = variable_symbols(initializeprob) + initu0order = variable_index.((initializeprob,), initu0vars) + # Variable symbols are not guaranteed to be in order + invpermute!(initu0vars, initu0order) + initu0 = getu(prob.f.initializeprob, initu0vars)(prob) + end initp = remake_buffer(initializeprob, parameter_values(initializeprob), Dict(sym => getu(prob, sym)(prob) for sym in parameter_symbols(initializeprob))) initializeprob = remake(initializeprob; u0 = initu0, p = initp) @@ -147,15 +157,30 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem}, # If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit # Since then it's the case of not a DAE but has initializeprob # In which case, it should be differentiable - isAD = has_autodiff(integrator.alg) ? alg_autodiff(integrator.alg) isa AutoForwardDiff : - true + isAD = if !isdefined(integrator, :alg) + false + elseif has_autodiff(integrator.alg) + alg_autodiff(integrator.alg) isa AutoForwardDiff + else + true + end alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD) nlsol = solve(initializeprob, alg) if isinplace === Val{true}() - integrator.u .= prob.f.initializeprobmap(nlsol) + if prob.u0 !== nothing && !isempty(prob.u0) + integrator.u .= prob.f.initializeprobmap(nlsol) + end + if SciMLBase.has_initializeprob_updatep(prob.f) + prob.f.initializeprob_updatep!(integrator.p, nlsol) + end elseif isinplace === Val{false}() - integrator.u = prob.f.initializeprobmap(nlsol) + if prob.u0 !== nothing && !isempty(prob.u0) + integrator.u .= prob.f.initializeprobmap(nlsol) + end + if SciMLBase.has_initializeprob_updatep(prob.f) + prob.f.initializeprob_updatep!(integrator.p, nlsol) + end else error("Unreachable reached. Report this error.") end