Skip to content

Commit

Permalink
refactor: rework discrete indexing behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jun 20, 2024
1 parent 4183188 commit 3e2e26d
Show file tree
Hide file tree
Showing 6 changed files with 429 additions and 298 deletions.
43 changes: 38 additions & 5 deletions src/index_provider_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,23 @@ struct ParameterObservedFunction{I, F <: Function}
observed_fn::F
end

function ParameterObservedFunction(ts_idx, f)
ParameterObservedFunction{typeof(ts_idx), typeof(f)}(ts_idx, f)
end
ParameterObservedFunction(f) = ParameterObservedFunction(nothing, f)

"""
parameter_observed(indp, sym)
Return the observed function of `sym` in `indp` as a [`ParameterObservedFunction`](@ref).
If `sym` only involves variables from a single parameter timeseries (optionally along
with non-timeseries parameters) the timeseries index of the parameter timeseries should
be provided in the [`ParameterObservedFunction`](@ref). In all other cases, just the
observed function should be returned as part of the `ParameterObservedFunction` object.
If `sym` only involves non-timeseries parameters, the timeseries index in the
`ParameterObservedFunction` should be `nothing`. If `sym` involves timeseries parameters,
the timeseries index should be specified in the `ParameterObservedFunction`. If `sym`
involves timeseries parameters from multiple different timeseries, all the relevant
timeseries indexes should be specified in the `ParameterObservedFunction` as a `Vector`.
By default, this function returns `nothing`.
By default, this function returns `nothing`, indicating that the index provider does not
support generating parameter observed functions.
"""
function parameter_observed(indp, sym)
if hasmethod(symbolic_container, Tuple{typeof(indp)})
Expand All @@ -139,6 +146,32 @@ function parameter_observed(indp, sym)
end
end

"""
struct ContinuousTimeseries end
A singleton struct corresponding to the timeseries index of the continuous timeseries.
"""
struct ContinuousTimeseries end

"""
get_all_timeseries_indexes(indp, sym)
Return a `Set` of all unique timeseries indexes of variables in symbolic expression
`sym`. `sym` may be a symbolic, or an array of symbolics. Continuous variables
correspond to the [`ContinuousTimeseries`](@ref) timeseries index. Non-timeseries
parameters do not have a timeseries index. Timeseries parameters have the same timeseries
index as that returned by [`timeseries_parameter_index`](@ref).
By default, this function returns `Set([ContinuousTimeseries()])`.
"""
function get_all_timeseries_indexes(indp, sym)
if hasmethod(symbolic_container, Tuple{typeof(indp)})
return get_all_timeseries_indexes(symbolic_container(indp), sym)
else
return Set([ContinuousTimeseries()])
end
end

