diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 397f5d9b6a..a03c107cae 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -188,11 +188,11 @@ l2x_idx = variable_index(sol, lorenz2.x) l1y_idx = variable_index(sol, lorenz1.y) l2y_idx = variable_index(sol, lorenz2.y) -@test getx(sol) == sol[:, l1x_idx] -@test get_arr(sol) == sol[:, [l1x_idx, l2x_idx]] -@test get_tuple(sol) == tuple.(sol[:, l1x_idx], sol[:, l2x_idx]) -@test get_obs(sol) == sol[:, l1x_idx] + sol[:, l2x_idx] -@test get_obs_arr(sol) == vcat.(sol[:, l1x_idx] + sol[:, l2x_idx], sol[:, l1y_idx] + sol[:, l2y_idx]) +@test getx(sol) == sol[l1x_idx, :] +@test get_arr(sol) == vcat.(sol[l1x_idx, :], sol[l2x_idx, :]) +@test get_tuple(sol) == tuple.(sol[l1x_idx, :], sol[l2x_idx, :]) +@test get_obs(sol) == sol[l1x_idx, :] + sol[l2x_idx, :] +@test get_obs_arr(sol) == vcat.(sol[l1x_idx, :] + sol[l2x_idx, :], sol[l1y_idx, :] + sol[l2y_idx, :]) #= using Plots @@ -217,14 +217,109 @@ sol = solve(prob, Tsit5()) @test sol[@nonamespace sys.x] isa Vector{<:Vector} @test sol.ps[p] == [1, 2, 3] -getx = getu(sys, x) -get_mix_arr = getu(sys, [x, y]) -get_mix_tuple = getu(sys, (x, y)) x_idx = variable_index.((sys,), [x[1], x[2], x[3]]) y_idx = variable_index(sys, y) -@test getx(sol) == sol[:, x_idx] -@test get_mix_arr(sol) == vcat.(sol[:, x_idx], sol[:, y_idx]) -@test get_mix_tuple(sol) == tuple.(sol[:, x_idx], sol[:, y_idx]) +x_val = vcat.(getindex.((sol,), x_idx, :)...) +y_val = sol[y_idx, :] +obs_val = sol[x[1] + y] + +# checking inference for mixed-type arrays will always fail +for (sym, val, check_inference) in [ + (x, x_val, true), + (y, y_val, true), + (y_idx, y_val, true), + (x_idx, x_val, true), + (x[1] + y, obs_val, true), + ([x[1], x[2]], sol[[x[1], x[2]]], true), + ([x[1], x_idx[2]], sol[[x[1], x[2]]], true), + ([x, x[1] + y], [[i, j] for (i, j) in zip(x_val, obs_val)], false), + ([x, y], [[i, j] for (i, j) in zip(x_val, y_val)], false), + ([x, y_idx], [[i, j] for (i, j) in zip(x_val, y_val)], false), + ([x, x], [[i, i] for i in x_val], true), + ([x, x_idx], [[i, i] for i in x_val], false), + ((x, y), [(i, j) for (i, j) in zip(x_val, y_val)], true), + ((x, y_idx), [(i, j) for (i, j) in zip(x_val, y_val)], true), + ((x, x), [(i, i) for i in x_val], true), + ((x, x_idx), [(i, i) for i in x_val], true), + ((x, x[1]+y), [(i, j) for (i, j) in zip(x_val, obs_val)], true), + ((x, (x[1] + y, y)), [(i, (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], true), + ([x, [x[1] + y, y]], [[i, [k, j]] for (i, j, k) in zip(x_val, y_val, obs_val)], false), + ((x, [x[1] + y, y], (x[1] + y, y_idx)), [(i, [k, j], (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], missing), + ([x, [x[1] + y, y], (x[1] + y, y_idx)], [[i, [k, j], (k, j)] for (i, j, k) in zip(x_val, y_val, obs_val)], false), +] + if check_inference === missing + @test_broken @inferred getu(prob, sym)(sol) + elseif check_inference + @inferred getu(prob, sym)(sol) + end + @test getu(prob, sym)(sol) == val +end + +x_newval = [3.0, 6.0, 9.0] +y_newval = 4.0 +x_probval = prob[x] +y_probval = prob[y] + +for (sym, oldval, newval, check_inference) in [ + (x, x_probval, x_newval, true), + (y, y_probval, y_newval, true), + (x_idx, x_probval, x_newval, true), + (y_idx, y_probval, y_newval, true), + ((x, y), (x_probval, y_probval), (x_newval, y_newval), true), + ([x, y], [x_probval, y_probval], [x_newval, y_newval], false), + ((x, y_idx), (x_probval, y_probval), (x_newval, y_newval), true), + ([x, y_idx], [x_probval, y_probval], [x_newval, y_newval], false), + ((x_idx, y), (x_probval, y_probval), (x_newval, y_newval), true), + ([x_idx, y], [x_probval, y_probval], [x_newval, y_newval], false), + ([x[1:2], [y_idx, x[3]]], [x_probval[1:2], [y_probval, x_probval[3]]], [x_newval[1:2], [y_newval, x_newval[3]]], true), + ([x[1:2], (y_idx, x[3])], [x_probval[1:2], (y_probval, x_probval[3])], [x_newval[1:2], (y_newval, x_newval[3])], false), + ((x[1:2], [y_idx, x[3]]), (x_probval[1:2], [y_probval, x_probval[3]]), (x_newval[1:2], [y_newval, x_newval[3]]), true), + ((x[1:2], (y_idx, x[3])), (x_probval[1:2], (y_probval, x_probval[3])), (x_newval[1:2], (y_newval, x_newval[3])), true), +] + getter = getu(prob, sym) + setter! = setu(prob, sym) + if check_inference + @inferred getter(prob) + end + @test getter(prob) == oldval + if check_inference + @inferred setter!(prob, newval) + else + setter!(prob, newval) + end + @test getter(prob) == newval + setter!(prob, oldval) + @test getter(prob) == oldval +end + +pval = [1.0, 2.0, 3.0] +pval_new = [4.0, 5.0, 6.0] + +for (sym, oldval, newval, check_inference) in [ + (p[1], pval[1], pval_new[1], true), + (p, pval, pval_new, true), + ((p[1], p[2]), Tuple(pval[1:2]), Tuple(pval_new[1:2]), true), + ([p[1], p[2]], pval[1:2], pval_new[1:2], true), + ((p[1], p[2:3]), (pval[1], pval[2:3]), (pval_new[1], pval_new[2:3]), true), + ([p[1], p[2:3]], [pval[1], pval[2:3]], [pval_new[1], pval_new[2:3]], false), + ((p[1], (p[2],), [p[3]]), (pval[1], (pval[2],), [pval[3]]), (pval_new[1], (pval_new[2],), [pval_new[3]]), true), + ([p[1], (p[2],), [p[3]]], [pval[1], (pval[2],), [pval[3]]], [pval_new[1], (pval_new[2],), [pval_new[3]]], false), +] + getter = getp(prob, sym) + setter! = setp(prob, sym) + if check_inference + @inferred getter(prob) + end + @test getter(prob) == oldval + if check_inference + @inferred setter!(prob, newval) + else + setter!(prob, newval) + end + @test getter(prob) == newval + setter!(prob, oldval) + @test getter(prob) == oldval +end # accessing parameters @variables t x(t)