Skip to content

Commit

Permalink
fixup! refactor: rework discrete indexing behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jun 21, 2024
1 parent fd7daee commit aff02b3
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 115 deletions.
1 change: 1 addition & 0 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export is_variable, variable_index, variable_symbols, is_parameter, parameter_in
is_timeseries_parameter, timeseries_parameter_index, ParameterTimeseriesIndex,
parameter_symbols, is_independent_variable, independent_variable_symbols,
is_observed, observed, parameter_observed, ParameterObservedFunction,
ContinuousTimeseries, get_all_timeseries_indexes,
is_time_dependent, constant_structure, symbolic_container,
all_variable_symbols, all_symbols, solvedvariables, allvariables, default_values,
symbolic_evaluate
Expand Down
24 changes: 12 additions & 12 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ as_timeseries_indexer(::IndexerBoth, gpo::GetParameterObserved) = gpo
as_not_timeseries_indexer(::IndexerBoth, gpo::GetParameterObserved) = gpo
as_not_timeseries_indexer(::IndexerMixedTimeseries, gpo::GetParameterObserved) = gpo

function (gpo::GetParameterObserved{Nothing})(::Timeseries, prob)
gpo.obsfn(parameter_values(prob), current_time(prob)[end])
end
function (gpo::GetParameterObserved{Nothing, true})(
buffer::AbstractArray, ::Timeseries, prob)
gpo.obsfn(buffer, parameter_values(prob), current_time(prob)[end])
return buffer
end
for argType in [Union{Int, CartesianIndex}, Colon, AbstractArray{Bool}, Any]
@eval function (gpo::GetParameterObserved{Nothing})(::Timeseries, prob, args::$argType)
gpo.obsfn(parameter_values(prob), current_time(prob)[end])
Expand Down Expand Up @@ -535,7 +543,6 @@ end
function (atw::AsParameterTupleWrapper)(ts::Timeseries, ::AtLeastTimeseriesIndexer, prob, i)
wrap_tuple.((atw,), atw.getter(ts, prob, i))
end
# args is just so it throws
function (atw::AsParameterTupleWrapper)(
ts::Timeseries, ::IndexerNotTimeseries, prob, args...)
wrap_tuple(atw, atw.getter(ts, prob, args...))
Expand All @@ -554,16 +561,6 @@ is_observed_getter(::GetParameterObserved) = true
is_observed_getter(::GetParameterObservedNoTime) = true
is_observed_getter(mpg::MultipleParametersGetter) = any(is_observed_getter, mpg.getters)

function _supports_parameter_observed(indp, sym)
if hasmethod(parameter_observed, Tuple{typeof(indp), typeof(sym)})
return true
elseif hasmethod(symbolic_container, Tuple{typeof(indp)})
return _supports_parameter_observed(symbolic_container(indp), sym)
else
return false
end
end

