Skip to content

Commit

Permalink
Merge pull request #2314 from AayushSabharwal/as/param-init
Browse files Browse the repository at this point in the history
feat: support parameter updates in `initialize_dae!`
  • Loading branch information
ChrisRackauckas authored Oct 8, 2024
2 parents 457611c + fef8e0f commit 8fa520f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
6 changes: 4 additions & 2 deletions lib/OrdinaryDiffEqCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "1.6.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand Down Expand Up @@ -42,10 +43,11 @@ OrdinaryDiffEqCoreEnzymeCoreExt = "EnzymeCore"

[compat]
ADTypes = "0.2, 1"
Accessors = "0.1.36"
Adapt = "3.0, 4"
ArrayInterface = "7"
DataStructures = "0.18"
DiffEqBase = "6.147"
DiffEqBase = "6.157"
DiffEqDevTools = "2.44.4"
DocStringExtensions = "0.9"
EnumX = "1"
Expand All @@ -65,7 +67,7 @@ Random = "<0.0.1, 1"
RecursiveArrayTools = "2.36, 3"
Reexport = "1.0"
SafeTestsets = "0.1.0"
SciMLBase = "2.50.4"
SciMLBase = "2.56"
SciMLOperators = "0.3"
SciMLStructures = "1"
SimpleUnPack = "1"
Expand Down
1 change: 1 addition & 0 deletions lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ import DiffEqBase: calculate_residuals,
import Polyester
using MacroTools, Adapt
import ADTypes: AutoFiniteDiff, AutoForwardDiff
import Accessors: @reset

using SciMLStructures: canonicalize, Tunable, isscimlstructure

Expand Down
10 changes: 10 additions & 0 deletions lib/OrdinaryDiffEqCore/src/initialize_dae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
initializeprob = prob.f.initializeprob

if SciMLBase.has_update_initializeprob!(prob.f)
prob.f.update_initializeprob!(initializeprob, prob)
end

# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
# Since then it's the case of not a DAE but has initializeprob
# In which case, it should be differentiable
Expand All @@ -173,6 +177,12 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
else
error("Unreachable reached. Report this error.")
end
if SciMLBase.has_initializeprobpmap(prob.f)
integrator.p = prob.f.initializeprobpmap(prob, nlsol)
sol = integrator.sol
@reset sol.prob.p = integrator.p
integrator.sol = sol
end

if nlsol.retcode != ReturnCode.Success
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
Expand Down

0 comments on commit 8fa520f

Please sign in to comment.