From 5b702cab0620e79b92637929bb773ef2a2805a21 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 21 Nov 2024 17:37:24 +0530 Subject: [PATCH] fix: propagate `initializeprobpmap` and `update_initializeprob!` to `DAEFunction` --- src/systems/diffeqs/abstractodesystem.jl | 6 +++++- test/initializationsystem.jl | 17 +++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index b6d750c621..cbf835f48d 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -498,6 +498,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys) checkbounds = false, initializeprob = nothing, initializeprobmap = nothing, + initializeprobpmap = nothing, + update_initializeprob! = nothing, kwargs...) where {iip} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`") @@ -551,7 +553,9 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys) jac_prototype = jac_prototype, observed = observedfun, initializeprob = initializeprob, - initializeprobmap = initializeprobmap) + initializeprobmap = initializeprobmap, + initializeprobpmap = initializeprobpmap, + update_initializeprob! = update_initializeprob!) end function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...) diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 920a73a723..770f9f12bd 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -958,3 +958,20 @@ end @test_warn ["structurally singular", "initialization", "Guess", "heuristic"] ODEProblem( pend, [x => 1, y => 0], (0.0, 1.5), [g => 1], guesses = [λ => 1]) end + +@testset "DAEProblem initialization" begin + @variables x(t) [guess = 1.0] y(t) [guess = 1.0] + @parameters p=missing [guess = 1.0] q=missing [guess = 1.0] + @mtkbuild sys = ODESystem( + [D(x) ~ p * y + q * t, x^3 + y^3 ~ 5], t; initialization_eqs = [p^2 + q^3 ~ 3]) + + # FIXME: solve for du0 + prob = DAEProblem( + sys, [D(x) => cbrt(4), D(y) => -1 / cbrt(4)], [x => 1.0], (0.0, 1.0), [p => 1.0]) + + integ = init(prob, DImplicitEuler()) + @test integ[x] ≈ 1.0 + @test integ[y]≈cbrt(4) rtol=1e-6 + @test integ.ps[p] ≈ 1.0 + @test integ.ps[q]≈cbrt(2) rtol=1e-6 +end