Skip to content

Commit

Permalink
fixup! wip: better parameter indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 13, 2024
1 parent 1eeb5c5 commit 6c9c517
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 127 deletions.
108 changes: 74 additions & 34 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ struct GetParameterIndex{I} <: AbstractParameterGetIndexer
end

is_indexer_timeseries(::Type{GetParameterIndex{I}}) where {I} = IndexerNotTimeseries()
is_indexer_timeseries(::Type{GetParameterIndex{I}}) where {I <: ParameterTimeseriesIndex} = IndexerTimeseries()
function is_indexer_timeseries(::Type{GetParameterIndex{I}}) where {I <:
ParameterTimeseriesIndex}
IndexerTimeseries()
end
function indexer_timeseries_index(gpi::GetParameterIndex{<:ParameterTimeseriesIndex})
gpi.idx.timeseries_idx
end
Expand All @@ -56,38 +59,48 @@ function (gpi::GetParameterIndex)(::Timeseries, prob, args)
throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, gpi, args))
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob)
gpi.((ts,), (prob,), eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi))))
gpi.((ts,), (prob,),
eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi))))
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob)
for (buf_idx, ts_idx) in zip(eachindex(buffer), eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi))))
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(
buffer::AbstractArray, ts::Timeseries, prob)
for (buf_idx, ts_idx) in zip(eachindex(buffer),
eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi))))
buffer[buf_idx] = gpi(ts, prob, ts_idx)
end
return buffer
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::Timeseries, prob, i::Union{Int, CartesianIndex})
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(
::Timeseries, prob, i::Union{Int, CartesianIndex})
parameter_values(prob, gpi.idx, i)
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, ::Colon)
gpi(ts, prob)
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob, ::Colon)
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(
buffer::AbstractArray, ts::Timeseries, prob, ::Colon)
gpi(buffer, ts, prob)
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, i::AbstractArray{Bool})
map(only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(gpi)), (i,)))) do idx
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(
ts::Timeseries, prob, i::AbstractArray{Bool})
map(only(to_indices(
parameter_timeseries(prob, indexer_timeseries_index(gpi)), (i,)))) do idx
gpi(ts, prob, idx)
end
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob, i::AbstractArray{Bool})
for (buf_idx, ts_idx) in zip(eachindex(buffer), only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(gpi)), (i,))))
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(
buffer::AbstractArray, ts::Timeseries, prob, i::AbstractArray{Bool})
for (buf_idx, ts_idx) in zip(eachindex(buffer),
only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(gpi)), (i,))))
buffer[buf_idx] = gpi(ts, prob, ts_idx)
end
return buffer
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, i)
gpi.((ts,), (prob,), i)
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob, i)
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(
buffer::AbstractArray, ts::Timeseries, prob, i)
for (buf_idx, subidx) in zip(eachindex(buffer), i)
buffer[buf_idx] = gpi(ts, prob, subidx)
end
Expand All @@ -96,28 +109,36 @@ end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::NotTimeseries, prob)
throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, gpi))
end
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::AbstractArray, ::NotTimeseries, prob)
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(
::AbstractArray, ::NotTimeseries, prob)
throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, gpi))
end

function _getp(sys, ::NotSymbolic, ::NotSymbolic, p)
return GetParameterIndex(p)
end

struct GetParameterTimeseriesIndex{I <: GetParameterIndex, J <: GetParameterIndex{<:ParameterTimeseriesIndex}} <: AbstractParameterGetIndexer
struct GetParameterTimeseriesIndex{
I <: GetParameterIndex, J <: GetParameterIndex{<:ParameterTimeseriesIndex}} <:
AbstractParameterGetIndexer
param_idx::I
param_timeseries_idx::J
end

