Skip to content

Commit

Permalink
Merge pull request #866 from AayushSabharwal/as/hotfix-initialization
Browse files Browse the repository at this point in the history
fix: fix `initializeprobpmap` call in `OverrideInit`
  • Loading branch information
AayushSabharwal authored Nov 21, 2024
2 parents b44504b + cfdb2f8 commit b85639c
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 62 deletions.
4 changes: 2 additions & 2 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
"""
initializeprobmap::IProbMap
"""
A function which takes the solution of `initializeprob` and returns
A function which takes `value_provider` and the solution of `initializeprob` and returns
the parameter object of the original problem. If absent (`nothing`),
this will not be called and the parameters of the problem being
initialized will be returned as-is.
Expand Down Expand Up @@ -210,7 +210,7 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,

u0 = initdata.initializeprobmap(nlsol)
if initdata.initializeprobpmap !== nothing
p = initdata.initializeprobpmap(nlsol)
p = initdata.initializeprobpmap(valp, nlsol)
end

return u0, p, SciMLBase.successful_retcode(nlsol)
Expand Down
8 changes: 4 additions & 4 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Expand All @@ -22,6 +23,7 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -31,12 +33,10 @@ DelayDiffEq = "5"
DiffEqCallbacks = "3, 4"
ForwardDiff = "0.10"
JumpProcesses = "9.10"
ModelingToolkit = "9"
ModelingToolkit = "9.52"
ModelingToolkitStandardLibrary = "2.7"
NonlinearSolve = "2, 3, 4"
Optimization = "3"
OptimizationMOI = "0.4"
OptimizationOptimJL = "0.1, 0.2, 0.3"
Optimization = "4"
OrdinaryDiffEq = "6.33"
Plots = "1.40"
RecursiveArrayTools = "3"
Expand Down
2 changes: 1 addition & 1 deletion test/downstream/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ gs_ts, = Zygote.gradient(sol) do sol
sum(sum.(sol[[lorenz1.x, lorenz2.x], :]))
end

@test all(map(x -> x == true_grad_vecsym, gs_ts))
@test_broken all(map(x -> x == true_grad_vecsym, gs_ts))

# BatchedInterface AD
@variables x(t)=1.0 y(t)=1.0 z(t)=1.0 w(t)=1.0
Expand Down
53 changes: 26 additions & 27 deletions test/downstream/initialization.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,9 @@
using OrdinaryDiffEq, Sundials, SciMLBase, Test
using ModelingToolkit, NonlinearSolve, OrdinaryDiffEq, Sundials, SciMLBase, Test
using SymbolicIndexingInterface
using ModelingToolkit: t_nounits as t, D_nounits as D

@testset "CheckInit" begin
abstol = 1e-10
@testset "Sundials + ODEProblem" begin
function rhs(u, p, t)
return [u[1] * t, u[1]^2 - u[2]^2]
end
function rhs!(du, u, p, t)
du[1] = u[1] * t
du[2] = u[1]^2 - u[2]^2
end

oopfn = ODEFunction{false}(rhs, mass_matrix = [1 0; 0 0])
iipfn = ODEFunction{true}(rhs!, mass_matrix = [1 0; 0 0])

@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0))
integ = init(prob, Sundials.ARKODE())
u0, _, success = SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)
@test success
@test u0 == prob.u0

integ.u[2] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)
end
end

@testset "Sundials + DAEProblem" begin
function daerhs(du, u, p, t)
return [du[1] - u[1] * t - p, u[1]^2 - u[2]^2]
Expand Down Expand Up @@ -59,3 +35,26 @@ using OrdinaryDiffEq, Sundials, SciMLBase, Test
end
end
end

@testset "OverrideInit with MTK" begin
abstol = 1e-10
reltol = 1e-8

@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
@parameters p=missing [guess = 1.0] q=missing [guess = 1.0]
@mtkbuild sys = ODESystem([D(x) ~ p * y + q * t, D(y) ~ 5x + q], t;
initialization_eqs = [p^2 + q^2 ~ 3, x^3 + y^3 ~ 5])
prob = ODEProblem(
sys, [x => 1.0], (0.0, 1.0), [p => 1.0]; initializealg = SciMLBase.NoInit())

@test prob.f.initialization_data isa SciMLBase.OverrideInitData
integ = init(prob, Tsit5())
u0, pobj, success = SciMLBase.get_initial_values(
prob, integ, prob.f, SciMLBase.OverrideInit(), Val(true);
nlsolve_alg = NewtonRaphson(), abstol, reltol)

