Skip to content

Commit

Permalink
feat: allow initialization of null integrators
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 28, 2024
1 parent e3b7d7a commit f0720b7
Showing 1 changed file with 34 additions and 9 deletions.
43 changes: 34 additions & 9 deletions src/initialize_dae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ function DiffEqBase.initialize_dae!(integrator::ODEIntegrator,
Val(DiffEqBase.isinplace(integrator.sol.prob)))
end

function DiffEqBase.initialize_dae!(integrator::DiffEqBase.NullODEIntegrator)
_initialize_dae!(integrator, integrator.sol.prob,
OverrideInit(),
Val(DiffEqBase.isinplace(integrator.sol.prob)))
end

## Default algorithms

function _initialize_dae!(integrator, prob::ODEProblem,
Expand Down Expand Up @@ -135,27 +141,46 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
initializeprob = prob.f.initializeprob
if initializeprob.f.sys !== nothing && prob.f.sys !== nothing
initu0vars = variable_symbols(initializeprob)
initu0order = variable_index.((initializeprob,), initu0vars)
# Variable symbols are not guaranteed to be in order
invpermute!(initu0vars, initu0order)
initu0 = getu(prob.f.initializeprob, initu0vars)(prob)
if initializeprob.u0 === nothing || isempty(initializeprob.u0)
initu0 = Float64[]
else
initu0vars = variable_symbols(initializeprob)
initu0order = variable_index.((initializeprob,), initu0vars)
# Variable symbols are not guaranteed to be in order
invpermute!(initu0vars, initu0order)
initu0 = getu(prob.f.initializeprob, initu0vars)(prob)
end
initp = remake_buffer(initializeprob, parameter_values(initializeprob),
Dict(sym => getu(prob, sym)(prob) for sym in parameter_symbols(initializeprob)))
initializeprob = remake(initializeprob; u0 = initu0, p = initp)
end
# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
# Since then it's the case of not a DAE but has initializeprob
# In which case, it should be differentiable
isAD = has_autodiff(integrator.alg) ? alg_autodiff(integrator.alg) isa AutoForwardDiff :
true
isAD = if !isdefined(integrator, :alg)
false
elseif has_autodiff(integrator.alg)
alg_autodiff(integrator.alg) isa AutoForwardDiff
else
true
end

alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
nlsol = solve(initializeprob, alg)
if isinplace === Val{true}()
integrator.u .= prob.f.initializeprobmap(nlsol)
if prob.u0 !== nothing && !isempty(prob.u0)
integrator.u .= prob.f.initializeprobmap(nlsol)
end
if SciMLBase.has_initializeprob_updatep(prob.f)
prob.f.initializeprob_updatep!(integrator.p, nlsol)
end
elseif isinplace === Val{false}()
integrator.u = prob.f.initializeprobmap(nlsol)
if prob.u0 !== nothing && !isempty(prob.u0)
integrator.u .= prob.f.initializeprobmap(nlsol)
end
if SciMLBase.has_initializeprob_updatep(prob.f)
prob.f.initializeprob_updatep!(integrator.p, nlsol)
end
else
error("Unreachable reached. Report this error.")
end
Expand Down

0 comments on commit f0720b7

Please sign in to comment.