is_indexer_timeseries(::Type{G}) where {G <: GetParameterTimeseriesIndex} = IndexerBoth()
indexer_timeseries_index(gpti::GetParameterTimeseriesIndex) = indexer_timeseries_index(gpti.param_timeseries_idx)
function indexer_timeseries_index(gpti::GetParameterTimeseriesIndex)
indexer_timeseries_index(gpti.param_timeseries_idx)
end
as_not_timeseries_indexer(::IndexerBoth, gpti::GetParameterTimeseriesIndex) = gpti.param_idx
as_timeseries_indexer(::IndexerBoth, gpti::GetParameterTimeseriesIndex) = gpti.param_timeseries_idx
function as_timeseries_indexer(::IndexerBoth, gpti::GetParameterTimeseriesIndex)
gpti.param_timeseries_idx
end

function (gpti::GetParameterTimeseriesIndex)(ts::Timeseries, prob, args...)
gpti.param_timeseries_idx(ts, prob, args...)
end
function (gpti::GetParameterTimeseriesIndex)(buffer::AbstractArray, ts::Timeseries, prob, args...)
function (gpti::GetParameterTimeseriesIndex)(
buffer::AbstractArray, ts::Timeseries, prob, args...)
gpti.param_timeseries_idx(buffer, ts, prob, args...)
end
function (gpti::GetParameterTimeseriesIndex)(ts::NotTimeseries, prob)
Expand All @@ -128,17 +149,19 @@ function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
idx = parameter_index(sys, p)
if is_timeseries_parameter(sys, p)
ts_idx = timeseries_parameter_index(sys, p)
return GetParameterTimeseriesIndex(GetParameterIndex(idx), GetParameterIndex(ts_idx))
return GetParameterTimeseriesIndex(
GetParameterIndex(idx), GetParameterIndex(ts_idx))
else
return GetParameterIndex(idx)
end
end

struct MixedTimeseriesIndexes
indexes
indexes::Any
end

struct MultipleParametersGetter{T <: IsIndexerTimeseries, G, I} <: AbstractParameterGetIndexer
struct MultipleParametersGetter{T <: IsIndexerTimeseries, G, I} <:
AbstractParameterGetIndexer
getters::G
timeseries_idx::I

Expand All @@ -165,7 +188,8 @@ struct MultipleParametersGetter{T <: IsIndexerTimeseries, G, I} <: AbstractParam
IndexerBoth
end

if indexer_type != IndexerNotTimeseries && !allequal(indexer_timeseries_index(g) for g in getters)
if indexer_type != IndexerNotTimeseries &&
!allequal(indexer_timeseries_index(g) for g in getters)
if indexer_type == IndexerTimeseries
throw(ArgumentError("All parameters must belong to the same timeseries"))
else
Expand All @@ -175,29 +199,38 @@ struct MultipleParametersGetter{T <: IsIndexerTimeseries, G, I} <: AbstractParam
end
end

return new{indexer_type, typeof(getters), typeof(timeseries_idx)}(getters, timeseries_idx)
return new{indexer_type, typeof(getters), typeof(timeseries_idx)}(
getters, timeseries_idx)
end
end

const AtLeastTimeseriesMPG = Union{MultipleParametersGetter{IndexerTimeseries}, MultipleParametersGetter{IndexerBoth}}
const MixedTimeseriesIndexMPG = MultipleParametersGetter{IndexerNotTimeseries, G, MixedTimeseriesIndexes} where {G}
const AtLeastTimeseriesMPG = Union{
MultipleParametersGetter{IndexerTimeseries}, MultipleParametersGetter{IndexerBoth}}
const MixedTimeseriesIndexMPG = MultipleParametersGetter{
IndexerNotTimeseries, G, MixedTimeseriesIndexes} where {G}

is_indexer_timeseries(::Type{<:MultipleParametersGetter{T}}) where {T} = T()
function indexer_timeseries_index(mpg::MultipleParametersGetter)
mpg.timeseries_idx
end
as_not_timeseries_indexer(::IndexerBoth, mpg::MultipleParametersGetter) = MultipleParametersGetter(as_not_timeseries_indexer.(mpg.getters))
function as_not_timeseries_indexer(::IndexerBoth, mpg::MultipleParametersGetter)
MultipleParametersGetter(as_not_timeseries_indexer.(mpg.getters))
end