@test getu(sys, x)(u0) 1.0
@test getu(sys, y)(u0) cbrt(4)
@test getp(sys, p)(pobj) 1.0
@test getp(sys, q)(pobj) sqrt(2)
end
8 changes: 4 additions & 4 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,12 @@ eprob = EnsembleProblem(oprob)
@test eprob.ps[osys.p] == 0.1

@test state_values(remake(eprob; u0 = [X => 0.1])) == [0.1]
@test state_values(remake(eprob; u0 = [:X => 0.2])) == [0.2]
@test_broken state_values(remake(eprob; u0 = [:X => 0.2])) == [0.2]
@test state_values(remake(eprob; u0 = [osys.X => 0.3])) == [0.3]

@test remake(eprob; p = [d => 0.4]).ps[d] == 0.4
@test remake(eprob; p = [:d => 0.5]).ps[d] == 0.5
@test remake(eprob; p = [osys.d => 0.6]).ps[d] == 0.6
@test_broken remake(eprob; p = [d => 0.4]).ps[d] == 0.4
@test_broken remake(eprob; p = [:d => 0.5]).ps[d] == 0.5
@test_broken remake(eprob; p = [osys.d => 0.6]).ps[d] == 0.6

# SteadyStateProblem Indexing
# Issue#660
Expand Down
28 changes: 5 additions & 23 deletions test/downstream/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,13 @@ end
ss1.state_map == ss2.state_map
end

ode_sol = solve(prob, Tsit5(); save_idxs = xidx)
subsys = SciMLBase.SavedSubsystem(sys, prob.p, [xidx])
@test SciMLBase.get_saved_state_idxs(subsys) == [xidx]

# FIXME: hack for save_idxs
SciMLBase.@reset ode_sol.saved_subsystem = subsys
ode_sol = solve(prob, Tsit5(); save_idxs = [x])

@mtkbuild sys = ODESystem([D(x) ~ x + p * y, 1 ~ sin(y) + cos(x)], t)
xidx = variable_index(sys, x)
prob = DAEProblem(sys, [D(x) => x + p * y, D(y) => 1 / sqrt(1 - (1 - cos(x))^2)],
[x => 1.0, y => asin(1 - cos(x))], (0.0, 1.0), [p => 2.0])
dae_sol = solve(prob, DFBDF(); save_idxs = xidx)
subsys = SciMLBase.SavedSubsystem(sys, prob.p, [xidx])
# FIXME: hack for save_idxs
SciMLBase.@reset dae_sol.saved_subsystem = subsys
[x => 1.0, y => asin(1 - cos(x))], (0.0, 1.0), [p => 2.0]; build_initializeprob = false)
dae_sol = solve(prob, DFBDF(); save_idxs = [x])

@brownian a b
@mtkbuild sys = System([D(x) ~ x + p * y + x * a, D(y) ~ 2p + x^2 + y * b], t)
Expand Down Expand Up @@ -256,21 +248,11 @@ end

@test SciMLBase.SavedSubsystem(sys, prob.p, [x, y, q, r, s, u]) === nothing

sol = solve(prob; save_idxs = xidx)
sol = solve(prob; save_idxs = [x, q, r])
xvals = sol[x]
subsys = SciMLBase.SavedSubsystem(sys, prob.p, [x, q, r])
@test SciMLBase.get_saved_state_idxs(subsys) == [xidx]
@test SciMLBase.get_saved_state_idxs(sol.saved_subsystem) == [xidx]
qvals = sol.ps[q]
rvals = sol.ps[r]
# FIXME: hack for save_idxs
SciMLBase.@reset sol.saved_subsystem = subsys
discq = DiffEqArray(SciMLBase.TupleOfArraysWrapper.(tuple.(Base.vect.(qvals))),
sol.discretes[qpidx.timeseries_idx].t, (1, 1))
discr = DiffEqArray(SciMLBase.TupleOfArraysWrapper.(tuple.(Base.vect.(rvals))),
sol.discretes[rpidx.timeseries_idx].t, (1, 1))
SciMLBase.@reset sol.discretes.collection[qpidx.timeseries_idx] = discq
SciMLBase.@reset sol.discretes.collection[rpidx.timeseries_idx] = discr

@test sol[x] == xvals

@test all(Base.Fix1(is_parameter, sol), [p, q, r, s, u])
Expand Down
2 changes: 1 addition & 1 deletion test/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ end
initprobmap = function (nlsol)
return [parameter_values(nlsol)[1], nlsol.u[1]]
end
initprobpmap = function (nlsol)
initprobpmap = function (_, nlsol)
return nlsol.u[2]
end
initialization_data = SciMLBase.OverrideInitData(
Expand Down

0 comments on commit b85639c

Please sign in to comment.