"""
parameter_symbols(indp)
Expand Down
167 changes: 106 additions & 61 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,19 @@ end
is_indexer_timeseries(::Type{GetParameterIndex{I}}) where {I} = IndexerNotTimeseries()
function is_indexer_timeseries(::Type{GetParameterIndex{I}}) where {I <:
ParameterTimeseriesIndex}
IndexerTimeseries()
IndexerOnlyTimeseries()
end
function indexer_timeseries_index(gpi::GetParameterIndex{<:ParameterTimeseriesIndex})
gpi.idx.timeseries_idx
end
function (gpi::GetParameterIndex)(::IsTimeseriesTrait, prob)
parameter_values(prob, gpi.idx)
end
function (gpi::GetParameterIndex)(::IsTimeseriesTrait, prob, arg)
parameter_values(prob, gpi.idx)
end
function (gpi::GetParameterIndex)(::Timeseries, prob, args)
throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, gpi, args))
parameter_values(prob, gpi.idx)
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::Timeseries, prob)
get_parameter_timeseries_collection(prob)[gpi.idx]
Expand Down Expand Up @@ -129,10 +132,12 @@ is_indexer_timeseries(::Type{G}) where {G <: GetParameterTimeseriesIndex} = Inde
function indexer_timeseries_index(gpti::GetParameterTimeseriesIndex)
indexer_timeseries_index(gpti.param_timeseries_idx)
end
as_not_timeseries_indexer(::IndexerBoth, gpti::GetParameterTimeseriesIndex) = gpti.param_idx
function as_timeseries_indexer(::IndexerBoth, gpti::GetParameterTimeseriesIndex)
gpti.param_timeseries_idx
end
function as_not_timeseries_indexer(::IndexerBoth, gpti::GetParameterTimeseriesIndex)
gpti.param_idx
end

function (gpti::GetParameterTimeseriesIndex)(ts::Timeseries, prob, args...)
gpti.param_timeseries_idx(ts, prob, args...)
Expand Down Expand Up @@ -163,35 +168,51 @@ const SingleGetParameterObserved = GetParameterObserved{I, false} where {I}
function is_indexer_timeseries(::Type{G}) where {G <: GetParameterObserved{Nothing}}
IndexerNotTimeseries()
end
function is_indexer_timeseries(::Type{G}) where {I <: Vector, G <: GetParameterObserved{I}}
IndexerMixedTimeseries()
end
is_indexer_timeseries(::Type{G}) where {G <: GetParameterObserved} = IndexerBoth()
indexer_timeseries_index(gpo::GetParameterObserved) = gpo.timeseries_idx
function as_not_timeseries_indexer(
::IndexerBoth, gpo::GetParameterObserved{I, M}) where {I, M}
return GetParameterObserved{M}(nothing, gpo.obsfn)
end
as_timeseries_indexer(::IndexerBoth, gpo::GetParameterObserved) = gpo
as_not_timeseries_indexer(::IndexerBoth, gpo::GetParameterObserved) = gpo
as_not_timeseries_indexer(::IndexerMixedTimeseries, gpo::GetParameterObserved) = gpo

function (gpo::GetParameterObserved{Nothing})(::Timeseries, prob)
gpo.obsfn(parameter_values(prob), current_time(prob)[end])
end
for multiple in [true, false]
@eval function (gpo::GetParameterObserved{Nothing, $multiple})(
buffer::AbstractArray, ::Timeseries, prob)
for argType in [Union{Int, CartesianIndex}, Colon, AbstractArray{Bool}, Any]
@eval function (gpo::GetParameterObserved{Nothing})(::Timeseries, prob, args::$argType)
gpo.obsfn(parameter_values(prob), current_time(prob)[end])
end
@eval function (gpo::GetParameterObserved{Nothing, true})(
buffer::AbstractArray, ::Timeseries, prob, args::$argType)
gpo.obsfn(buffer, parameter_values(prob), current_time(prob)[end])
return buffer
end
end
for argType in [Union{Int, CartesianIndex}, Colon, AbstractArray{Bool}, Any]
@eval function (gpo::GetParameterObserved{Nothing})(::Timeseries, prob, args::$argType)
throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, gpo, args))
@eval function (gpo::GetParameterObserved{<:Vector})(::Timeseries, prob, args::$argType)
throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(gpo)))
end
for multiple in [true, false]
@eval function (gpo::GetParameterObserved{Nothing, $multiple})(
@eval function (gpo::GetParameterObserved{<:Vector, $multiple})(
::AbstractArray, ::Timeseries, prob, args::$argType)
throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, gpo, args))
throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(gpo)))
end
end
end

function (gpo::GetParameterObserved{<:Vector})(::NotTimeseries, prob)
gpo.obsfn(parameter_values(prob), current_time(prob))
end
function (gpo::GetParameterObserved{<:Vector, true})(
buffer::AbstractArray, ::NotTimeseries, prob)
gpo.obsfn(buffer, parameter_values(prob), current_time(prob))
end
function (gpo::GetParameterObserved{<:Vector})(::Timeseries, prob)
throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(gpo)))
end
function (gpo::GetParameterObserved{<:Vector, true})(::AbstractArray, ::Timeseries, prob)
throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(gpo)))
end
function (gpo::GetParameterObserved{<:Vector, false})(::AbstractArray, ::Timeseries, prob)
throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(gpo)))
end
function (gpo::GetParameterObserved)(::NotTimeseries, prob)
gpo.obsfn(parameter_values(prob), current_time(prob))
end
Expand Down Expand Up @@ -312,11 +333,7 @@ function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
error("Invalid symbol $p for `getp`")
end

struct MixedTimeseriesIndexes
indexes::Any
end

struct MultipleParametersGetter{T <: IsIndexerTimeseries, G, I} <:
struct MultipleParametersGetter{T <: IndexerTimeseriesType, G, I} <:
AbstractParameterGetIndexer
getters::G
timeseries_idx::I
Expand All @@ -327,63 +344,78 @@ function MultipleParametersGetter(getters)
return MultipleParametersGetter{IndexerNotTimeseries, typeof(getters), Nothing}(
getters, nothing)
end
has_timeseries_indexers = any(getters) do g
is_indexer_timeseries(g) == IndexerTimeseries()
has_only_timeseries_indexers = any(getters) do g
is_indexer_timeseries(g) == IndexerOnlyTimeseries()
end
has_non_timeseries_indexers = any(getters) do g
is_indexer_timeseries(g) == IndexerNotTimeseries()
end
if has_timeseries_indexers && has_non_timeseries_indexers
throw(ArgumentError("Cannot mix timeseries and non-timeseries indexers in `$MultipleParametersGetter`"))
all_non_timeseries_indexers = all(getters) do g
is_indexer_timeseries(g) == IndexerNotTimeseries()
end
if has_only_timeseries_indexers && has_non_timeseries_indexers
throw(ArgumentError("Cannot mix timeseries indexes with non-timeseries variables in `$MultipleParametersGetter`"))
end
indexer_type = if has_timeseries_indexers
# If any getters are only timeseries, all of them should be
indexer_type = if has_only_timeseries_indexers
getters = as_timeseries_indexer.(getters)
timeseries_idx = indexer_timeseries_index(first(getters))
IndexerTimeseries
elseif has_non_timeseries_indexers
getters = as_not_timeseries_indexer.(getters)
timeseries_idx = nothing
IndexerOnlyTimeseries
# If all indexers are non-timeseries, so should their combination
elseif all_non_timeseries_indexers
IndexerNotTimeseries
else
timeseries_idx = indexer_timeseries_index(first(getters))
IndexerBoth
end

if indexer_type != IndexerNotTimeseries &&
!allequal(indexer_timeseries_index(g) for g in getters)
if indexer_type == IndexerTimeseries
throw(ArgumentError("All parameters must belong to the same timeseries"))
else
indexer_type = IndexerNotTimeseries
timeseries_idx = MixedTimeseriesIndexes(indexer_timeseries_index.(getters))
timeseries_idx = if indexer_type == IndexerNotTimeseries
nothing
else
timeseries_idxs = Set(Iterators.flatten(indexer_timeseries_index(g)
for g in getters if is_indexer_timeseries(g) != IndexerNotTimeseries()))
if length(timeseries_idxs) > 1
indexer_type = IndexerMixedTimeseries
getters = as_not_timeseries_indexer.(getters)
collect(timeseries_idxs)
else
only(timeseries_idxs)
end
end

return MultipleParametersGetter{indexer_type, typeof(getters), typeof(timeseries_idx)}(
getters, timeseries_idx)
end

const AtLeastTimeseriesMPG = Union{
MultipleParametersGetter{IndexerTimeseries}, MultipleParametersGetter{IndexerBoth}}
const MixedTimeseriesIndexMPG = MultipleParametersGetter{
IndexerNotTimeseries, G, MixedTimeseriesIndexes} where {G}
const OnlyTimeseriesMPG = MultipleParametersGetter{IndexerOnlyTimeseries}
const BothMPG = MultipleParametersGetter{IndexerBoth}
const NonTimeseriesMPG = MultipleParametersGetter{IndexerNotTimeseries}
const MixedTimeseriesMPG = MultipleParametersGetter{IndexerMixedTimeseries}
const AtLeastTimeseriesMPG = Union{OnlyTimeseriesMPG, BothMPG}

is_indexer_timeseries(::Type{<:MultipleParametersGetter{T}}) where {T} = T()
function indexer_timeseries_index(mpg::MultipleParametersGetter)
mpg.timeseries_idx
end
function as_timeseries_indexer(::IndexerBoth, mpg::MultipleParametersGetter)
MultipleParametersGetter(as_timeseries_indexer.(mpg.getters))
end
function as_not_timeseries_indexer(::IndexerBoth, mpg::MultipleParametersGetter)
MultipleParametersGetter(as_not_timeseries_indexer.(mpg.getters))
end
function as_not_timeseries_indexer(::IndexerMixedTimeseries, mpg::MultipleParametersGetter)
return mpg
end

function as_timeseries_indexer(::IndexerBoth, mpg::MultipleParametersGetter)
MultipleParametersGetter(as_timeseries_indexer.(mpg.getters))
function (mpg::MixedTimeseriesMPG)(::Timeseries, prob, args...)
throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(mpg)))
end
function (mpg::MixedTimeseriesMPG)(::AbstractArray, ::Timeseries, prob, args...)
throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(mpg)))
end

for (indexerTimeseriesType, timeseriesType) in [
(IndexerNotTimeseries, IsTimeseriesTrait),
(IndexerBoth, NotTimeseries)
(IndexerBoth, NotTimeseries),
(IndexerMixedTimeseries, NotTimeseries)
]
@eval function (mpg::MultipleParametersGetter{$indexerTimeseriesType})(
::$timeseriesType, prob)
Expand All @@ -398,17 +430,16 @@ for (indexerTimeseriesType, timeseriesType) in [
end
end

function (mpg::MixedTimeseriesIndexMPG)(::Timeseries, prob, args...)
throw(MixedParameterTimeseriesIndexError(prob, mpg.timeseries_idx.indexes))
end

function (mpg::MultipleParametersGetter{IndexerNotTimeseries})(::Timeseries, prob, args)
throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, mpg, args))
function (mpg::NonTimeseriesMPG)(::IsTimeseriesTrait, prob, arg)
return _call.(mpg.getters, (prob,), (arg,))
end
function (mpg::MultipleParametersGetter{IndexerNotTimeseries})(
::AbstractArray, ::Timeseries, prob, args)
throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, mpg, args))
function (mpg::NonTimeseriesMPG)(buffer::AbstractArray, ::IsTimeseriesTrait, prob, arg)
for (buf_idx, getter) in zip(eachindex(buffer), mpg.getters)
buffer[buf_idx] = getter(prob)
end
return buffer
end

function (mpg::AtLeastTimeseriesMPG)(ts::Timeseries, prob)
map(eachindex(parameter_timeseries(prob, indexer_timeseries_index(mpg)))) do i
mpg(ts, prob, i)
Expand Down Expand Up @@ -457,10 +488,10 @@ function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ts::Timeseries, prob
end
return buffer
end
function (mpg::MultipleParametersGetter{IndexerTimeseries})(::NotTimeseries, prob)
function (mpg::OnlyTimeseriesMPG)(::NotTimeseries, prob)
throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, mpg))
end
function (mpg::MultipleParametersGetter{IndexerTimeseries})(
function (mpg::OnlyTimeseriesMPG)(
::AbstractArray, ::NotTimeseries, prob)
throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, mpg))
end
Expand Down Expand Up @@ -490,6 +521,10 @@ wrap_tuple(::AsParameterTupleWrapper{N}, val) where {N} = ntuple(i -> val[i], Va
function (atw::AsParameterTupleWrapper)(ts::IsTimeseriesTrait, prob, args...)
atw(ts, is_indexer_timeseries(atw), prob, args...)
end
function (atw::AsParameterTupleWrapper)(
ts::IsTimeseriesTrait, ::IndexerMixedTimeseries, prob, args...)
wrap_tuple(atw, atw.getter(ts, prob, args...))
end
function (atw::AsParameterTupleWrapper)(ts::Timeseries, ::AtLeastTimeseriesIndexer, prob)
wrap_tuple.((atw,), atw.getter(ts, prob))
end
Expand Down Expand Up @@ -519,6 +554,16 @@ is_observed_getter(::GetParameterObserved) = true
is_observed_getter(::GetParameterObservedNoTime) = true
is_observed_getter(mpg::MultipleParametersGetter) = any(is_observed_getter, mpg.getters)

function _supports_parameter_observed(indp, sym)
if hasmethod(parameter_observed, Tuple{typeof(indp), typeof(sym)})
return true
elseif hasmethod(symbolic_container, Tuple{typeof(indp)})
return _supports_parameter_observed(symbolic_container(indp), sym)
else
return false
end
end

for (t1, t2) in [
(ArraySymbolic, Any),
(ScalarSymbolic, Any),
Expand All @@ -531,7 +576,7 @@ for (t1, t2) in [
getters = getp.((sys,), p)
num_observed = count(is_observed_getter, getters)

if num_observed == 0
if num_observed == 0 || !_supports_parameter_observed(sys, p)
return MultipleParametersGetter(getters)
else
pofn = parameter_observed(sys, p isa Tuple ? collect(p) : p)
Expand Down
Loading

0 comments on commit 3e2e26d

Please sign in to comment.