diff --git a/lib/OrdinaryDiffEqCore/Project.toml b/lib/OrdinaryDiffEqCore/Project.toml index 274351cdde..44ad3c80d3 100644 --- a/lib/OrdinaryDiffEqCore/Project.toml +++ b/lib/OrdinaryDiffEqCore/Project.toml @@ -5,6 +5,7 @@ version = "1.6.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" @@ -42,6 +43,7 @@ OrdinaryDiffEqCoreEnzymeCoreExt = "EnzymeCore" [compat] ADTypes = "0.2, 1" +Accessors = "0.1.36" Adapt = "3.0, 4" ArrayInterface = "7" DataStructures = "0.18" diff --git a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl index d3ae689ec0..9276a467f9 100644 --- a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl +++ b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl @@ -70,6 +70,7 @@ import DiffEqBase: calculate_residuals, import Polyester using MacroTools, Adapt import ADTypes: AutoFiniteDiff, AutoForwardDiff +import Accessors: @reset using SciMLStructures: canonicalize, Tunable, isscimlstructure diff --git a/lib/OrdinaryDiffEqCore/src/initialize_dae.jl b/lib/OrdinaryDiffEqCore/src/initialize_dae.jl index c352cb3ddf..e88deeef1a 100644 --- a/lib/OrdinaryDiffEqCore/src/initialize_dae.jl +++ b/lib/OrdinaryDiffEqCore/src/initialize_dae.jl @@ -153,6 +153,10 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem}, alg::OverrideInit, isinplace::Union{Val{true}, Val{false}}) initializeprob = prob.f.initializeprob + if SciMLBase.has_update_initializeprob!(prob.f) + prob.f.update_initializeprob!(initializeprob, prob) + end + # If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit # Since then it's the case of not a DAE but has initializeprob # In which case, it should be differentiable @@ -173,6 +177,12 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem}, else error("Unreachable reached. Report this error.") end + if SciMLBase.has_initializeprobpmap(prob.f) + integrator.p = prob.f.initializeprobpmap(prob, nlsol) + sol = integrator.sol + @reset sol.prob.p = integrator.p + integrator.sol = sol + end if nlsol.retcode != ReturnCode.Success integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,