Skip to content

Commit

Permalink
fixup! refactor: rework discrete indexing behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jun 25, 2024
1 parent 1fdb6c1 commit 0d6f00d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 10 deletions.
18 changes: 17 additions & 1 deletion src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ end
function is_indexer_timeseries(::Type{G}) where {G <: MultipleGetters{<:Vector}}
return IndexerMixedTimeseries()
end
function is_indexer_timeseries(::Type{G}) where {G <: MultipleGetters{Nothing}}
return IndexerNotTimeseries()
end

function (mg::MultipleGetters)(ts::IsTimeseriesTrait, prob, args...)
return mg(ts, is_indexer_timeseries(mg), prob, args...)
Expand All @@ -180,7 +183,7 @@ end
function (mg::MultipleGetters)(ts::Timeseries, ::IndexerBoth, prob, i)
mg.((ts,), (prob,), i)
end
function (mg::MultipleGetters)(::NotTimeseries, ::IndexerBoth, prob)
function (mg::MultipleGetters)(::NotTimeseries, ::Union{IndexerBoth, IndexerNotTimeseries}, prob)
return map(g -> g(prob), mg.getters)
end

Expand Down Expand Up @@ -222,6 +225,19 @@ for (t1, t2) in [
return MultipleGetters(ContinuousTimeseries(), sym)
end
sym_arr = sym isa Tuple ? collect(sym) : sym
num_observed = count(x -> is_observed(sys, x), sym)
if !is_time_dependent(sys)
if num_observed == 0 || num_observed == 1 && sym isa Tuple
return MultipleGetters(nothing, getu.((sys,), sym))
else
obs = observed(sys, sym_arr)
getter = TimeIndependentObservedFunction(obs)
if sym isa Tuple
getter = AsTupleWrapper{length(sym)}(getter)
end
return getter
end
end
ts_idxs = get_all_timeseries_indexes(sys, sym_arr)
if !(ContinuousTimeseries() in ts_idxs)
return getp(sys, sym)
Expand Down
32 changes: 23 additions & 9 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ using SymbolicIndexingInterface: IndexerOnlyTimeseries, IndexerNotTimeseries, In
IndexerMixedTimeseries,
is_indexer_timeseries, indexer_timeseries_index,
ParameterTimeseriesValueIndexMismatchError,
MixedParameterTimeseriesIndexError,
MixedContinuousParameterTimeseriesError
MixedParameterTimeseriesIndexError
using Test

arr = [1.0, 2.0, 3.0]
Expand Down Expand Up @@ -465,14 +464,29 @@ for (sym, val_is_timeseries, val, check_inference) in [
end
end

for sym in [
[:x, :b],
(:x, :c),
:(x + b),
[:(2b), :(3x)],
(:(2b), :(3x))
temp_state = ProblemState(; u = fs.u[1],
p = with_updated_parameter_timeseries_values(
parameter_values(fs), 1 => fs.p_ts[1, 1], 2 => fs.p_ts[2, 1]),
t = fs.t[1])
_xval = temp_state.u[1]
_bval = bval[1]
_cval = cval[1]
for (sym, val, check_inference) in [
([:x, :b], [_xval, _bval], false),
((:x, :c), (_xval, _cval), true),
(:(x + b), _xval + _bval, true),
([:(2b), :(3x)], [2_bval, 3_xval], true),
((:(2b), :(3x)), (2_bval, 3_xval), true)
]
@test_throws MixedContinuousParameterTimeseriesError getu(sys, sym)
getter = getu(sys, sym)
@test_throws MixedParameterTimeseriesIndexError getter(fs)
for subidx in [1, CartesianIndex(2), :, rand(Bool, 4), rand(1:4, 3), 1:2]
@test_throws MixedParameterTimeseriesIndexError getter(fs, subidx)
end
if check_inference
@inferred getter(temp_state)
end
@test getter(temp_state) == val
end

for sym in [
Expand Down

0 comments on commit 0d6f00d

Please sign in to comment.