diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index 5243072..e28720b 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -2,9 +2,13 @@ parameter_values(p) Return an indexable collection containing the value of each parameter in `p`. + +If this function is called with an `AbstractArray`, it will return the same array. """ function parameter_values end +parameter_values(arr::AbstractArray) = arr + """ set_parameter!(sys, val, idx) @@ -22,8 +26,10 @@ end """ getp(sys, p) -Return a function that takes an integrator or solution of `sys`, and returns the value of -the parameter `p`. Note that `p` can be a direct numerical index or a symbolic value. +Return a function that takes an array representing the parameter vector or an integrator +or solution of `sys`, and returns the value of the parameter `p`. Note that `p` can be a +direct numerical index or a symbolic value, or an array/tuple of the aforementioned. + Requires that the integrator or solution implement [`parameter_values`](@ref). This function typically does not need to be implemented, and has a default implementation relying on [`parameter_values`](@ref). @@ -31,44 +37,49 @@ typically does not need to be implemented, and has a default implementation rely function getp(sys, p) symtype = symbolic_type(p) elsymtype = symbolic_type(eltype(p)) - if symtype != NotSymbolic() - return _getp(sys, symtype, p) - else - return _getp(sys, elsymtype, p) - end + _getp(sys, symtype, elsymtype, p) end -function _getp(sys, ::NotSymbolic, p) +function _getp(sys, ::NotSymbolic, ::NotSymbolic, p) return function getter(sol) return parameter_values(sol)[p] end end -function _getp(sys, ::ScalarSymbolic, p) +function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) return function getter(sol) return parameter_values(sol)[idx] end end -function _getp(sys, ::ScalarSymbolic, p::Union{Tuple, AbstractArray}) - idxs = parameter_index.((sys,), p) - return function getter(sol) - return getindex.((parameter_values(sol),), idxs) +for (t1, t2) in [ + (ArraySymbolic, Any), + (ScalarSymbolic, Any), + (NotSymbolic, Union{<:Tuple, <:AbstractArray}), +] + @eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2) + getters = getp.((sys,), p) + + return function getter(sol) + map(g -> g(sol), getters) + end end end -function _getp(sys, ::ArraySymbolic, p) +function _getp(sys, ::ArraySymbolic, ::NotSymbolic, p) return getp(sys, collect(p)) end """ setp(sys, p) -Return a function that takes an integrator of `sys` and a value, and sets -the parameter `p` to that value. Note that `p` can be a direct numerical index or a -symbolic value. Requires that the integrator implement [`parameter_values`](@ref) and the -returned collection be a mutable reference to the parameter vector in the integrator. In +Return a function that takes an array representing the parameter vector or an integrator +or problem of `sys`, and a value, and sets the parameter `p` to that value. Note that `p` +can be a direct numerical index or a symbolic value. + +Requires that the integrator implement [`parameter_values`](@ref) and the returned +collection be a mutable reference to the parameter vector in the integrator. In case `parameter_values` cannot return such a mutable reference, or additional actions need to be performed when updating parameters, [`set_parameter!`](@ref) must be implemented. @@ -76,33 +87,35 @@ implemented. function setp(sys, p) symtype = symbolic_type(p) elsymtype = symbolic_type(eltype(p)) - if symtype != NotSymbolic() - return _setp(sys, symtype, p) - else - return _setp(sys, elsymtype, p) - end + _setp(sys, symtype, elsymtype, p) end -function _setp(sys, ::NotSymbolic, p) +function _setp(sys, ::NotSymbolic, ::NotSymbolic, p) return function setter!(sol, val) set_parameter!(sol, val, p) end end -function _setp(sys, ::ScalarSymbolic, p) +function _setp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) return function setter!(sol, val) set_parameter!(sol, val, idx) end end -function _setp(sys, ::ScalarSymbolic, p::Union{Tuple, AbstractArray}) - idxs = parameter_index.((sys,), p) - return function setter!(sol, val) - set_parameter!.((sol,), val, idxs) +for (t1, t2) in [ + (ArraySymbolic, Any), + (ScalarSymbolic, Any), + (NotSymbolic, Union{<:Tuple, <:AbstractArray}), +] + @eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2) + setters = setp.((sys,), p) + return function setter!(sol, val) + map((s!, v) -> s!(sol, v), setters, val) + end end end -function _setp(sys, ::ArraySymbolic, p) +function _setp(sys, ::ArraySymbolic, ::NotSymbolic, p) return setp(sys, collect(p)) end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 94a02a6..d84e869 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -39,14 +39,22 @@ is_timeseries(::Type) = NotTimeseries() """ state_values(p) + state_values(p, i) Return an indexable collection containing the values of all states in the integrator or problem `p`. If `is_timeseries(p)` is [`Timeseries`](@ref), return a vector of arrays, -each of which contain the state values at the corresponding timestep. +each of which contain the state values at the corresponding timestep. In this case, the +two-argument version of the function can also be implemented to efficiently return +the state values at timestep `i`. By default, the two-argument method calls +`state_values(p)[i]` + +If this function is called with an `AbstractArray`, it will return the same array. See: [`is_timeseries`](@ref) """ function state_values end +state_values(arr::AbstractArray) = arr +state_values(arr, i) = state_values(arr)[i] """ set_state!(sys, val, idx) @@ -64,23 +72,32 @@ end """ current_time(p) + current_time(p, i) Return the current time in the integrator or problem `p`. If `is_timeseries(p)` is [`Timeseries`](@ref), return the vector of timesteps at which -the state value is saved. +the state value is saved. In this case, the two-argument version of the function can +also be implemented to efficiently return the time at timestep `i`. By default, the two- +argument method calls `current_time(p)[i]` See: [`is_timeseries`](@ref) """ function current_time end +current_time(p, i) = current_time(p)[i] + """ getu(sys, sym) Return a function that takes an integrator, problem or solution of `sys`, and returns -the value of the symbolic `sym`. `sym` can be a direct index into the state vector, a -symbolic state, a symbolic expression involving symbolic quantities in the system -`sys`, or an array/tuple of the aforementioned. +the value of the symbolic `sym`. If `sym` is not an observed quantity, the returned +function can also directly be called with an array of values representing the state +vector. `sym` can be a direct index into the state vector, a symbolic state, a symbolic +expression involving symbolic quantities in the system `sys`, a parameter symbol, or the +independent variable symbol, or an array/tuple of the aforementioned. If the returned +function is called with a timeseries object, it can also be given a second argument +representing the index at which to find the value of `sym`. At minimum, this requires that the integrator, problem or solution implement [`state_values`](@ref). To support symbolic expressions, the integrator or problem @@ -93,26 +110,41 @@ relying on the above functions. function getu(sys, sym) symtype = symbolic_type(sym) elsymtype = symbolic_type(eltype(sym)) - - if symtype != NotSymbolic() - _getu(sys, symtype, sym) - else - _getu(sys, elsymtype, sym) - end + _getu(sys, symtype, elsymtype, sym) end -function _getu(sys, ::NotSymbolic, sym) +function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym) _getter(::Timeseries, prob) = getindex.(state_values(prob), (sym,)) + _getter(::Timeseries, prob, i) = getindex(state_values(prob, i), sym) _getter(::NotTimeseries, prob) = state_values(prob)[sym] - return function getter(prob) - return _getter(is_timeseries(prob), prob) + return let _getter = _getter + function getter(prob) + return _getter(is_timeseries(prob), prob) + end + function getter(prob, i) + return _getter(is_timeseries(prob), prob, i) + end + getter end end -function _getu(sys, ::ScalarSymbolic, sym) +function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) if is_variable(sys, sym) idx = variable_index(sys, sym) return getu(sys, idx) + elseif is_parameter(sys, sym) + return let fn = getp(sys, sym) + getter(prob, args...) = fn(prob) + getter + end + elseif is_independent_variable(sys, sym) + _getter(::IsTimeseriesTrait, prob) = current_time(prob) + _getter(::Timeseries, prob, i) = current_time(prob, i) + return let _getter = _getter + getter(prob) = _getter(is_timeseries(prob), prob) + getter(prob, i) = _getter(is_timeseries(prob), prob, i) + getter + end elseif is_observed(sys, sym) fn = observed(sys, sym) if is_time_dependent(sys) @@ -121,61 +153,80 @@ function _getu(sys, ::ScalarSymbolic, sym) (parameter_values(prob),), current_time(prob)) end + function _getter2(::Timeseries, prob, i) + return fn(state_values(prob, i), + parameter_values(prob), + current_time(prob, i)) + end function _getter2(::NotTimeseries, prob) return fn(state_values(prob), parameter_values(prob), current_time(prob)) end - return function getter2(prob) - return _getter2(is_timeseries(prob), prob) + return let _getter2 = _getter2 + function getter2(prob) + return _getter2(is_timeseries(prob), prob) + end + function getter2(prob, i) + return _getter2(is_timeseries(prob), prob, i) + end + getter2 end else - function _getter3(::Timeseries, prob) - return fn.(state_values(prob), (parameter_values(prob),)) - end - function _getter3(::NotTimeseries, prob) - return fn(state_values(prob), parameter_values(prob)) - end - - return function getter3(prob) - return _getter3(is_timeseries(prob), prob) + # if there is no time, there is no timeseries + return let fn = fn + function getter3(prob) + return fn(state_values(prob), parameter_values(prob)) + end end end end error("Invalid symbol $sym for `getu`") end -struct TimeseriesIndexWrapper{T, I} - timeseries::T - idx::I -end - -state_values(t::TimeseriesIndexWrapper) = state_values(t.timeseries)[t.idx] -parameter_values(t::TimeseriesIndexWrapper) = parameter_values(t.timeseries) -current_time(t::TimeseriesIndexWrapper) = current_time(t.timeseries)[t.idx] - -function _getu(sys, ::ScalarSymbolic, sym::Union{<:Tuple, <:AbstractArray}) - getters = getu.((sys,), sym) - _call(getter, prob) = getter(prob) +for (t1, t2) in [ + (ScalarSymbolic, Any), + (ArraySymbolic, Any), + (NotSymbolic, Union{<:Tuple, <:AbstractArray}), +] + @eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2) + getters = getu.((sys,), sym) + _call(getter, args...) = getter(args...) + return let getters = getters, _call = _call + _getter(::NotTimeseries, prob) = map(g -> g(prob), getters) + function _getter(::Timeseries, prob) + broadcast(i -> map(g -> _call(g, prob, i), getters), + eachindex(state_values(prob))) + end + function _getter(::Timeseries, prob, i) + return map(g -> _call(g, prob, i), getters) + end - function _getter(::Timeseries, prob) - tiws = TimeseriesIndexWrapper.((prob,), eachindex(state_values(prob))) - return [_getter(NotTimeseries(), tiw) for tiw in tiws] - end - _getter(::NotTimeseries, prob) = _call.(getters, (prob,)) - return function getter(prob) - return _getter(is_timeseries(prob), prob) + # Need another scope for this to not box `_getter` + let _getter = _getter + function getter(prob) + return _getter(is_timeseries(prob), prob) + end + function getter(prob, i) + return _getter(is_timeseries(prob), prob, i) + end + getter + end + end end end -function _getu(sys, ::ArraySymbolic, sym) +function _getu(sys, ::ArraySymbolic, ::NotSymbolic, sym) return getu(sys, collect(sym)) end +# setu doesn't need the same `let` blocks to be inferred for some reason + """ setu(sys, sym) -Return a function that takes an integrator or problem of `sys` and a value, and sets the -the state `sym` to that value. Note that `sym` can be a direct numerical index, a symbolic state, or an array/tuple of the aforementioned. +Return a function that takes an array representing the state vector or an integrator or +problem of `sys`, and a value, and sets the the state `sym` to that value. Note that `sym` +can be a direct numerical index, a symbolic state, or an array/tuple of the aforementioned. Requires that the integrator implement [`state_values`](@ref) and the returned collection be a mutable reference to the state vector in the integrator/problem. Alternatively, if this is not possible or additional actions need to @@ -186,36 +237,40 @@ This function does not work on types for which [`is_timeseries`](@ref) is function setu(sys, sym) symtype = symbolic_type(sym) elsymtype = symbolic_type(eltype(sym)) - - if symtype != NotSymbolic() - _setu(sys, symtype, sym) - else - _setu(sys, elsymtype, sym) - end + _setu(sys, symtype, elsymtype, sym) end -function _setu(sys, ::NotSymbolic, sym) +function _setu(sys, ::NotSymbolic, ::NotSymbolic, sym) return function setter!(prob, val) set_state!(prob, val, sym) end end -function _setu(sys, ::ScalarSymbolic, sym) - is_variable(sys, sym) || error("Invalid symbol $sym for `setu`") - idx = variable_index(sys, sym) - return function setter!(prob, val) - set_state!(prob, val, idx) +function _setu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) + if is_variable(sys, sym) + idx = variable_index(sys, sym) + return function setter!(prob, val) + set_state!(prob, val, idx) + end + elseif is_parameter(sys, sym) + return setp(sys, sym) end + error("Invalid symbol $sym for `setu`") end -function _setu(sys, ::ScalarSymbolic, sym::Union{<:Tuple, <:AbstractArray}) - setters = setu.((sys,), sym) - _call!(setter!, prob, val) = setter!(prob, val) - return function setter!(prob, val) - _call!.(setters, (prob,), val) +for (t1, t2) in [ + (ScalarSymbolic, Any), + (ArraySymbolic, Any), + (NotSymbolic, Union{<:Tuple, <:AbstractArray}), +] + @eval function _setu(sys, ::NotSymbolic, ::$t1, sym::$t2) + setters = setu.((sys,), sym) + return function setter!(prob, val) + map((s!, v) -> s!(prob, v), setters, val) + end end end -function _setu(sys, ::ArraySymbolic, sym) +function _setu(sys, ::ArraySymbolic, ::NotSymbolic, sym) return setu(sys, collect(sym)) end diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 8f858e6..42d98f9 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -9,15 +9,50 @@ end SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p -sys = SymbolCache([:x, :y, :z], [:a, :b], [:t]) -p = [1.0, 2.0] +sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) +p = [1.0, 2.0, 3.0] fi = FakeIntegrator(sys, copy(p)) -for (i, sym) in [(1, :a), (2, :b), ([1, 2], [:a, :b]), ((1, 2), (:a, :b))] +new_p = [4.0, 5.0, 6.0] +for (sym, oldval, newval, check_inference) in [ + (:a, p[1], new_p[1], true), + (1, p[1], new_p[1], true), + ([:a, :b], p[1:2], new_p[1:2], true), + (1:2, p[1:2], new_p[1:2], true), + ((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true), + ([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), + ([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), + ((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), + ((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true), + ([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), + ([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), + ((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), + ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true), +] get = getp(sys, sym) set! = setp(sys, sym) - true_value = i isa Tuple ? getindex.((p,), i) : p[i] - @test get(fi) == ParameterIndexingProxy(fi)[sym] == true_value - set!(fi, 0.5 .* i) - @test get(fi) == ParameterIndexingProxy(fi)[sym] == 0.5 .* i - set!(fi, true_value) + if check_inference + @inferred get(fi) + end + @test get(fi) == oldval + if check_inference + @inferred set!(fi, newval) + else + set!(fi, newval) + end + @test get(fi) == newval + set!(fi, oldval) + @test get(fi) == oldval + + if check_inference + @inferred get(p) + end + @test get(p) == oldval + if check_inference + @inferred set!(p, newval) + else + set!(p, newval) + end + @test get(p) == newval + set!(p, oldval) + @test get(p) == oldval end diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index 7fc3a61..e47a460 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -1,44 +1,174 @@ using SymbolicIndexingInterface -struct FakeIntegrator{S, U} +struct FakeIntegrator{S, U, P, T} sys::S u::U + p::P + t::T end SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys SymbolicIndexingInterface.state_values(fp::FakeIntegrator) = fp.u +SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p +SymbolicIndexingInterface.current_time(fp::FakeIntegrator) = fp.t -sys = SymbolCache([:x, :y, :z], [:a, :b], [:t]) +sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) u = [1.0, 2.0, 3.0] -fi = FakeIntegrator(sys, copy(u)) -for (i, sym) in [(1, :x), (2, :y), (3, :z), ([1, 2], [:x, :y]), ((3, 2), (:z, :y))] +p = [11.0, 12.0, 13.0] +t = 0.5 +fi = FakeIntegrator(sys, copy(u), copy(p), t) +# checking inference for non-concretely typed arrays will always fail +for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true) + (:y, u[2], 4.0, true) + (:z, u[3], 4.0, true) + (1, u[1], 4.0, true) + ([:x, :y], u[1:2], 4ones(2), true) + ([1, 2], u[1:2], 4ones(2), true) + ((:z, :y), (u[3], u[2]), (4.0, 5.0), true) + ((3, 2), (u[3], u[2]), (4.0, 5.0), true) + ([:x, [:y, :z]], [u[1], u[2:3]], [4.0, [5.0, 6.0]], false) + ([:x, 2:3], [u[1], u[2:3]], [4.0, [5.0, 6.0]], false) + ([:x, (:y, :z)], [u[1], (u[2], u[3])], [4.0, (5.0, 6.0)], false) + ([:x, Tuple(2:3)], [u[1], (u[2], u[3])], [4.0, (5.0, 6.0)], false) + ([:x, [:y], (:z,)], [u[1], [u[2]], (u[3],)], [4.0, [5.0], (6.0,)], false) + ([:x, [:y], (3,)], [u[1], [u[2]], (u[3],)], [4.0, [5.0], (6.0,)], false) + ((:x, [:y, :z]), (u[1], u[2:3]), (4.0, [5.0, 6.0]), true) + ((:x, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true) + ((1, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true) + ((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), (4.0, [5.0], (6.0,)), true)] get = getu(sys, sym) set! = setu(sys, sym) - true_value = i isa Tuple ? getindex.((u,), i) : u[i] - @test get(fi) == true_value - set!(fi, 0.5 .* i) - @test get(fi) == 0.5 .* i - set!(fi, true_value) + if check_inference + @inferred get(fi) + end + @test get(fi) == val + if check_inference + @inferred set!(fi, newval) + else + set!(fi, newval) + end + @test get(fi) == newval + set!(fi, val) + @test get(fi) == val + + if check_inference + @inferred get(u) + end + @test get(u) == val + if check_inference + @inferred set!(u, newval) + else + set!(u, newval) + end + @test get(u) == newval + set!(u, val) + @test get(u) == val end -struct FakeSolution{S, U} +for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true) + (:b, p[2], 5.0, true) + (:c, p[3], 6.0, true) + ([:a, :b], p[1:2], [4.0, 5.0], true) + ((:c, :b), (p[3], p[2]), (6.0, 5.0), true) + ([:x, :a], [u[1], p[1]], [4.0, 5.0], false) + ((:y, :b), (u[2], p[2]), (5.0, 6.0), true)] + get = getu(fi, sym) + set! = setu(fi, sym) + if check_inference + @inferred get(fi) + end + @test get(fi) == oldval + if check_inference + @inferred set!(fi, newval) + else + set!(fi, newval) + end + @test get(fi) == newval + set!(fi, oldval) + @test get(fi) == oldval +end + +for (sym, val, check_inference) in [ + (:t, t, true), + ([:x, :a, :t], [u[1], p[1], t], false), + ((:x, :a, :t), (u[1], p[1], t), true), +] + get = getu(fi, sym) + if check_inference + @inferred get(fi) + end + @test get(fi) == val +end + +struct FakeSolution{S, U, P, T} sys::S u::U + p::P + t::T end SymbolicIndexingInterface.is_timeseries(::Type{<:FakeSolution}) = Timeseries() SymbolicIndexingInterface.symbolic_container(fp::FakeSolution) = fp.sys SymbolicIndexingInterface.state_values(fp::FakeSolution) = fp.u +SymbolicIndexingInterface.parameter_values(fp::FakeSolution) = fp.p +SymbolicIndexingInterface.current_time(fp::FakeSolution) = fp.t -sys = SymbolCache([:x, :y, :z], [:a, :b], [:t]) +sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t]) u = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] -sol = FakeSolution(sys, u) -for (i, sym) in [(1, :x), (2, :y), (3, :z), ([1, 2], [:x, :y]), ((3, 2), (:z, :y))] +t = [1.5, 2.0] +sol = FakeSolution(sys, u, p, t) + +xvals = getindex.(sol.u, 1) +yvals = getindex.(sol.u, 2) +zvals = getindex.(sol.u, 3) + +for (sym, ans, check_inference) in [(:x, xvals, true) + (:y, yvals, true) + (:z, zvals, true) + (1, xvals, true) + ([:x, :y], vcat.(xvals, yvals), true) + (1:2, vcat.(xvals, yvals), true) + ([:x, 2], vcat.(xvals, yvals), false) + ((:z, :y), tuple.(zvals, yvals), true) + ((3, 2), tuple.(zvals, yvals), true) + ([:x, [:y, :z]], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)]), false) + ([:x, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false) + ([1, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false) + ([:x, [:y, :z], (:x, :z)], + vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), + false) + ([:x, [:y, 3], (1, :z)], + vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), + false) + ((:x, [:y, :z]), tuple.(xvals, vcat.(yvals, zvals)), true) + ((:x, (:y, :z)), tuple.(xvals, tuple.(yvals, zvals)), true) + ((:x, [:y, :z], (:z, :y)), + tuple.(xvals, vcat.(yvals, zvals), tuple.(zvals, yvals)), + true) + ([:x, :a], vcat.(xvals, p[1]), false) + ((:y, :b), tuple.(yvals, p[2]), true) + (:t, t, true) + ([:x, :a, :t], vcat.(xvals, p[1], t), false) + ((:x, :a, :t), tuple.(xvals, p[1], t), true)] get = getu(sys, sym) - true_value = if i isa Tuple - [getindex.((v,), i) for v in u] - else - getindex.(u, (i,)) + if check_inference + @inferred get(sol) + end + @test get(sol) == ans + for i in eachindex(u) + if check_inference + @inferred get(sol, i) + end + @test get(sol, i) == ans[i] end - @test get(sol) == true_value +end + +for (sym, val) in [(:a, p[1]) + (:b, p[2]) + (:c, p[3]) + ([:a, :b], p[1:2]) + ((:c, :b), (p[3], p[2]))] + get = getu(fi, sym) + @inferred get(fi) + @test get(fi) == val end