Skip to content

Commit

Permalink
test: test implementation of SII parameter timeseries interface
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 21, 2024
1 parent ee65ce1 commit 2eb3a80
Showing 1 changed file with 222 additions and 21 deletions.
243 changes: 222 additions & 21 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,30 +371,231 @@ sol = solve(prob, Tsit5())
end

dt = 0.1
@variables x(t) y(t) u(t) yd(t) ud(t) r(t)
@parameters kp

eqs = [yd ~ Sample(t, dt)(y)
ud ~ kp * (r - yd)
r ~ 1.0
dt2 = 0.2
@variables x(t)=0 y(t)=0 u(t)=0 yd1(t)=0 ud1(t)=0 yd2(t)=0 ud2(t)=0
@parameters kp=1 r=2

eqs = [
# controller (time discrete part `dt=0.1`)
yd1 ~ Sample(t, dt)(y)
ud1 ~ kp * (r - yd1)
# controller (time discrete part `dt=0.2`)
yd2 ~ Sample(t, dt2)(y)
ud2 ~ kp * (r - yd2)

# plant (time continuous part)
u ~ Hold(ud)
u ~ Hold(ud1) + Hold(ud2)
D(x) ~ -x + u
y ~ x]

@mtkbuild sys = ODESystem(eqs, t)
prob = ODEProblem(sys, [x => 1.0], (0.0, 10.0), [kp => 10.0, Hold(ud) => 1.0])
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
@test sol.discretes isa DiffEqArray
@test length(sol.discretes) == 101

getter = getp(sol, Hold(ud))
@test sol(0.1, idxs = Hold(ud)) == sol(0.1 + eps(0.1), idxs = Hold(ud)) ==
sol(0.2 - eps(0.2), idxs = Hold(ud)) == getter(parameter_values_at_time(sol, 2))
@test sol.ps[Hold(ud)] == getter(sol.prob.p)
@test sol.ps[Hold(ud), :] isa Vector{Float64}
@test length(sol.ps[Hold(ud), :]) == 101
for i in 1:100
@test sol.ps[Hold(ud), i] == getter(parameter_values_at_time(sol, i))
@mtkbuild cl = ODESystem(eqs, t)
prob = ODEProblem(cl, [x => 0.0], (0.0, 1.0), [kp => 1.0])
sol = solve(remake(prob), Tsit5())

kpidx = parameter_index(cl, kp)
kpval = 1.0
ud1idx = parameter_index(cl, Hold(ud1))
ud1val = [val[ud1idx.parameter_idx] for val in sol.discretes[ud1idx.timeseries_idx].u]
ud2idx = parameter_index(cl, Hold(ud2))
ud2val = [val[ud2idx.parameter_idx] for val in sol.discretes[ud2idx.timeseries_idx].u]
ridx = parameter_index(cl, r)
rval = 2.0

for (sym, val, buffer, check_inference) in [
(kp, kpval, nothing, true),
(kpidx, kpval, nothing, true),
([kp, r], [kpval, rval], zeros(2), true),
((kp, r), (kpval, rval), zeros(2), true),
([kpidx, ridx], [kpval, rval], zeros(2), true),
((kpidx, ridx), (kpval, rval), zeros(2), true),
([kp, ridx], [kpval, rval], zeros(2), true),
((kp, ridx), (kpval, rval), zeros(2), true),
([kp, Hold(ud1)], [kpval, ud1val[end]], zeros(2), true),
((kp, Hold(ud1)), (kpval, ud1val[end]), zeros(2), true),
([kpidx, Hold(ud1)], [kpval, ud1val[end]], zeros(2), true),
((kpidx, Hold(ud1)), (kpval, ud1val[end]), zeros(2), true),
# not technically valid, but need to test getp behavior
([Hold(ud1) + Hold(ud2), kp], [ud1val[end] + ud2val[end], kpval], zeros(2), true),
((Hold(ud1) + Hold(ud2), kp), (ud1val[end] + ud2val[end], kpval), zeros(2), true),
]
getter = getp(sys, sym)
if check_inference
@inferred getter(sol)
@inferred getter(prob)
end
@test getter(sol) == val
@test getter(prob) == val
if buffer !== nothing
getter(buffer, sol)
@test buffer == collect(val)
buffer .= 0
getter(buffer, prob)
@test buffer == collect(val)
if check_inference
@inferred getter(buffer, sol)
@inferred getter(buffer, prob)
end
end
end

int = init(sol.prob.p, Tsit5())
step!(int, 0.1, true)
ud1obsval = vcat(ud1val[2:end], int.ps[Hold(ud1)])

for (sym, val, buffer, check_prob, check_inference) in [
(Hold(ud1), ud1val, zeros(length(ud1val)), true, true),
(Hold(ud2), ud2val, zeros(length(ud2val)), true, true),
(ud1idx, ud1val, zeros(length(ud1val)), false, true),
([Hold(ud1), Hold(ud1)], vcat.(ud1val, ud1val), map(_ -> zeros(2), ud1val), true, true),
((Hold(ud1), Hold(ud1)), tuple.(ud1val, ud1val), map(_ -> zeros(2), ud1val), true, true),
([ud2idx, Hold(ud2)], vcat.(ud2val, ud2val), map(_ -> zeros(2), ud2val), false, true),
((ud2idx, Hold(ud2)), tuple.(ud2val, ud2val), map(_ -> zeros(2), ud2val), false, true),
([ud2idx, ud2idx], vcat.(ud2val, ud2val), map(_ -> zeros(2), ud2val), false, true),
((ud2idx, ud2idx), tuple.(ud2val, ud2val), map(_ -> zeros(2), ud2val), false, true),
(2Shift(t, 1)(ud1), 2ud1obsval, zeros(length(ud1val)), true, true),
([2Shift(t, 1)(ud1), 2Hold(ud1)], vcat.(2ud1obsval, 2ud1val), map(_ -> zeros(2), ud1val), true, true),
((2Shift(t, 1)(ud1), 2Hold(ud1)), tuple.(2ud1obsval, 2ud1val), map(_ -> zeros(2), ud1val), true, true),
]
getter = getp(sys, sym)
if check_inference
@inferred getter(sol)
if check_prob
@inferred getter(prob)
end
end
@test getter(sol) == val
if check_prob
@test getter(prob) == val[1]
end
buf = deepcopy(buffer)
if check_inference
@inferred getter(deepcopy(buffer), sol)
if check_prob
@inferred getter(deepcopy(buffer[1]), prob)
end
end
getter(buf, sol)
@test buf == collect.(val)
if check_prob
buf = deepcopy(buffer[1])
getter(buf, prob)
@test buf == collect(val[1])
end

for subidx in [rand(eachindex(val)), CartesianIndex(rand(eachindex(val))), :, rand(Bool, length(val)), rand(eachindex(val), 5), 2:5]
if check_inference
@inferred getter(sol, subidx)
end
@test getter(sol, subidx) == collect.(val[subidx])
buf = deepcopy(buffer[subidx])
getter(buf, sol, subidx)
@test buf == collect.(val[subidx])
end
end

for sym in [[kp, ud1idx], (kp, ud1idx), [kpidx, ud1idx], (kpidx, ud1idx), [ud1idx, Hold(ud2)], (ud1idx, Hold(ud2)), [ud1idx, ud2idx], (ud1idx, ud2idx)]
@test_throws ArgumentError getp(sys, sym)
end

for (sym, val) in [
([Hold(ud1), Hold(ud2)], [ud1val[1], ud2val[1]]),
((Hold(ud1), Hold(ud2)), (ud1val[1], ud2val[1])),
]
getter = getp(sys, sym)
@test_throws SymbolicIndexingInterface.MixedParameterTimeseriesIndexError getter(sol)
@test getter(prob) == val
end

xval = getindex.(sol.u)
ud1ts = sol.discretes[ud1idx.timeseries_idx]
ud1val_state = [ud1ts.u[searchsortedlast(ud1ts.t, t)][] for t in sol.t]
ud2ts = sol.discretes[ud2idx.timeseries_idx]
ud2val_state = [ud2ts.u[searchsortedlast(ud2ts.t, t)][] for t in sol.t]

for (sym, val_is_timeseries, val, check_inference) in [
(kp, false, kpval, true),
([kp, r], false, [kpval, rval], true),
((kp, r), false, (kpval, rval), true),
(Hold(ud1), true, ud1val_state, true),
([kp, Hold(ud1)], true, vcat.(kpval, ud1val_state), true),
((kp, Hold(ud1)), true, tuple.(kpval, ud1val_state), true),
([Hold(ud1), Hold(ud2)], true, vcat.(ud1val_state, ud2val_state), true),
((Hold(ud1), Hold(ud2)), true, tuple.(ud1val_state, ud2val_state), true),
([kp, Hold(ud1), Hold(ud2)], true, vcat.(kpval, ud1val_state, ud2val_state), true),
((kp, Hold(ud1), Hold(ud2)), true, tuple.(kpval, ud1val_state, ud2val_state), true),
([x, Hold(ud1)], true, vcat.(xval, ud1val_state), true),
((x, Hold(ud1)), true, tuple.(xval, ud1val_state), true),
([x, Hold(ud1), Hold(ud2)], true, vcat.(xval, ud1val_state, ud2val_state), true),
((x, Hold(ud1), Hold(ud2)), true, tuple.(xval, ud1val_state, ud2val_state), true),
([x, Hold(ud1), Hold(ud2), kp], true, vcat.(xval, ud1val_state, ud2val_state, kpval), true),
((x, Hold(ud1), Hold(ud2), kp), true, tuple.(xval, ud1val_state, ud2val_state, kpval), true),
(2Hold(ud1), true, 2ud1val_state, true),
([x, 2Hold(ud1), 3Hold(ud2)], true, vcat.(xval, 2ud1val_state, 3ud2val_state), true),
((x, 2Hold(ud1), 3Hold(ud2)), true, tuple.(xval, 2ud1val_state, 3ud2val_state), true),
]
getter = getu(sys, sym)
if check_inference
@inferred getter(sol)
@inferred sol[sym]
end
@test getter(sol) == val
@test sol[sym] == val

for subidx in [rand(eachindex(sol.t)), CartesianIndex(rand(eachindex(sol.t))), :, rand(Bool, length(sol.t)), rand(eachindex(sol.t), 5), 3:6]
if check_inference
@inferred getter(sol, subidx)
@inferred sol[sym, subidx]
end

target = if val_is_timeseries
val[subidx]
else
if fs.t[subidx] isa AbstractArray
len = length(fs.t[subidx])
fill(val, len)
else
val
end
end
@test getter(sol, subidx) == target
@test sol[sym, subidx] == target
end
end

@parameters σ ρ β
@variables x y z

eqs = [0 ~ σ * (y - x),
0 ~ x *- z) - y,
0 ~ x * y - β * z]
@mtkbuild ns = NonlinearSystem(eqs)
u0 = [x => 1.0,
y => 0.0,
z => 0.0]

p ==> 28.0,
ρ => 10.0,
β => 8 / 3]
prob = NonlinearProblem(ns, u0, p)
sol = solve(prob)

for (sym, val, check_inference) in [
(σ, 28.0, true),
([σ, ρ], [28.0, 10.0], true),
((σ, ρ), (28.0, 10.0), true),
+ ρ, 38.0, true),
([σ + ρ, ρ + β], [38.0, 10 + 8/3], true),
((σ + ρ, ρ + β), (38.0, 10 + 8/3), true),
]
getter = getp(sys, sym)
if check_inference
@inferred getter(sol)
end
@test getter(sol) == val

if sym isa Union{Array, Tuple}
buffer = zeros(length(sym))
@inferred getter(buffer, sol)
@test buffer == collect(val)
end
end

0 comments on commit 2eb3a80

Please sign in to comment.