as_timeseries_indexer(::IndexerBoth, mpg::MultipleParametersGetter) = MultipleParametersGetter(as_timeseries_indexer.(mpg.getters))
function as_timeseries_indexer(::IndexerBoth, mpg::MultipleParametersGetter)
MultipleParametersGetter(as_timeseries_indexer.(mpg.getters))
end

for (indexerTimeseriesType, timeseriesType) in [
(IndexerNotTimeseries, IsTimeseriesTrait),
(IndexerBoth, NotTimeseries)
]
@eval function (mpg::MultipleParametersGetter{$indexerTimeseriesType})(::$timeseriesType, prob)
@eval function (mpg::MultipleParametersGetter{$indexerTimeseriesType})(
::$timeseriesType, prob)
CallWith(prob).(mpg.getters)
end
@eval function (mpg::MultipleParametersGetter{$indexerTimeseriesType})(buffer::AbstractArray, ::$timeseriesType, prob)
@eval function (mpg::MultipleParametersGetter{$indexerTimeseriesType})(
buffer::AbstractArray, ::$timeseriesType, prob)
for (buf_idx, getter) in zip(eachindex(buffer), mpg.getters)
buffer[buf_idx] = getter(prob)
end
Expand All @@ -212,7 +245,8 @@ end
function (mpg::MultipleParametersGetter{IndexerNotTimeseries})(::Timeseries, prob, args)
throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, mpg, args))
end
function (mpg::MultipleParametersGetter{IndexerNotTimeseries})(::AbstractArray, ::Timeseries, prob, args)
function (mpg::MultipleParametersGetter{IndexerNotTimeseries})(
::AbstractArray, ::Timeseries, prob, args)
throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, mpg, args))
end
function (mpg::AtLeastTimeseriesMPG)(ts::Timeseries, prob)
Expand All @@ -227,20 +261,23 @@ function (mpg::AtLeastTimeseriesMPG)(ts::Timeseries, prob, ::Colon)
mpg(ts, prob)
end
function (mpg::AtLeastTimeseriesMPG)(ts::Timeseries, prob, i::AbstractArray{Bool})
map(only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(mpg)), (i,)))) do idx
map(only(to_indices(
parameter_timeseries(prob, indexer_timeseries_index(mpg)), (i,)))) do idx
mpg(ts, prob, idx)
end
end
function (mpg::AtLeastTimeseriesMPG)(ts::Timeseries, prob, i)
mpg.((ts,), (prob,), i)
end
function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ts::Timeseries, prob)
for (buf_idx, ts_idx) in zip(eachindex(buffer), eachindex(parameter_timeseries(prob, indexer_timeseries_index(mpg))))
for (buf_idx, ts_idx) in zip(eachindex(buffer),
eachindex(parameter_timeseries(prob, indexer_timeseries_index(mpg))))
mpg(buffer[buf_idx], ts, prob, ts_idx)
end
return buffer
end
function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ::Timeseries, prob, i::Union{Int, CartesianIndex})
function (mpg::AtLeastTimeseriesMPG)(
buffer::AbstractArray, ::Timeseries, prob, i::Union{Int, CartesianIndex})
for (buf_idx, getter) in zip(eachindex(buffer), mpg.getters)
buffer[buf_idx] = getter(prob, i)
end
Expand All @@ -249,8 +286,10 @@ end
function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ts::Timeseries, prob, ::Colon)
mpg(buffer, ts, prob)
end
function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ts::Timeseries, prob, i::AbstractArray{Bool})
mpg(buffer, ts, prob, only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(mpg)), (i,))))
function (mpg::AtLeastTimeseriesMPG)(
buffer::AbstractArray, ts::Timeseries, prob, i::AbstractArray{Bool})
mpg(buffer, ts, prob,
only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(mpg)), (i,))))
end
function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ts::Timeseries, prob, i)
for (buf_idx, ts_idx) in zip(eachindex(buffer), i)
Expand All @@ -261,7 +300,8 @@ end
function (mpg::MultipleParametersGetter{IndexerTimeseries})(::NotTimeseries, prob)
throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, mpg))
end
function (mpg::MultipleParametersGetter{IndexerTimeseries})(::AbstractArray, ::NotTimeseries, prob)
function (mpg::MultipleParametersGetter{IndexerTimeseries})(
::AbstractArray, ::NotTimeseries, prob)
throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, mpg))
end

