diff --git a/src/initialization.jl b/src/initialization.jl index 778843382..605300a61 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -23,7 +23,7 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap} """ initializeprobmap::IProbMap """ - A function which takes the solution of `initializeprob` and returns + A function which takes `value_provider` and the solution of `initializeprob` and returns the parameter object of the original problem. If absent (`nothing`), this will not be called and the parameters of the problem being initialized will be returned as-is. @@ -210,7 +210,7 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, u0 = initdata.initializeprobmap(nlsol) if initdata.initializeprobpmap !== nothing - p = initdata.initializeprobpmap(nlsol) + p = initdata.initializeprobpmap(valp, nlsol) end return u0, p, SciMLBase.successful_retcode(nlsol) diff --git a/test/downstream/initialization.jl b/test/downstream/initialization.jl index a7d5ee671..6ff062811 100644 --- a/test/downstream/initialization.jl +++ b/test/downstream/initialization.jl @@ -1,4 +1,4 @@ -using OrdinaryDiffEq, Sundials, SciMLBase, Test +using ModelingToolkit, NonlinearSolve, OrdinaryDiffEq, Sundials, SciMLBase, Test @testset "CheckInit" begin abstol = 1e-10 @@ -59,3 +59,26 @@ using OrdinaryDiffEq, Sundials, SciMLBase, Test end end end + +@testset "OverrideInit with MTK" begin + abstol = 1e-10 + reltol = 1e-8 + + @variables x(t) [guess = 1.0] y(t) [guess = 1.0] + @parameters p=missing [guess = 1.0] q=missing [guess = 1.0] + @mtkbuild sys = ODESystem([D(x) ~ p * y + q * t, D(y) ~ 5x + q], t; + initialization_eqs = [p^2 + q^2 ~ 3, x^3 + y^3 ~ 5]) + prob = ODEProblem( + sys, [x => 1.0], (0.0, 1.0), [p => 1.0]; initializealg = SciMLBase.NoInit()) + + @test prob.f.initialization_data isa SciMLBase.OverrideInitData + integ = init(prob, Tsit5()) + u0, pobj, success = SciMLBase.get_initial_values( + prob, integ, prob.f, SciMLBase.OverrideInit(), Val(true); + nlsolve_alg = NewtonRaphson(), abstol, reltol) + + @test getu(sys, x)(u0) ≈ 1.0 + @test getu(sys, y)(u0) ≈ cbrt(4) + @test getp(sys, p)(pobj) ≈ 1.0 + @test getp(sys, q)(pobj) ≈ sqrt(2) +end