Skip to content

Commit

Permalink
feat: support parameter updates in initialize_dae!
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 7, 2024
1 parent 457611c commit bcd324a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lib/OrdinaryDiffEqCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "1.6.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand Down Expand Up @@ -42,6 +43,7 @@ OrdinaryDiffEqCoreEnzymeCoreExt = "EnzymeCore"

[compat]
ADTypes = "0.2, 1"
Accessors = "0.1.36"
Adapt = "3.0, 4"
ArrayInterface = "7"
DataStructures = "0.18"
Expand Down
1 change: 1 addition & 0 deletions lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ import DiffEqBase: calculate_residuals,
import Polyester
using MacroTools, Adapt
import ADTypes: AutoFiniteDiff, AutoForwardDiff
import Accessors: @reset

using SciMLStructures: canonicalize, Tunable, isscimlstructure

Expand Down
10 changes: 10 additions & 0 deletions lib/OrdinaryDiffEqCore/src/initialize_dae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
initializeprob = prob.f.initializeprob

if SciMLBase.has_update_initializeprob!(prob.f)
prob.f.update_initializeprob!(initializeprob, prob)
end

# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
# Since then it's the case of not a DAE but has initializeprob
# In which case, it should be differentiable
Expand All @@ -173,6 +177,12 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
else
error("Unreachable reached. Report this error.")
end
if SciMLBase.has_initializeprobpmap(prob.f)
integrator.p = prob.f.initializeprobpmap(prob, nlsol)
sol = integrator.sol
@reset sol.prob.p = integrator.p
integrator.sol = sol
end

if nlsol.retcode != ReturnCode.Success
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
Expand Down

0 comments on commit bcd324a

Please sign in to comment.