diff --git a/src/state_indexing.jl b/src/state_indexing.jl index d84e869..f2cd7f2 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -189,27 +189,67 @@ for (t1, t2) in [ (NotSymbolic, Union{<:Tuple, <:AbstractArray}), ] @eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2) - getters = getu.((sys,), sym) - _call(getter, args...) = getter(args...) - return let getters = getters, _call = _call - _getter(::NotTimeseries, prob) = map(g -> g(prob), getters) - function _getter(::Timeseries, prob) - broadcast(i -> map(g -> _call(g, prob, i), getters), - eachindex(state_values(prob))) - end - function _getter(::Timeseries, prob, i) - return map(g -> _call(g, prob, i), getters) - end + num_observed = count(x -> is_observed(sys, x), sym) + if num_observed <= 1 + getters = getu.((sys,), sym) + _call(getter, args...) = getter(args...) + return let getters = getters, _call = _call + _getter(::NotTimeseries, prob) = map(g -> g(prob), getters) + function _getter(::Timeseries, prob) + broadcast(i -> map(g -> _call(g, prob, i), getters), + eachindex(state_values(prob))) + end + function _getter(::Timeseries, prob, i) + return map(g -> _call(g, prob, i), getters) + end - # Need another scope for this to not box `_getter` - let _getter = _getter - function getter(prob) - return _getter(is_timeseries(prob), prob) + # Need another scope for this to not box `_getter` + let _getter = _getter + function getter(prob) + return _getter(is_timeseries(prob), prob) + end + function getter(prob, i) + return _getter(is_timeseries(prob), prob, i) + end + getter + end + end + else + obs = observed(sys, sym isa Tuple ? collect(sym) : sym) + return let obs = obs, is_tuple = sym isa Tuple + function _getter2(::NotTimeseries, prob) + obs(state_values(prob), parameter_values(prob), current_time(prob)) + end + function _getter2(::Timeseries, prob) + obs.(state_values(prob), (parameter_values(prob),), current_time(prob)) end - function getter(prob, i) - return _getter(is_timeseries(prob), prob, i) + function _getter2(::Timeseries, prob, i) + obs(state_values(prob, i), + parameter_values(prob), + current_time(prob, i)) + end + + if is_tuple + let _getter2 = _getter2 + function getter2(prob) + Tuple(_getter2(is_timeseries(prob), prob)) + end + function getter2(prob, i) + Tuple(_getter2(is_timeseries(prob), prob, i)) + end + getter2 + end + else + let _getter2 = _getter2 + function getter3(prob) + _getter2(is_timeseries(prob), prob) + end + function getter3(prob, i) + _getter2(is_timeseries(prob), prob, i) + end + getter3 + end end - getter end end end