Skip to content

Commit

Permalink
test: extensively test getu/setu and their type-stability
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 18, 2024
1 parent 00f3e4e commit bc4057c
Showing 1 changed file with 106 additions and 11 deletions.
117 changes: 106 additions & 11 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,11 @@ l2x_idx = variable_index(sol, lorenz2.x)
l1y_idx = variable_index(sol, lorenz1.y)
l2y_idx = variable_index(sol, lorenz2.y)

@test getx(sol) == sol[:, l1x_idx]
@test get_arr(sol) == sol[:, [l1x_idx, l2x_idx]]
@test get_tuple(sol) == tuple.(sol[:, l1x_idx], sol[:, l2x_idx])
@test get_obs(sol) == sol[:, l1x_idx] + sol[:, l2x_idx]
@test get_obs_arr(sol) == vcat.(sol[:, l1x_idx] + sol[:, l2x_idx], sol[:, l1y_idx] + sol[:, l2y_idx])
@test getx(sol) == sol[l1x_idx, :]
@test get_arr(sol) == vcat.(sol[l1x_idx, :], sol[l2x_idx, :])
@test get_tuple(sol) == tuple.(sol[l1x_idx, :], sol[l2x_idx, :])
@test get_obs(sol) == sol[l1x_idx, :] + sol[l2x_idx, :]
@test get_obs_arr(sol) == vcat.(sol[l1x_idx, :] + sol[l2x_idx, :], sol[l1y_idx, :] + sol[l2y_idx, :])

#=
using Plots
Expand All @@ -217,14 +217,109 @@ sol = solve(prob, Tsit5())
@test sol[@nonamespace sys.x] isa Vector{<:Vector}
@test sol.ps[p] == [1, 2, 3]

getx = getu(sys, x)
get_mix_arr = getu(sys, [x, y])
get_mix_tuple = getu(sys, (x, y))
x_idx = variable_index.((sys,), [x[1], x[2], x[3]])
y_idx = variable_index(sys, y)
@test getx(sol) == sol[:, x_idx]
@test get_mix_arr(sol) == vcat.(sol[:, x_idx], sol[:, y_idx])
@test get_mix_tuple(sol) == tuple.(sol[:, x_idx], sol[:, y_idx])
x_val = vcat.(getindex.((sol,), x_idx, :)...)
y_val = sol[y_idx, :]
obs_val = sol[x[1] + y]

# checking inference for mixed-type arrays will always fail
for (sym, val, check_inference) in [
(x, x_val, true),
(y, y_val, true),
(y_idx, y_val, true),
(x_idx, x_val, true),
(x[1] + y, obs_val, true),
([x[1], x[2]], sol[[x[1], x[2]]], true),
([x[1], x_idx[2]], sol[[x[1], x[2]]], true),
([x, x[1] + y], [[i, j] for (i, j) in zip(x_val, obs_val)], false),
([x, y], [[i, j] for (i, j) in zip(x_val, y_val)], false),
([x, y_idx], [[i, j] for (i, j) in zip(x_val, y_val)], false),
([x, x], [[i, i] for i in x_val], true),
([x, x_idx], [[i, i] for i in x_val], false),
((x, y), [(i, j) for (i, j) in zip(x_val, y_val)], true),
((x, y_idx), [(i, j) for (i, j) in zip(x_val, y_val)], true),
((x, x), [(i, i) for i in x_val], true),
((x, x_idx), [(i, i) for i in x_val], true),
((x, x[1]+y), [(i, j) for (i, j) in zip(x_val, obs_val)], true),
((x, (x[1] + y, y)), [(i, (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], true),
([x, [x[1] + y, y]], [[i, [k, j]] for (i, j, k) in zip(x_val, y_val, obs_val)], false),
((x, [x[1] + y, y], (x[1] + y, y_idx)), [(i, [k, j], (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], missing),
([x, [x[1] + y, y], (x[1] + y, y_idx)], [[i, [k, j], (k, j)] for (i, j, k) in zip(x_val, y_val, obs_val)], false),
]
if check_inference === missing
@test_broken @inferred getu(prob, sym)(sol)
elseif check_inference
@inferred getu(prob, sym)(sol)
end
@test getu(prob, sym)(sol) == val
end

x_newval = [3.0, 6.0, 9.0]
y_newval = 4.0
x_probval = prob[x]
y_probval = prob[y]

for (sym, oldval, newval, check_inference) in [
(x, x_probval, x_newval, true),
(y, y_probval, y_newval, true),
(x_idx, x_probval, x_newval, true),
(y_idx, y_probval, y_newval, true),
((x, y), (x_probval, y_probval), (x_newval, y_newval), true),
([x, y], [x_probval, y_probval], [x_newval, y_newval], false),
((x, y_idx), (x_probval, y_probval), (x_newval, y_newval), true),
([x, y_idx], [x_probval, y_probval], [x_newval, y_newval], false),
((x_idx, y), (x_probval, y_probval), (x_newval, y_newval), true),
([x_idx, y], [x_probval, y_probval], [x_newval, y_newval], false),
([x[1:2], [y_idx, x[3]]], [x_probval[1:2], [y_probval, x_probval[3]]], [x_newval[1:2], [y_newval, x_newval[3]]], true),
([x[1:2], (y_idx, x[3])], [x_probval[1:2], (y_probval, x_probval[3])], [x_newval[1:2], (y_newval, x_newval[3])], false),
((x[1:2], [y_idx, x[3]]), (x_probval[1:2], [y_probval, x_probval[3]]), (x_newval[1:2], [y_newval, x_newval[3]]), true),
((x[1:2], (y_idx, x[3])), (x_probval[1:2], (y_probval, x_probval[3])), (x_newval[1:2], (y_newval, x_newval[3])), true),
]
getter = getu(prob, sym)
setter! = setu(prob, sym)
if check_inference
@inferred getter(prob)
end
@test getter(prob) == oldval
if check_inference
@inferred setter!(prob, newval)
else
setter!(prob, newval)
end
@test getter(prob) == newval
setter!(prob, oldval)
@test getter(prob) == oldval
end

pval = [1.0, 2.0, 3.0]
pval_new = [4.0, 5.0, 6.0]

for (sym, oldval, newval, check_inference) in [
(p[1], pval[1], pval_new[1], true),
(p, pval, pval_new, true),
((p[1], p[2]), Tuple(pval[1:2]), Tuple(pval_new[1:2]), true),
([p[1], p[2]], pval[1:2], pval_new[1:2], true),
((p[1], p[2:3]), (pval[1], pval[2:3]), (pval_new[1], pval_new[2:3]), true),
([p[1], p[2:3]], [pval[1], pval[2:3]], [pval_new[1], pval_new[2:3]], false),
((p[1], (p[2],), [p[3]]), (pval[1], (pval[2],), [pval[3]]), (pval_new[1], (pval_new[2],), [pval_new[3]]), true),
([p[1], (p[2],), [p[3]]], [pval[1], (pval[2],), [pval[3]]], [pval_new[1], (pval_new[2],), [pval_new[3]]], false),
]
getter = getp(prob, sym)
setter! = setp(prob, sym)
if check_inference
@inferred getter(prob)
end
@test getter(prob) == oldval
if check_inference
@inferred setter!(prob, newval)
else
setter!(prob, newval)
end
@test getter(prob) == newval
setter!(prob, oldval)
@test getter(prob) == oldval
end

# accessing parameters
@variables t x(t)
Expand Down

0 comments on commit bc4057c

Please sign in to comment.