From 2eb3a806c2a757f6d9957a51ed14a7bd6681bde2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 21 May 2024 11:45:48 +0530 Subject: [PATCH] test: test implementation of SII parameter timeseries interface --- test/downstream/symbol_indexing.jl | 243 ++++++++++++++++++++++++++--- 1 file changed, 222 insertions(+), 21 deletions(-) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 1b1e82a36..817cfb7c3 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -371,30 +371,231 @@ sol = solve(prob, Tsit5()) end dt = 0.1 -@variables x(t) y(t) u(t) yd(t) ud(t) r(t) -@parameters kp - -eqs = [yd ~ Sample(t, dt)(y) - ud ~ kp * (r - yd) - r ~ 1.0 +dt2 = 0.2 +@variables x(t)=0 y(t)=0 u(t)=0 yd1(t)=0 ud1(t)=0 yd2(t)=0 ud2(t)=0 +@parameters kp=1 r=2 + +eqs = [ + # controller (time discrete part `dt=0.1`) + yd1 ~ Sample(t, dt)(y) + ud1 ~ kp * (r - yd1) + # controller (time discrete part `dt=0.2`) + yd2 ~ Sample(t, dt2)(y) + ud2 ~ kp * (r - yd2) # plant (time continuous part) - u ~ Hold(ud) + u ~ Hold(ud1) + Hold(ud2) D(x) ~ -x + u y ~ x] -@mtkbuild sys = ODESystem(eqs, t) -prob = ODEProblem(sys, [x => 1.0], (0.0, 10.0), [kp => 10.0, Hold(ud) => 1.0]) -sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent) -@test sol.discretes isa DiffEqArray -@test length(sol.discretes) == 101 - -getter = getp(sol, Hold(ud)) -@test sol(0.1, idxs = Hold(ud)) == sol(0.1 + eps(0.1), idxs = Hold(ud)) == - sol(0.2 - eps(0.2), idxs = Hold(ud)) == getter(parameter_values_at_time(sol, 2)) -@test sol.ps[Hold(ud)] == getter(sol.prob.p) -@test sol.ps[Hold(ud), :] isa Vector{Float64} -@test length(sol.ps[Hold(ud), :]) == 101 -for i in 1:100 - @test sol.ps[Hold(ud), i] == getter(parameter_values_at_time(sol, i)) +@mtkbuild cl = ODESystem(eqs, t) +prob = ODEProblem(cl, [x => 0.0], (0.0, 1.0), [kp => 1.0]) +sol = solve(remake(prob), Tsit5()) + +kpidx = parameter_index(cl, kp) +kpval = 1.0 +ud1idx = parameter_index(cl, Hold(ud1)) +ud1val = [val[ud1idx.parameter_idx] for val in sol.discretes[ud1idx.timeseries_idx].u] +ud2idx = parameter_index(cl, Hold(ud2)) +ud2val = [val[ud2idx.parameter_idx] for val in sol.discretes[ud2idx.timeseries_idx].u] +ridx = parameter_index(cl, r) +rval = 2.0 + +for (sym, val, buffer, check_inference) in [ + (kp, kpval, nothing, true), + (kpidx, kpval, nothing, true), + ([kp, r], [kpval, rval], zeros(2), true), + ((kp, r), (kpval, rval), zeros(2), true), + ([kpidx, ridx], [kpval, rval], zeros(2), true), + ((kpidx, ridx), (kpval, rval), zeros(2), true), + ([kp, ridx], [kpval, rval], zeros(2), true), + ((kp, ridx), (kpval, rval), zeros(2), true), + ([kp, Hold(ud1)], [kpval, ud1val[end]], zeros(2), true), + ((kp, Hold(ud1)), (kpval, ud1val[end]), zeros(2), true), + ([kpidx, Hold(ud1)], [kpval, ud1val[end]], zeros(2), true), + ((kpidx, Hold(ud1)), (kpval, ud1val[end]), zeros(2), true), + # not technically valid, but need to test getp behavior + ([Hold(ud1) + Hold(ud2), kp], [ud1val[end] + ud2val[end], kpval], zeros(2), true), + ((Hold(ud1) + Hold(ud2), kp), (ud1val[end] + ud2val[end], kpval), zeros(2), true), + ] + getter = getp(sys, sym) + if check_inference + @inferred getter(sol) + @inferred getter(prob) + end + @test getter(sol) == val + @test getter(prob) == val + if buffer !== nothing + getter(buffer, sol) + @test buffer == collect(val) + buffer .= 0 + getter(buffer, prob) + @test buffer == collect(val) + if check_inference + @inferred getter(buffer, sol) + @inferred getter(buffer, prob) + end + end +end + +int = init(sol.prob.p, Tsit5()) +step!(int, 0.1, true) +ud1obsval = vcat(ud1val[2:end], int.ps[Hold(ud1)]) + +for (sym, val, buffer, check_prob, check_inference) in [ + (Hold(ud1), ud1val, zeros(length(ud1val)), true, true), + (Hold(ud2), ud2val, zeros(length(ud2val)), true, true), + (ud1idx, ud1val, zeros(length(ud1val)), false, true), + ([Hold(ud1), Hold(ud1)], vcat.(ud1val, ud1val), map(_ -> zeros(2), ud1val), true, true), + ((Hold(ud1), Hold(ud1)), tuple.(ud1val, ud1val), map(_ -> zeros(2), ud1val), true, true), + ([ud2idx, Hold(ud2)], vcat.(ud2val, ud2val), map(_ -> zeros(2), ud2val), false, true), + ((ud2idx, Hold(ud2)), tuple.(ud2val, ud2val), map(_ -> zeros(2), ud2val), false, true), + ([ud2idx, ud2idx], vcat.(ud2val, ud2val), map(_ -> zeros(2), ud2val), false, true), + ((ud2idx, ud2idx), tuple.(ud2val, ud2val), map(_ -> zeros(2), ud2val), false, true), + (2Shift(t, 1)(ud1), 2ud1obsval, zeros(length(ud1val)), true, true), + ([2Shift(t, 1)(ud1), 2Hold(ud1)], vcat.(2ud1obsval, 2ud1val), map(_ -> zeros(2), ud1val), true, true), + ((2Shift(t, 1)(ud1), 2Hold(ud1)), tuple.(2ud1obsval, 2ud1val), map(_ -> zeros(2), ud1val), true, true), + ] + getter = getp(sys, sym) + if check_inference + @inferred getter(sol) + if check_prob + @inferred getter(prob) + end + end + @test getter(sol) == val + if check_prob + @test getter(prob) == val[1] + end + buf = deepcopy(buffer) + if check_inference + @inferred getter(deepcopy(buffer), sol) + if check_prob + @inferred getter(deepcopy(buffer[1]), prob) + end + end + getter(buf, sol) + @test buf == collect.(val) + if check_prob + buf = deepcopy(buffer[1]) + getter(buf, prob) + @test buf == collect(val[1]) + end + + for subidx in [rand(eachindex(val)), CartesianIndex(rand(eachindex(val))), :, rand(Bool, length(val)), rand(eachindex(val), 5), 2:5] + if check_inference + @inferred getter(sol, subidx) + end + @test getter(sol, subidx) == collect.(val[subidx]) + buf = deepcopy(buffer[subidx]) + getter(buf, sol, subidx) + @test buf == collect.(val[subidx]) + end +end + +for sym in [[kp, ud1idx], (kp, ud1idx), [kpidx, ud1idx], (kpidx, ud1idx), [ud1idx, Hold(ud2)], (ud1idx, Hold(ud2)), [ud1idx, ud2idx], (ud1idx, ud2idx)] + @test_throws ArgumentError getp(sys, sym) +end + +for (sym, val) in [ + ([Hold(ud1), Hold(ud2)], [ud1val[1], ud2val[1]]), + ((Hold(ud1), Hold(ud2)), (ud1val[1], ud2val[1])), + ] + getter = getp(sys, sym) + @test_throws SymbolicIndexingInterface.MixedParameterTimeseriesIndexError getter(sol) + @test getter(prob) == val +end + +xval = getindex.(sol.u) +ud1ts = sol.discretes[ud1idx.timeseries_idx] +ud1val_state = [ud1ts.u[searchsortedlast(ud1ts.t, t)][] for t in sol.t] +ud2ts = sol.discretes[ud2idx.timeseries_idx] +ud2val_state = [ud2ts.u[searchsortedlast(ud2ts.t, t)][] for t in sol.t] + +for (sym, val_is_timeseries, val, check_inference) in [ + (kp, false, kpval, true), + ([kp, r], false, [kpval, rval], true), + ((kp, r), false, (kpval, rval), true), + (Hold(ud1), true, ud1val_state, true), + ([kp, Hold(ud1)], true, vcat.(kpval, ud1val_state), true), + ((kp, Hold(ud1)), true, tuple.(kpval, ud1val_state), true), + ([Hold(ud1), Hold(ud2)], true, vcat.(ud1val_state, ud2val_state), true), + ((Hold(ud1), Hold(ud2)), true, tuple.(ud1val_state, ud2val_state), true), + ([kp, Hold(ud1), Hold(ud2)], true, vcat.(kpval, ud1val_state, ud2val_state), true), + ((kp, Hold(ud1), Hold(ud2)), true, tuple.(kpval, ud1val_state, ud2val_state), true), + ([x, Hold(ud1)], true, vcat.(xval, ud1val_state), true), + ((x, Hold(ud1)), true, tuple.(xval, ud1val_state), true), + ([x, Hold(ud1), Hold(ud2)], true, vcat.(xval, ud1val_state, ud2val_state), true), + ((x, Hold(ud1), Hold(ud2)), true, tuple.(xval, ud1val_state, ud2val_state), true), + ([x, Hold(ud1), Hold(ud2), kp], true, vcat.(xval, ud1val_state, ud2val_state, kpval), true), + ((x, Hold(ud1), Hold(ud2), kp), true, tuple.(xval, ud1val_state, ud2val_state, kpval), true), + (2Hold(ud1), true, 2ud1val_state, true), + ([x, 2Hold(ud1), 3Hold(ud2)], true, vcat.(xval, 2ud1val_state, 3ud2val_state), true), + ((x, 2Hold(ud1), 3Hold(ud2)), true, tuple.(xval, 2ud1val_state, 3ud2val_state), true), + ] + getter = getu(sys, sym) + if check_inference + @inferred getter(sol) + @inferred sol[sym] + end + @test getter(sol) == val + @test sol[sym] == val + + for subidx in [rand(eachindex(sol.t)), CartesianIndex(rand(eachindex(sol.t))), :, rand(Bool, length(sol.t)), rand(eachindex(sol.t), 5), 3:6] + if check_inference + @inferred getter(sol, subidx) + @inferred sol[sym, subidx] + end + + target = if val_is_timeseries + val[subidx] + else + if fs.t[subidx] isa AbstractArray + len = length(fs.t[subidx]) + fill(val, len) + else + val + end + end + @test getter(sol, subidx) == target + @test sol[sym, subidx] == target + end +end + +@parameters σ ρ β +@variables x y z + +eqs = [0 ~ σ * (y - x), + 0 ~ x * (ρ - z) - y, + 0 ~ x * y - β * z] +@mtkbuild ns = NonlinearSystem(eqs) +u0 = [x => 1.0, + y => 0.0, + z => 0.0] + +p = [σ => 28.0, + ρ => 10.0, + β => 8 / 3] +prob = NonlinearProblem(ns, u0, p) +sol = solve(prob) + +for (sym, val, check_inference) in [ + (σ, 28.0, true), + ([σ, ρ], [28.0, 10.0], true), + ((σ, ρ), (28.0, 10.0), true), + (σ + ρ, 38.0, true), + ([σ + ρ, ρ + β], [38.0, 10 + 8/3], true), + ((σ + ρ, ρ + β), (38.0, 10 + 8/3), true), + ] + getter = getp(sys, sym) + if check_inference + @inferred getter(sol) + end + @test getter(sol) == val + + if sym isa Union{Array, Tuple} + buffer = zeros(length(sym)) + @inferred getter(buffer, sol) + @test buffer == collect(val) + end end