Skip to content

Commit

Permalink
fix: fix initializeprobpmap call in OverrideInit
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 21, 2024
1 parent b44504b commit f7f5630
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
"""
initializeprobmap::IProbMap
"""
A function which takes the solution of `initializeprob` and returns
A function which takes `value_provider` and the solution of `initializeprob` and returns
the parameter object of the original problem. If absent (`nothing`),
this will not be called and the parameters of the problem being
initialized will be returned as-is.
Expand Down Expand Up @@ -210,7 +210,7 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,

u0 = initdata.initializeprobmap(nlsol)
if initdata.initializeprobpmap !== nothing
p = initdata.initializeprobpmap(nlsol)
p = initdata.initializeprobpmap(valp, nlsol)
end

return u0, p, SciMLBase.successful_retcode(nlsol)
Expand Down
25 changes: 24 additions & 1 deletion test/downstream/initialization.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using OrdinaryDiffEq, Sundials, SciMLBase, Test
using ModelingToolkit, NonlinearSolve, OrdinaryDiffEq, Sundials, SciMLBase, Test

@testset "CheckInit" begin
abstol = 1e-10
Expand Down Expand Up @@ -59,3 +59,26 @@ using OrdinaryDiffEq, Sundials, SciMLBase, Test
end
end
end

@testset "OverrideInit with MTK" begin
abstol = 1e-10
reltol = 1e-8

@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
@parameters p=missing [guess = 1.0] q=missing [guess = 1.0]
@mtkbuild sys = ODESystem([D(x) ~ p * y + q * t, D(y) ~ 5x + q], t;
initialization_eqs = [p^2 + q^2 ~ 3, x^3 + y^3 ~ 5])
prob = ODEProblem(
sys, [x => 1.0], (0.0, 1.0), [p => 1.0]; initializealg = SciMLBase.NoInit())

@test prob.f.initialization_data isa SciMLBase.OverrideInitData
integ = init(prob, Tsit5())
u0, pobj, success = SciMLBase.get_initial_values(
prob, integ, prob.f, SciMLBase.OverrideInit(), Val(true);
nlsolve_alg = NewtonRaphson(), abstol, reltol)

@test getu(sys, x)(u0) 1.0
@test getu(sys, y)(u0) cbrt(4)
@test getp(sys, p)(pobj) 1.0
@test getp(sys, q)(pobj) sqrt(2)
end

0 comments on commit f7f5630

Please sign in to comment.