for (t1, t2) in [
(ArraySymbolic, Any),
(ScalarSymbolic, Any),
Expand All @@ -576,10 +573,13 @@ for (t1, t2) in [
getters = getp.((sys,), p)
num_observed = count(is_observed_getter, getters)

if num_observed == 0 || !_supports_parameter_observed(sys, p)
if num_observed == 0
return MultipleParametersGetter(getters)
else
pofn = parameter_observed(sys, p isa Tuple ? collect(p) : p)
if pofn === nothing
return MultipleParametersGetter.(getters)
end
if is_time_dependent(sys)
getter = GetParameterObserved{true}(pofn.timeseries_idx, pofn.observed_fn)
else
Expand Down
64 changes: 0 additions & 64 deletions src/parameter_timeseries_collection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,30 +124,6 @@ function parameter_values_at_time(valp, t)
(ts_idx => _timeseries_value(ptc, ts_idx, t) for ts_idx in eachindex(ptc))...)
end

"""
parameter_values_at_state_time(valp, i)
parameter_values_at_state_time(valp)
Return an indexable collection containing the value of all parameters in `valp` at time
index `i` in the state timeseries.
By default, this function relies on [`parameter_values_at_time`](@ref) and
[`current_time`](@ref) for a default implementation.
The single-argument version of this function is a shorthand to return parameter values
at each point in the state timeseries. This also has a default implementation relying on
[`parameter_values_at_time`](@ref) and [`current_time`](@ref).
"""
function parameter_values_at_state_time end

function parameter_values_at_state_time(p, i)
state_time = current_time(p, i)
return parameter_values_at_time(p, state_time)
end
function parameter_values_at_state_time(p)
return (parameter_values_at_time(p, t) for t in current_time(p))
end

"""
parameter_timeseries(valp, i)
Expand All @@ -163,43 +139,3 @@ function parameter_timeseries end
function parameter_timeseries(valp, i)
return parameter_timeseries(get_parameter_timeseries_collection(valp), i)
end

"""
parameter_timeseries_at_state_time(valp, i, j)
parameter_timeseries_at_state_time(valp, i)
Return the index of the timestep in the parameter timeseries at timeseries index `i` which
occurs just before or at the same time as the state timestep with index `j`. The two-
argument version of this function returns an iterable of indexes, one for each timestep in
the state timeseries. If `j` is an object that refers to multiple values in the state
timeseries (e.g. `Colon`), return an iterable of the indexes in the parameter timeseries
at the appropriate points.
Both versions of this function have default implementations relying on
[`current_time`](@ref) and [`parameter_timeseries`](@ref), for the cases where `j` is one
of: `Int`, `CartesianIndex`, `AbstractArray{Bool}`, `Colon` or an iterable of the
aforementioned.
"""
function parameter_timeseries_at_state_time end

function parameter_timeseries_at_state_time(valp, i, j::Union{Int, CartesianIndex})
state_time = current_time(valp, j)
timeseries = parameter_timeseries(valp, i)
searchsortedlast(timeseries, state_time)
end

function parameter_timeseries_at_state_time(valp, i, ::Colon)
parameter_timeseries_at_state_time(valp, i)
end

function parameter_timeseries_at_state_time(valp, i, j::AbstractArray{Bool})
parameter_timeseries_at_state_time(valp, i, only(to_indices(current_time(valp), (j,))))
end

function parameter_timeseries_at_state_time(valp, i, j)
(parameter_timeseries_at_state_time(valp, i, jj) for jj in j)
end

function parameter_timeseries_at_state_time(valp, i)
parameter_timeseries_at_state_time(valp, i, eachindex(current_time(valp)))
end
37 changes: 9 additions & 28 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,43 +61,24 @@ struct TimeDependentObservedFunction{F} <: AbstractStateGetIndexer
end

function (o::TimeDependentObservedFunction)(ts::Timeseries, prob)
return o(ts, is_parameter_timeseries(prob), prob)
end
function (o::TimeDependentObservedFunction)(::Timeseries, ::Timeseries, prob)
map(o.obsfn, state_values(prob),
parameter_values_at_state_time(prob), current_time(prob))
end
function (o::TimeDependentObservedFunction)(::Timeseries, ::NotTimeseries, prob)
o.obsfn.(state_values(prob),
return o.obsfn.(state_values(prob),
(parameter_values(prob),),
current_time(prob))
end
function (o::TimeDependentObservedFunction)(ts::Timeseries, prob, i)
return o(ts, is_parameter_timeseries(prob), prob, i)
end
function (o::TimeDependentObservedFunction)(
::Timeseries, ::Timeseries, prob, i::Union{Int, CartesianIndex})
return o.obsfn(state_values(prob, i),
parameter_values_at_state_time(prob, i),
current_time(prob, i))
::Timeseries, prob, i::Union{Int, CartesianIndex})
return o.obsfn(state_values(prob, i), parameter_values(prob), current_time(prob, i))
end
function (o::TimeDependentObservedFunction)(
ts::Timeseries, p_ts::IsTimeseriesTrait, prob, ::Colon)
return o(ts, p_ts, prob)
function (o::TimeDependentObservedFunction)(ts::Timeseries, prob, ::Colon)
return o(ts, prob)
end
function (o::TimeDependentObservedFunction)(
ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i::AbstractArray{Bool})
function (o::TimeDependentObservedFunction)(ts::Timeseries, prob, i::AbstractArray{Bool})
map(only(to_indices(current_time(prob), (i,)))) do idx
o(ts, p_ts, prob, idx)
o(ts, prob, idx)
end
end
function (o::TimeDependentObservedFunction)(
ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i)
o.((ts,), (p_ts,), (prob,), i)
end
function (o::TimeDependentObservedFunction)(
::Timeseries, ::NotTimeseries, prob, i::Union{Int, CartesianIndex})
o.obsfn(state_values(prob, i), parameter_values(prob), current_time(prob, i))
function (o::TimeDependentObservedFunction)(ts::Timeseries, prob, i)
o.((ts,), (prob,), i)
end
function (o::TimeDependentObservedFunction)(::NotTimeseries, prob)
return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob))
Expand Down
28 changes: 19 additions & 9 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ for sys in [
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)
]
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)]
get = getp(sys, sym)
set! = setp(sys, sym)
if check_inference
Expand Down Expand Up @@ -223,27 +222,38 @@ for (sym, val, buffer, check_inference) in [
([1, 4], [aval, dval], zeros(2), true),
((1, 4), (aval, dval), zeros(2), true),
([:a, 4], [aval, dval], zeros(2), true),
((:a, 4), (aval, dval), zeros(2), true)
((:a, 4), (aval, dval), zeros(2), true),
(:(a + d), aval + dval, nothing, true),
([:(a + d), :(a * d)], [aval + dval, aval * dval], zeros(2), true),
((:(a + d), :(a * d)), (aval + dval, aval * dval), zeros(2), true)
]
getter = getp(sys, sym)
@test is_indexer_timeseries(getter) isa IndexerNotTimeseries
test_inplace = buffer !== nothing

is_observed = sym isa Expr ||
sym isa Union{AbstractArray, Tuple} && any(x -> x isa Expr, sym)
if check_inference
@inferred getter(fs)
@inferred getter(parameter_values(fs))
if !is_observed
@inferred getter(parameter_values(fs))
end
if test_inplace
@inferred getter(deepcopy(buffer), fs)
@inferred getter(deepcopy(buffer), parameter_values(fs))
if !is_observed
@inferred getter(deepcopy(buffer), parameter_values(fs))
end
end
end
@test getter(fs) == val
@test getter(parameter_values(fs)) == val
if !is_observed
@test getter(parameter_values(fs)) == val
end
if test_inplace
target = collect(val)
for obj in (fs, parameter_values(fs))
valps = is_observed ? (fs,) : (fs, parameter_values(fs))
for valp in valps
tmp = deepcopy(buffer)
getter(tmp, obj)
getter(tmp, valp)
@test tmp == target
end
end
Expand Down
4 changes: 2 additions & 2 deletions test/symbol_cache_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ sc = SymbolCache([:x, :y], [:a, :b, :c], :t;
@test parameter_observed(sc, :(2a)).timeseries_idx === nothing
@test parameter_observed(sc, [:(2a), :(3a)]).timeseries_idx === nothing
@test parameter_observed(sc, (:(2a), :(3a))).timeseries_idx === nothing
@test parameter_observed(sc, [:b, :c]).timeseries_idx === nothing
@test parameter_observed(sc, (:b, :c)).timeseries_idx === nothing
@test sort(parameter_observed(sc, [:b, :c]).timeseries_idx) == [1, 2]
@test sort(parameter_observed(sc, (:b, :c)).timeseries_idx) == [1, 2]

@test_throws ArgumentError SymbolCache([:x, :y], [:a, :b], :t;
timeseries_parameters = Dict(:c => ParameterTimeseriesIndex(1, 1)))
Expand Down

0 comments on commit aff02b3

Please sign in to comment.