From 0d6f00d606602e8a2e8b4fdffd88d9290d7c9a9d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 25 Jun 2024 18:34:16 +0530 Subject: [PATCH] fixup! refactor: rework discrete indexing behavior --- src/state_indexing.jl | 18 +++++++++++++++++- test/parameter_indexing_test.jl | 32 +++++++++++++++++++++++--------- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/src/state_indexing.jl b/src/state_indexing.jl index abe2960..71bd0f4 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -157,6 +157,9 @@ end function is_indexer_timeseries(::Type{G}) where {G <: MultipleGetters{<:Vector}} return IndexerMixedTimeseries() end +function is_indexer_timeseries(::Type{G}) where {G <: MultipleGetters{Nothing}} + return IndexerNotTimeseries() +end function (mg::MultipleGetters)(ts::IsTimeseriesTrait, prob, args...) return mg(ts, is_indexer_timeseries(mg), prob, args...) @@ -180,7 +183,7 @@ end function (mg::MultipleGetters)(ts::Timeseries, ::IndexerBoth, prob, i) mg.((ts,), (prob,), i) end -function (mg::MultipleGetters)(::NotTimeseries, ::IndexerBoth, prob) +function (mg::MultipleGetters)(::NotTimeseries, ::Union{IndexerBoth, IndexerNotTimeseries}, prob) return map(g -> g(prob), mg.getters) end @@ -222,6 +225,19 @@ for (t1, t2) in [ return MultipleGetters(ContinuousTimeseries(), sym) end sym_arr = sym isa Tuple ? collect(sym) : sym + num_observed = count(x -> is_observed(sys, x), sym) + if !is_time_dependent(sys) + if num_observed == 0 || num_observed == 1 && sym isa Tuple + return MultipleGetters(nothing, getu.((sys,), sym)) + else + obs = observed(sys, sym_arr) + getter = TimeIndependentObservedFunction(obs) + if sym isa Tuple + getter = AsTupleWrapper{length(sym)}(getter) + end + return getter + end + end ts_idxs = get_all_timeseries_indexes(sys, sym_arr) if !(ContinuousTimeseries() in ts_idxs) return getp(sys, sym) diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index b276ce4..abd8cd5 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -3,8 +3,7 @@ using SymbolicIndexingInterface: IndexerOnlyTimeseries, IndexerNotTimeseries, In IndexerMixedTimeseries, is_indexer_timeseries, indexer_timeseries_index, ParameterTimeseriesValueIndexMismatchError, - MixedParameterTimeseriesIndexError, - MixedContinuousParameterTimeseriesError + MixedParameterTimeseriesIndexError using Test arr = [1.0, 2.0, 3.0] @@ -465,14 +464,29 @@ for (sym, val_is_timeseries, val, check_inference) in [ end end -for sym in [ - [:x, :b], - (:x, :c), - :(x + b), - [:(2b), :(3x)], - (:(2b), :(3x)) +temp_state = ProblemState(; u = fs.u[1], + p = with_updated_parameter_timeseries_values( + parameter_values(fs), 1 => fs.p_ts[1, 1], 2 => fs.p_ts[2, 1]), + t = fs.t[1]) +_xval = temp_state.u[1] +_bval = bval[1] +_cval = cval[1] +for (sym, val, check_inference) in [ + ([:x, :b], [_xval, _bval], false), + ((:x, :c), (_xval, _cval), true), + (:(x + b), _xval + _bval, true), + ([:(2b), :(3x)], [2_bval, 3_xval], true), + ((:(2b), :(3x)), (2_bval, 3_xval), true) ] - @test_throws MixedContinuousParameterTimeseriesError getu(sys, sym) + getter = getu(sys, sym) + @test_throws MixedParameterTimeseriesIndexError getter(fs) + for subidx in [1, CartesianIndex(2), :, rand(Bool, 4), rand(1:4, 3), 1:2] + @test_throws MixedParameterTimeseriesIndexError getter(fs, subidx) + end + if check_inference + @inferred getter(temp_state) + end + @test getter(temp_state) == val end for sym in [