Skip to content

Commit

Permalink
fixup! refactor: rework discrete indexing behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jun 21, 2024
1 parent 3e2e26d commit fd7daee
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 39 deletions.
23 changes: 12 additions & 11 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,20 +190,21 @@ for (t1, t2) in [
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
]
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
num_observed = count(x -> is_observed(sys, x), sym)
num_params = count(x -> is_parameter(sys, x), sym)
num_timeseries_params = count(
x -> is_parameter(sys, x) && is_timeseries_parameter(sys, x), sym)
if num_timeseries_params > 0 && num_params < length(sym) # we have timeseries params and continuous variables
if isempty(sym)
return MultipleGetters(sym)
end
ts_idxs = get_all_timeseries_indexes(sys, sym isa Tuple ? collect(sym) : sym)
if ContinuousTimeseries() in ts_idxs && length(ts_idxs) > 1
throw(MixedContinuousParameterTimeseriesError(sym))
end
if !(ContinuousTimeseries() in ts_idxs)
return getp(sys, sym)
end

num_observed = count(x -> is_observed(sys, x), sym)
if num_observed == 0 || num_observed == 1 && sym isa Tuple
if !isempty(sym) && num_timeseries_params > 0
getp(sys, sym)
else
getters = getu.((sys,), sym)
return MultipleGetters(getters)
end
getters = getu.((sys,), sym)
return MultipleGetters(getters)
else
obs = observed(sys, sym isa Tuple ? collect(sym) : sym)
getter = if is_time_dependent(sys)
Expand Down
2 changes: 1 addition & 1 deletion src/symbol_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ function get_all_timeseries_indexes(sc::SymbolCache, sym::Expr)
Base.Fix1(get_all_timeseries_indexes, sc), union, exs.declared; init = Set())
end
function get_all_timeseries_indexes(sc::SymbolCache, sym::AbstractArray)
return mapreduce(Base.Fix1(get_all_timeseries_indexes, sc), union, sym)
return mapreduce(Base.Fix1(get_all_timeseries_indexes, sc), union, sym; init = Set())
end
function is_independent_variable(sc::SymbolCache, sym)
sc.independent_variables === nothing && return false
Expand Down
58 changes: 31 additions & 27 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ using SymbolicIndexingInterface: IndexerOnlyTimeseries, IndexerNotTimeseries, In
IndexerMixedTimeseries,
is_indexer_timeseries, indexer_timeseries_index,
ParameterTimeseriesValueIndexMismatchError,
MixedParameterTimeseriesIndexError
MixedParameterTimeseriesIndexError,
MixedContinuousParameterTimeseriesError
using Test

arr = [1.0, 2.0, 3.0]
Expand Down Expand Up @@ -409,64 +410,68 @@ for sym in [[:a, bidx], (:a, bidx), [1, bidx], (1, bidx)]
@test_throws ArgumentError getp(sys, sym)
end

for (sym, val) in [
([:b, :c], [bval[end], cval[end]]),
((:b, :c), (bval[end], cval[end]))
]
for (sym, val) in [([:b, :c], [bval[end], cval[end]])
((:b, :c), (bval[end], cval[end]))]
getter = getp(sys, sym)
@test is_indexer_timeseries(getter) == IndexerMixedTimeseries()
@test_throws MixedParameterTimeseriesIndexError getter(fs)
@test getter(parameter_values(fs)) == val
end

xval = getindex.(fs.u, 1)

for (sym, val_is_timeseries, val, check_inference) in [
(:a, false, aval, true),
([:a, :d], false, [aval, dval], true),
((:a, :d), false, (aval, dval), true),
(:b, true, bval, true),
([:a, :b], true, vcat.(aval, bval), false),
((:a, :b), true, tuple.(aval, bval), true),
# ([:b, :c], true, vcat.(bval_state, cval_state), true),
# ((:b, :c), true, tuple.(bval_state, cval_state), true),
# ([:a, :b, :c], true, vcat.(aval, bval_state, cval_state), false),
# ((:a, :b, :c), true, tuple.(aval, bval_state, cval_state), true),
# ([:x, :b], true, vcat.(xval, bval_state), false),
# ((:x, :b), true, tuple.(xval, bval_state), true),
# ([:x, :b, :c], true, vcat.(xval, bval_state, cval_state), false),
# ((:x, :b, :c), true, tuple.(xval, bval_state, cval_state), true),
# ([:a, :b, :x], true, vcat.(aval, bval_state, xval), false),
# ((:a, :b, :x), true, tuple.(aval, bval_state, xval), true),
([:a, :x], true, vcat.(aval, xval), false),
((:a, :x), true, tuple.(aval, xval), true),
(:(2b), true, 2 .* bval, true),
([:a, :(2b)], true, vcat.(aval, 2 .* bval), true) # ([:x, :(2b), :(3c)], true, vcat.(xval, 2 .* bval_state, 3 .* cval_state), true), # ((:x, :(2b), :(3c)), true, tuple.(xval, 2 .* bval_state, 3 .* cval_state), true)
([:a, :(2b)], true, vcat.(aval, 2 .* bval), true),
((:a, :(2b)), true, tuple.(aval, 2 .* bval), true)
]
getter = getu(sys, sym)
if check_inference
@inferred getter(fs)
end
@test getter(fs) == val

reference = val_is_timeseries ? val : xval
for subidx in [
1, CartesianIndex(2), :, rand(Bool, length(bval)), rand(eachindex(bval), 3), 1:2]
@show sym subidx
1, CartesianIndex(2), :, rand(Bool, length(reference)),
rand(eachindex(reference), 3), 1:2]
if check_inference
@inferred getter(fs, subidx)
end
target = if val_is_timeseries
val[subidx]
else
if bval[subidx] isa AbstractArray
len = length(bval[subidx])
fill(val, len)
else
val
end
val
end
@test getter(fs, subidx) == target
end
end

#=
@test_throws ErrorException getp(sys, :not_a_param)
for sym in [
[:x, :b],
(:x, :c),
:(x + b),
[:(2b), :(3x)],
(:(2b), :(3x))
]
@test_throws MixedContinuousParameterTimeseriesError getu(sys, sym)
end

for sym in [
:err,
[:err, :b],
(:err, :b)
]
@test_throws ErrorException getp(sys, sym)
end

let fs = fs, sys = sys
getter = getp(sys, [])
Expand Down Expand Up @@ -511,4 +516,3 @@ for (sym, val, check_inference) in [
@test buffer == collect(val)
end
end
=#

0 comments on commit fd7daee

Please sign in to comment.