Expand Down
6 changes: 4 additions & 2 deletions src/parameter_timeseries_collection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ function Base.getindex(ptc::ParameterTimeseriesCollection, idx::ParameterTimeser
timeseries = ptc.collection[idx.timeseries_idx]
return getu(timeseries, idx.parameter_idx)(timeseries)
end
function Base.getindex(ptc::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex, subidx)
function Base.getindex(
ptc::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex, subidx)
timeseries = ptc.collection[idx.timeseries_idx]
return getu(timeseries, idx.parameter_idx)(timeseries, subidx)
end
Expand All @@ -68,7 +69,8 @@ function Base.getindex(ptc::ParameterTimeseriesCollection, ts_idx, subidx, param
return ptc[ParameterTimeseriesIndex(ts_idx, param_idx), subidx]
end

function parameter_values(ptc::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex, subidx)
function parameter_values(
ptc::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex, subidx)
return ptc[idx, subidx]
end
function parameter_timeseries(ptc::ParameterTimeseriesCollection, idx)
Expand Down
40 changes: 26 additions & 14 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ function set_state!(sys, val, idx)
state_values(sys)[idx] = val
end


"""
getu(indp, sym)
Expand Down Expand Up @@ -69,25 +68,32 @@ end
function (g::GetpAtStateTime)(ts::Timeseries, p_ts::Timeseries, prob)
g(ts, p_ts, is_indexer_timeseries(g.getter), prob)
end
function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::Union{IndexerTimeseries, IndexerBoth}, prob)
g.getter.((prob,), parameter_timeseries_at_state_time(prob, indexer_timeseries_index(g.getter)))
function (g::GetpAtStateTime)(
::Timeseries, ::Timeseries, ::Union{IndexerTimeseries, IndexerBoth}, prob)
g.getter.((prob,),
parameter_timeseries_at_state_time(prob, indexer_timeseries_index(g.getter)))
end
function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob)
g.getter(prob)
end
function (g::GetpAtStateTime)(ts::Timeseries, p_ts::Timeseries, prob, i)
g(ts, p_ts, is_indexer_timeseries(g.getter), prob, i)
end
function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::Union{IndexerTimeseries, IndexerBoth}, prob, i)
g.getter(prob, parameter_timeseries_at_state_time(prob, indexer_timeseries_index(g.getter), i))
function (g::GetpAtStateTime)(
::Timeseries, ::Timeseries, ::Union{IndexerTimeseries, IndexerBoth}, prob, i)
g.getter(prob,
parameter_timeseries_at_state_time(prob, indexer_timeseries_index(g.getter), i))
end
function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, ::Union{Int, CartesianIndex})
function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries,
prob, ::Union{Int, CartesianIndex})
g.getter(prob)
end
function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, ::Colon)
function (g::GetpAtStateTime)(
::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, ::Colon)
map(_ -> g.getter(prob), current_time(prob))
end
function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, i::AbstractArray{Bool})
function (g::GetpAtStateTime)(
::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, i::AbstractArray{Bool})
num_ones = sum(i)
map(_ -> g.getter(prob), 1:num_ones)
end
Expand Down Expand Up @@ -123,23 +129,28 @@ end
function (o::TimeDependentObservedFunction)(ts::Timeseries, prob, i)
return o(ts, is_parameter_timeseries(prob), prob, i)
end
function (o::TimeDependentObservedFunction)(::Timeseries, ::Timeseries, prob, i::Union{Int, CartesianIndex})
function (o::TimeDependentObservedFunction)(
::Timeseries, ::Timeseries, prob, i::Union{Int, CartesianIndex})
return o.obsfn(state_values(prob, i),
parameter_values_at_state_time(prob, i),
current_time(prob, i))
end
function (o::TimeDependentObservedFunction)(ts::Timeseries, p_ts::IsTimeseriesTrait, prob, ::Colon)
function (o::TimeDependentObservedFunction)(
ts::Timeseries, p_ts::IsTimeseriesTrait, prob, ::Colon)
return o(ts, p_ts, prob)
end
function (o::TimeDependentObservedFunction)(ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i::AbstractArray{Bool})
function (o::TimeDependentObservedFunction)(
ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i::AbstractArray{Bool})
map(only(to_indices(current_time(prob), (i,)))) do idx
o(ts, p_ts, prob, idx)
end
end
function (o::TimeDependentObservedFunction)(ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i)
function (o::TimeDependentObservedFunction)(
ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i)
o.((ts,), (p_ts,), (prob,), i)
end
function (o::TimeDependentObservedFunction)(::Timeseries, ::NotTimeseries, prob, i::Union{Int, CartesianIndex})
function (o::TimeDependentObservedFunction)(
::Timeseries, ::NotTimeseries, prob, i::Union{Int, CartesianIndex})
o.obsfn(state_values(prob, i), parameter_values(prob), current_time(prob, i))
end
function (o::TimeDependentObservedFunction)(::NotTimeseries, prob)
Expand Down Expand Up @@ -231,7 +242,8 @@ for (t1, t2) in [
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
num_observed = count(x -> is_observed(sys, x), sym)
if num_observed == 0
if all(Base.Fix1(is_parameter, sys), sym) && all(!Base.Fix1(is_timeseries_parameter, sys), sym)
if all(Base.Fix1(is_parameter, sys), sym) &&
all(!Base.Fix1(is_timeseries_parameter, sys), sym)
GetpAtStateTime(getp(sys, sym))
else
getters = getu.((sys,), sym)
Expand Down
9 changes: 6 additions & 3 deletions src/symbol_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ function SymbolCache(vars = nothing, params = nothing, indepvars = nothing;
throw(ArgumentError("Timeseries parameter $k must also be present in parameters."))
end
if !isa(v, ParameterTimeseriesIndex)
throw(TypeError(:SymbolCache, "index of timeseries parameter $k", ParameterTimeseriesIndex, v))
throw(TypeError(:SymbolCache, "index of timeseries parameter $k",
ParameterTimeseriesIndex, v))
end
end
end
return SymbolCache{typeof(vars), typeof(params), typeof(timeseries_parameters), typeof(indepvars), typeof(defaults)}(
return SymbolCache{typeof(vars), typeof(params), typeof(timeseries_parameters),
typeof(indepvars), typeof(defaults)}(
vars,
params,
timeseries_parameters,
Expand Down Expand Up @@ -93,7 +95,8 @@ function is_timeseries_parameter(sc::SymbolCache, sym)
sc.timeseries_parameters !== nothing && haskey(sc.timeseries_parameters, sym)
end
function timeseries_parameter_index(sc::SymbolCache, sym)
sc.timeseries_parameters === nothing ? nothing : get(sc.timeseries_parameters, sym, nothing)
sc.timeseries_parameters === nothing ? nothing :
get(sc.timeseries_parameters, sym, nothing)
end
function is_independent_variable(sc::SymbolCache, sym)
sc.independent_variables === nothing && return false
Expand Down
Loading

0 comments on commit 6c9c517

Please sign in to comment.