From 6c9c5174a0796022f556a4d488dc40dc151aab25 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 13 May 2024 13:56:11 +0530 Subject: [PATCH] fixup! wip: better parameter indexing --- src/parameter_indexing.jl | 108 +++++++++++++++++-------- src/parameter_timeseries_collection.jl | 6 +- src/state_indexing.jl | 40 +++++---- src/symbol_cache.jl | 9 ++- src/value_provider_interface.jl | 13 +-- test/parameter_indexing_test.jl | 97 +++++++++++++--------- test/state_indexing_test.jl | 51 ++++++------ test/symbol_cache_test.jl | 16 ++-- 8 files changed, 213 insertions(+), 127 deletions(-) diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index 32e1ca26..239ff702 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -45,7 +45,10 @@ struct GetParameterIndex{I} <: AbstractParameterGetIndexer end is_indexer_timeseries(::Type{GetParameterIndex{I}}) where {I} = IndexerNotTimeseries() -is_indexer_timeseries(::Type{GetParameterIndex{I}}) where {I <: ParameterTimeseriesIndex} = IndexerTimeseries() +function is_indexer_timeseries(::Type{GetParameterIndex{I}}) where {I <: + ParameterTimeseriesIndex} + IndexerTimeseries() +end function indexer_timeseries_index(gpi::GetParameterIndex{<:ParameterTimeseriesIndex}) gpi.idx.timeseries_idx end @@ -56,30 +59,39 @@ function (gpi::GetParameterIndex)(::Timeseries, prob, args) throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, gpi, args)) end function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob) - gpi.((ts,), (prob,), eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi)))) + gpi.((ts,), (prob,), + eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi)))) end -function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob) - for (buf_idx, ts_idx) in zip(eachindex(buffer), eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi)))) +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})( + buffer::AbstractArray, ts::Timeseries, prob) + for (buf_idx, ts_idx) in zip(eachindex(buffer), + eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi)))) buffer[buf_idx] = gpi(ts, prob, ts_idx) end return buffer end -function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::Timeseries, prob, i::Union{Int, CartesianIndex}) +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})( + ::Timeseries, prob, i::Union{Int, CartesianIndex}) parameter_values(prob, gpi.idx, i) end function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, ::Colon) gpi(ts, prob) end -function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob, ::Colon) +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})( + buffer::AbstractArray, ts::Timeseries, prob, ::Colon) gpi(buffer, ts, prob) end -function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, i::AbstractArray{Bool}) - map(only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(gpi)), (i,)))) do idx +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})( + ts::Timeseries, prob, i::AbstractArray{Bool}) + map(only(to_indices( + parameter_timeseries(prob, indexer_timeseries_index(gpi)), (i,)))) do idx gpi(ts, prob, idx) end end -function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob, i::AbstractArray{Bool}) - for (buf_idx, ts_idx) in zip(eachindex(buffer), only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(gpi)), (i,)))) +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})( + buffer::AbstractArray, ts::Timeseries, prob, i::AbstractArray{Bool}) + for (buf_idx, ts_idx) in zip(eachindex(buffer), + only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(gpi)), (i,)))) buffer[buf_idx] = gpi(ts, prob, ts_idx) end return buffer @@ -87,7 +99,8 @@ end function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, i) gpi.((ts,), (prob,), i) end -function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob, i) +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})( + buffer::AbstractArray, ts::Timeseries, prob, i) for (buf_idx, subidx) in zip(eachindex(buffer), i) buffer[buf_idx] = gpi(ts, prob, subidx) end @@ -96,7 +109,8 @@ end function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::NotTimeseries, prob) throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, gpi)) end -function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::AbstractArray, ::NotTimeseries, prob) +function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})( + ::AbstractArray, ::NotTimeseries, prob) throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, gpi)) end @@ -104,20 +118,27 @@ function _getp(sys, ::NotSymbolic, ::NotSymbolic, p) return GetParameterIndex(p) end -struct GetParameterTimeseriesIndex{I <: GetParameterIndex, J <: GetParameterIndex{<:ParameterTimeseriesIndex}} <: AbstractParameterGetIndexer +struct GetParameterTimeseriesIndex{ + I <: GetParameterIndex, J <: GetParameterIndex{<:ParameterTimeseriesIndex}} <: + AbstractParameterGetIndexer param_idx::I param_timeseries_idx::J end is_indexer_timeseries(::Type{G}) where {G <: GetParameterTimeseriesIndex} = IndexerBoth() -indexer_timeseries_index(gpti::GetParameterTimeseriesIndex) = indexer_timeseries_index(gpti.param_timeseries_idx) +function indexer_timeseries_index(gpti::GetParameterTimeseriesIndex) + indexer_timeseries_index(gpti.param_timeseries_idx) +end as_not_timeseries_indexer(::IndexerBoth, gpti::GetParameterTimeseriesIndex) = gpti.param_idx -as_timeseries_indexer(::IndexerBoth, gpti::GetParameterTimeseriesIndex) = gpti.param_timeseries_idx +function as_timeseries_indexer(::IndexerBoth, gpti::GetParameterTimeseriesIndex) + gpti.param_timeseries_idx +end function (gpti::GetParameterTimeseriesIndex)(ts::Timeseries, prob, args...) gpti.param_timeseries_idx(ts, prob, args...) end -function (gpti::GetParameterTimeseriesIndex)(buffer::AbstractArray, ts::Timeseries, prob, args...) +function (gpti::GetParameterTimeseriesIndex)( + buffer::AbstractArray, ts::Timeseries, prob, args...) gpti.param_timeseries_idx(buffer, ts, prob, args...) end function (gpti::GetParameterTimeseriesIndex)(ts::NotTimeseries, prob) @@ -128,17 +149,19 @@ function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) if is_timeseries_parameter(sys, p) ts_idx = timeseries_parameter_index(sys, p) - return GetParameterTimeseriesIndex(GetParameterIndex(idx), GetParameterIndex(ts_idx)) + return GetParameterTimeseriesIndex( + GetParameterIndex(idx), GetParameterIndex(ts_idx)) else return GetParameterIndex(idx) end end struct MixedTimeseriesIndexes - indexes + indexes::Any end -struct MultipleParametersGetter{T <: IsIndexerTimeseries, G, I} <: AbstractParameterGetIndexer +struct MultipleParametersGetter{T <: IsIndexerTimeseries, G, I} <: + AbstractParameterGetIndexer getters::G timeseries_idx::I @@ -165,7 +188,8 @@ struct MultipleParametersGetter{T <: IsIndexerTimeseries, G, I} <: AbstractParam IndexerBoth end - if indexer_type != IndexerNotTimeseries && !allequal(indexer_timeseries_index(g) for g in getters) + if indexer_type != IndexerNotTimeseries && + !allequal(indexer_timeseries_index(g) for g in getters) if indexer_type == IndexerTimeseries throw(ArgumentError("All parameters must belong to the same timeseries")) else @@ -175,29 +199,38 @@ struct MultipleParametersGetter{T <: IsIndexerTimeseries, G, I} <: AbstractParam end end - return new{indexer_type, typeof(getters), typeof(timeseries_idx)}(getters, timeseries_idx) + return new{indexer_type, typeof(getters), typeof(timeseries_idx)}( + getters, timeseries_idx) end end -const AtLeastTimeseriesMPG = Union{MultipleParametersGetter{IndexerTimeseries}, MultipleParametersGetter{IndexerBoth}} -const MixedTimeseriesIndexMPG = MultipleParametersGetter{IndexerNotTimeseries, G, MixedTimeseriesIndexes} where {G} +const AtLeastTimeseriesMPG = Union{ + MultipleParametersGetter{IndexerTimeseries}, MultipleParametersGetter{IndexerBoth}} +const MixedTimeseriesIndexMPG = MultipleParametersGetter{ + IndexerNotTimeseries, G, MixedTimeseriesIndexes} where {G} is_indexer_timeseries(::Type{<:MultipleParametersGetter{T}}) where {T} = T() function indexer_timeseries_index(mpg::MultipleParametersGetter) mpg.timeseries_idx end -as_not_timeseries_indexer(::IndexerBoth, mpg::MultipleParametersGetter) = MultipleParametersGetter(as_not_timeseries_indexer.(mpg.getters)) +function as_not_timeseries_indexer(::IndexerBoth, mpg::MultipleParametersGetter) + MultipleParametersGetter(as_not_timeseries_indexer.(mpg.getters)) +end -as_timeseries_indexer(::IndexerBoth, mpg::MultipleParametersGetter) = MultipleParametersGetter(as_timeseries_indexer.(mpg.getters)) +function as_timeseries_indexer(::IndexerBoth, mpg::MultipleParametersGetter) + MultipleParametersGetter(as_timeseries_indexer.(mpg.getters)) +end for (indexerTimeseriesType, timeseriesType) in [ (IndexerNotTimeseries, IsTimeseriesTrait), (IndexerBoth, NotTimeseries) ] - @eval function (mpg::MultipleParametersGetter{$indexerTimeseriesType})(::$timeseriesType, prob) + @eval function (mpg::MultipleParametersGetter{$indexerTimeseriesType})( + ::$timeseriesType, prob) CallWith(prob).(mpg.getters) end - @eval function (mpg::MultipleParametersGetter{$indexerTimeseriesType})(buffer::AbstractArray, ::$timeseriesType, prob) + @eval function (mpg::MultipleParametersGetter{$indexerTimeseriesType})( + buffer::AbstractArray, ::$timeseriesType, prob) for (buf_idx, getter) in zip(eachindex(buffer), mpg.getters) buffer[buf_idx] = getter(prob) end @@ -212,7 +245,8 @@ end function (mpg::MultipleParametersGetter{IndexerNotTimeseries})(::Timeseries, prob, args) throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, mpg, args)) end -function (mpg::MultipleParametersGetter{IndexerNotTimeseries})(::AbstractArray, ::Timeseries, prob, args) +function (mpg::MultipleParametersGetter{IndexerNotTimeseries})( + ::AbstractArray, ::Timeseries, prob, args) throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, mpg, args)) end function (mpg::AtLeastTimeseriesMPG)(ts::Timeseries, prob) @@ -227,7 +261,8 @@ function (mpg::AtLeastTimeseriesMPG)(ts::Timeseries, prob, ::Colon) mpg(ts, prob) end function (mpg::AtLeastTimeseriesMPG)(ts::Timeseries, prob, i::AbstractArray{Bool}) - map(only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(mpg)), (i,)))) do idx + map(only(to_indices( + parameter_timeseries(prob, indexer_timeseries_index(mpg)), (i,)))) do idx mpg(ts, prob, idx) end end @@ -235,12 +270,14 @@ function (mpg::AtLeastTimeseriesMPG)(ts::Timeseries, prob, i) mpg.((ts,), (prob,), i) end function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ts::Timeseries, prob) - for (buf_idx, ts_idx) in zip(eachindex(buffer), eachindex(parameter_timeseries(prob, indexer_timeseries_index(mpg)))) + for (buf_idx, ts_idx) in zip(eachindex(buffer), + eachindex(parameter_timeseries(prob, indexer_timeseries_index(mpg)))) mpg(buffer[buf_idx], ts, prob, ts_idx) end return buffer end -function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ::Timeseries, prob, i::Union{Int, CartesianIndex}) +function (mpg::AtLeastTimeseriesMPG)( + buffer::AbstractArray, ::Timeseries, prob, i::Union{Int, CartesianIndex}) for (buf_idx, getter) in zip(eachindex(buffer), mpg.getters) buffer[buf_idx] = getter(prob, i) end @@ -249,8 +286,10 @@ end function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ts::Timeseries, prob, ::Colon) mpg(buffer, ts, prob) end -function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ts::Timeseries, prob, i::AbstractArray{Bool}) - mpg(buffer, ts, prob, only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(mpg)), (i,)))) +function (mpg::AtLeastTimeseriesMPG)( + buffer::AbstractArray, ts::Timeseries, prob, i::AbstractArray{Bool}) + mpg(buffer, ts, prob, + only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(mpg)), (i,)))) end function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ts::Timeseries, prob, i) for (buf_idx, ts_idx) in zip(eachindex(buffer), i) @@ -261,7 +300,8 @@ end function (mpg::MultipleParametersGetter{IndexerTimeseries})(::NotTimeseries, prob) throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, mpg)) end -function (mpg::MultipleParametersGetter{IndexerTimeseries})(::AbstractArray, ::NotTimeseries, prob) +function (mpg::MultipleParametersGetter{IndexerTimeseries})( + ::AbstractArray, ::NotTimeseries, prob) throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, mpg)) end diff --git a/src/parameter_timeseries_collection.jl b/src/parameter_timeseries_collection.jl index 95b467d1..23e0c742 100644 --- a/src/parameter_timeseries_collection.jl +++ b/src/parameter_timeseries_collection.jl @@ -57,7 +57,8 @@ function Base.getindex(ptc::ParameterTimeseriesCollection, idx::ParameterTimeser timeseries = ptc.collection[idx.timeseries_idx] return getu(timeseries, idx.parameter_idx)(timeseries) end -function Base.getindex(ptc::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex, subidx) +function Base.getindex( + ptc::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex, subidx) timeseries = ptc.collection[idx.timeseries_idx] return getu(timeseries, idx.parameter_idx)(timeseries, subidx) end @@ -68,7 +69,8 @@ function Base.getindex(ptc::ParameterTimeseriesCollection, ts_idx, subidx, param return ptc[ParameterTimeseriesIndex(ts_idx, param_idx), subidx] end -function parameter_values(ptc::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex, subidx) +function parameter_values( + ptc::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex, subidx) return ptc[idx, subidx] end function parameter_timeseries(ptc::ParameterTimeseriesCollection, idx) diff --git a/src/state_indexing.jl b/src/state_indexing.jl index fbfebe70..98209855 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -2,7 +2,6 @@ function set_state!(sys, val, idx) state_values(sys)[idx] = val end - """ getu(indp, sym) @@ -69,8 +68,10 @@ end function (g::GetpAtStateTime)(ts::Timeseries, p_ts::Timeseries, prob) g(ts, p_ts, is_indexer_timeseries(g.getter), prob) end -function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::Union{IndexerTimeseries, IndexerBoth}, prob) - g.getter.((prob,), parameter_timeseries_at_state_time(prob, indexer_timeseries_index(g.getter))) +function (g::GetpAtStateTime)( + ::Timeseries, ::Timeseries, ::Union{IndexerTimeseries, IndexerBoth}, prob) + g.getter.((prob,), + parameter_timeseries_at_state_time(prob, indexer_timeseries_index(g.getter))) end function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob) g.getter(prob) @@ -78,16 +79,21 @@ end function (g::GetpAtStateTime)(ts::Timeseries, p_ts::Timeseries, prob, i) g(ts, p_ts, is_indexer_timeseries(g.getter), prob, i) end -function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::Union{IndexerTimeseries, IndexerBoth}, prob, i) - g.getter(prob, parameter_timeseries_at_state_time(prob, indexer_timeseries_index(g.getter), i)) +function (g::GetpAtStateTime)( + ::Timeseries, ::Timeseries, ::Union{IndexerTimeseries, IndexerBoth}, prob, i) + g.getter(prob, + parameter_timeseries_at_state_time(prob, indexer_timeseries_index(g.getter), i)) end -function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, ::Union{Int, CartesianIndex}) +function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, + prob, ::Union{Int, CartesianIndex}) g.getter(prob) end -function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, ::Colon) +function (g::GetpAtStateTime)( + ::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, ::Colon) map(_ -> g.getter(prob), current_time(prob)) end -function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, i::AbstractArray{Bool}) +function (g::GetpAtStateTime)( + ::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, i::AbstractArray{Bool}) num_ones = sum(i) map(_ -> g.getter(prob), 1:num_ones) end @@ -123,23 +129,28 @@ end function (o::TimeDependentObservedFunction)(ts::Timeseries, prob, i) return o(ts, is_parameter_timeseries(prob), prob, i) end -function (o::TimeDependentObservedFunction)(::Timeseries, ::Timeseries, prob, i::Union{Int, CartesianIndex}) +function (o::TimeDependentObservedFunction)( + ::Timeseries, ::Timeseries, prob, i::Union{Int, CartesianIndex}) return o.obsfn(state_values(prob, i), parameter_values_at_state_time(prob, i), current_time(prob, i)) end -function (o::TimeDependentObservedFunction)(ts::Timeseries, p_ts::IsTimeseriesTrait, prob, ::Colon) +function (o::TimeDependentObservedFunction)( + ts::Timeseries, p_ts::IsTimeseriesTrait, prob, ::Colon) return o(ts, p_ts, prob) end -function (o::TimeDependentObservedFunction)(ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i::AbstractArray{Bool}) +function (o::TimeDependentObservedFunction)( + ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i::AbstractArray{Bool}) map(only(to_indices(current_time(prob), (i,)))) do idx o(ts, p_ts, prob, idx) end end -function (o::TimeDependentObservedFunction)(ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i) +function (o::TimeDependentObservedFunction)( + ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i) o.((ts,), (p_ts,), (prob,), i) end -function (o::TimeDependentObservedFunction)(::Timeseries, ::NotTimeseries, prob, i::Union{Int, CartesianIndex}) +function (o::TimeDependentObservedFunction)( + ::Timeseries, ::NotTimeseries, prob, i::Union{Int, CartesianIndex}) o.obsfn(state_values(prob, i), parameter_values(prob), current_time(prob, i)) end function (o::TimeDependentObservedFunction)(::NotTimeseries, prob) @@ -231,7 +242,8 @@ for (t1, t2) in [ @eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2) num_observed = count(x -> is_observed(sys, x), sym) if num_observed == 0 - if all(Base.Fix1(is_parameter, sys), sym) && all(!Base.Fix1(is_timeseries_parameter, sys), sym) + if all(Base.Fix1(is_parameter, sys), sym) && + all(!Base.Fix1(is_timeseries_parameter, sys), sym) GetpAtStateTime(getp(sys, sym)) else getters = getu.((sys,), sym) diff --git a/src/symbol_cache.jl b/src/symbol_cache.jl index 664a3e05..577263ae 100644 --- a/src/symbol_cache.jl +++ b/src/symbol_cache.jl @@ -54,11 +54,13 @@ function SymbolCache(vars = nothing, params = nothing, indepvars = nothing; throw(ArgumentError("Timeseries parameter $k must also be present in parameters.")) end if !isa(v, ParameterTimeseriesIndex) - throw(TypeError(:SymbolCache, "index of timeseries parameter $k", ParameterTimeseriesIndex, v)) + throw(TypeError(:SymbolCache, "index of timeseries parameter $k", + ParameterTimeseriesIndex, v)) end end end - return SymbolCache{typeof(vars), typeof(params), typeof(timeseries_parameters), typeof(indepvars), typeof(defaults)}( + return SymbolCache{typeof(vars), typeof(params), typeof(timeseries_parameters), + typeof(indepvars), typeof(defaults)}( vars, params, timeseries_parameters, @@ -93,7 +95,8 @@ function is_timeseries_parameter(sc::SymbolCache, sym) sc.timeseries_parameters !== nothing && haskey(sc.timeseries_parameters, sym) end function timeseries_parameter_index(sc::SymbolCache, sym) - sc.timeseries_parameters === nothing ? nothing : get(sc.timeseries_parameters, sym, nothing) + sc.timeseries_parameters === nothing ? nothing : + get(sc.timeseries_parameters, sym, nothing) end function is_independent_variable(sc::SymbolCache, sym) sc.independent_variables === nothing && return false diff --git a/src/value_provider_interface.jl b/src/value_provider_interface.jl index 6a785058..3c6deb96 100644 --- a/src/value_provider_interface.jl +++ b/src/value_provider_interface.jl @@ -264,9 +264,9 @@ end ########### struct ParameterTimeseriesValueIndexMismatchError{P <: IsTimeseriesTrait} <: Exception - valp - indexer - args + valp::Any + indexer::Any + args::Any function ParameterTimeseriesValueIndexMismatchError{Timeseries}(valp, indexer, args) if is_parameter_timeseries(valp) != Timeseries() @@ -309,7 +309,8 @@ function Base.showerror(io::IO, err::ParameterTimeseriesValueIndexMismatchError{ """) end -function Base.showerror(io::IO, err::ParameterTimeseriesValueIndexMismatchError{NotTimeseries}) +function Base.showerror( + io::IO, err::ParameterTimeseriesValueIndexMismatchError{NotTimeseries}) print(io, """ Invalid indexing operation: tried to access object of type $(typeof(err.valp)) \ (which is not a parameter timeseries object) using timeseries indexer \ @@ -318,8 +319,8 @@ function Base.showerror(io::IO, err::ParameterTimeseriesValueIndexMismatchError{ end struct MixedParameterTimeseriesIndexError <: Exception - valp - ts_idxs + valp::Any + ts_idxs::Any end function Base.showerror(io::IO, err::MixedParameterTimeseriesIndexError) diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 8f9001e6..3bebd08d 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -28,7 +28,11 @@ end for sys in [ SymbolCache([:x, :y, :z], [:a, :b, :c, :d], [:t]), - SymbolCache([:x, :y, :z], [:a, :b, :c, :d], [:t], timeseries_parameters = Dict(:b => ParameterTimeseriesIndex(1, 1), :c => ParameterTimeseriesIndex(2, 1))), + SymbolCache([:x, :y, :z], + [:a, :b, :c, :d], + [:t], + timeseries_parameters = Dict( + :b => ParameterTimeseriesIndex(1, 1), :c => ParameterTimeseriesIndex(2, 1))) ] has_ts = sys.timeseries_parameters !== nothing for pType in [Vector, Tuple] @@ -141,7 +145,8 @@ SymbolicIndexingInterface.current_time(fs::FakeSolution) = fs.t SymbolicIndexingInterface.symbolic_container(fs::FakeSolution) = fs.sys SymbolicIndexingInterface.parameter_values(fs::FakeSolution) = fs.p SymbolicIndexingInterface.parameter_values(fs::FakeSolution, i) = fs.p[i] -function SymbolicIndexingInterface.parameter_values(fs::FakeSolution, i::ParameterTimeseriesIndex, j) +function SymbolicIndexingInterface.parameter_values( + fs::FakeSolution, i::ParameterTimeseriesIndex, j) parameter_values(fs.p_ts, i, j) end function SymbolicIndexingInterface.parameter_values_at_state_time(fs::FakeSolution, t) @@ -154,10 +159,16 @@ function SymbolicIndexingInterface.parameter_values_at_state_time(fs::FakeSoluti end return p end -SymbolicIndexingInterface.parameter_timeseries(fs::FakeSolution, idx) = parameter_timeseries(fs.p_ts, idx) +function SymbolicIndexingInterface.parameter_timeseries(fs::FakeSolution, idx) + parameter_timeseries(fs.p_ts, idx) +end SymbolicIndexingInterface.is_timeseries(::Type{FakeSolution}) = Timeseries() SymbolicIndexingInterface.is_parameter_timeseries(::Type{FakeSolution}) = Timeseries() -sys = SymbolCache([:x, :y, :z], [:a, :b, :c, :d], :t; timeseries_parameters = Dict(:b => ParameterTimeseriesIndex(1, 1), :c => ParameterTimeseriesIndex(2, 1))) +sys = SymbolCache([:x, :y, :z], + [:a, :b, :c, :d], + :t; + timeseries_parameters = Dict( + :b => ParameterTimeseriesIndex(1, 1), :c => ParameterTimeseriesIndex(2, 1))) b_timeseries = MyDiffEqArray(collect(0:0.1:0.9), [[2.5i] for i in 1:10]) c_timeseries = MyDiffEqArray(collect(0:0.25:0.9), [[3.5i] for i in 1:4]) fs = FakeSolution( @@ -176,7 +187,7 @@ bidx = timeseries_parameter_index(sys, :b) cidx = timeseries_parameter_index(sys, :c) for (sym, indexer_trait, timeseries_index, val, buffer, check_inference) in [ - (:a, IndexerNotTimeseries, 0, aval, nothing, true), + (:a, IndexerNotTimeseries, 0, aval, nothing, true), (1, IndexerNotTimeseries, 0, aval, nothing, true), ([:a, :d], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true), ((:a, :d), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true), @@ -186,16 +197,17 @@ for (sym, indexer_trait, timeseries_index, val, buffer, check_inference) in [ ((:a, 4), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true), (:b, IndexerBoth, 1, bval, zeros(length(bval)), true), (bidx, IndexerTimeseries, 1, bval, zeros(length(bval)), true), - ([:a, :b], IndexerNotTimeseries, 0, [aval, bval[end]], zeros(2), true), - ((:a, :b), IndexerNotTimeseries, 0, (aval, bval[end]), zeros(2), true), + ([:a, :b], IndexerNotTimeseries, 0, [aval, bval[end]], zeros(2), true), + ((:a, :b), IndexerNotTimeseries, 0, (aval, bval[end]), zeros(2), true), ([1, :b], IndexerNotTimeseries, 0, [aval, bval[end]], zeros(2), true), ((1, :b), IndexerNotTimeseries, 0, (aval, bval[end]), zeros(2), true), ([:b, :b], IndexerBoth, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), - ((:b, :b), IndexerBoth, 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), + ((:b, :b), IndexerBoth, 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), ([bidx, :b], IndexerTimeseries, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), - ((bidx, :b), IndexerTimeseries, 1, tuple.(bval, bval),map(_ -> zeros(2), bval), true), - ([bidx, bidx], IndexerTimeseries, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), - ((bidx, bidx), IndexerTimeseries, 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), + ((bidx, :b), IndexerTimeseries, 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), + ([bidx, bidx], IndexerTimeseries, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), + ((bidx, bidx), IndexerTimeseries, 1, + tuple.(bval, bval), map(_ -> zeros(2), bval), true) ] getter = getp(sys, sym) @test is_indexer_timeseries(getter) isa indexer_trait @@ -206,7 +218,8 @@ for (sym, indexer_trait, timeseries_index, val, buffer, check_inference) in [ test_non_timeseries = indexer_trait !== IndexerTimeseries if test_inplace && test_non_timeseries non_timeseries_val = indexer_trait == IndexerNotTimeseries ? val : val[end] - non_timeseries_buffer = indexer_trait == IndexerNotTimeseries ? deepcopy(buffer) : deepcopy(buffer[end]) + non_timeseries_buffer = indexer_trait == IndexerNotTimeseries ? deepcopy(buffer) : + deepcopy(buffer[end]) test_non_timeseries_inplace = non_timeseries_buffer isa AbstractArray end if check_inference @@ -249,14 +262,18 @@ for (sym, indexer_trait, timeseries_index, val, buffer, check_inference) in [ else @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter(parameter_values(fs)) if test_inplace - @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter([], parameter_values(fs)) + @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter( + [], parameter_values(fs)) end end - for subidx in [1, CartesianIndex(1), :, rand(Bool, length(val)), rand(eachindex(val), 3), 1:2] + for subidx in [ + 1, CartesianIndex(1), :, rand(Bool, length(val)), rand(eachindex(val), 3), 1:2] if indexer_trait <: IndexerNotTimeseries - @test_throws ParameterTimeseriesValueIndexMismatchError{Timeseries} getter(fs, subidx) + @test_throws ParameterTimeseriesValueIndexMismatchError{Timeseries} getter( + fs, subidx) if test_inplace - @test_throws ParameterTimeseriesValueIndexMismatchError{Timeseries} getter([], fs, subidx) + @test_throws ParameterTimeseriesValueIndexMismatchError{Timeseries} getter( + [], fs, subidx) end else if check_inference @@ -282,13 +299,14 @@ for (sym, indexer_trait, timeseries_index, val, buffer, check_inference) in [ end end -for sym in [[:a, bidx], (:a, bidx), [1, bidx], (1, bidx), [bidx, :c], (bidx, :c), [bidx, cidx], (bidx, cidx)] +for sym in [[:a, bidx], (:a, bidx), [1, bidx], (1, bidx), + [bidx, :c], (bidx, :c), [bidx, cidx], (bidx, cidx)] @test_throws ArgumentError getp(sys, sym) end for (sym, val) in [ ([:b, :c], [bval[end], cval[end]]), - ((:b, :c), (bval[end], cval[end])), + ((:b, :c), (bval[end], cval[end])) ] getter = getp(sys, sym) @test is_indexer_timeseries(getter) == IndexerNotTimeseries() @@ -301,26 +319,26 @@ cval_state = [c_timeseries.u[searchsortedlast(c_timeseries.t, t)][] for t in fs. xval = getindex.(fs.u, 1) for (sym, val_is_timeseries, val, check_inference) in [ - (:a, false, aval, true), - ([:a, :d], false, [aval, dval], true), - ((:a, :d), false, (aval, dval), true), - (:b, true, bval_state, true), - ([:a, :b], true, vcat.(aval, bval_state), false), - ((:a, :b), true, tuple.(aval, bval_state), true), - ([:b, :c], true, vcat.(bval_state, cval_state), true), - ((:b, :c), true, tuple.(bval_state, cval_state), true), - ([:a, :b, :c], true, vcat.(aval, bval_state, cval_state), false), - ((:a, :b, :c), true, tuple.(aval, bval_state, cval_state), true), - ([:x, :b], true, vcat.(xval, bval_state), false), - ((:x, :b), true, tuple.(xval, bval_state), true), - ([:x, :b, :c], true, vcat.(xval, bval_state, cval_state), false), - ((:x, :b, :c), true, tuple.(xval, bval_state, cval_state), true), - ([:a, :b, :x], true, vcat.(aval, bval_state, xval), false), - ((:a, :b, :x), true, tuple.(aval, bval_state, xval), true), - (:(2b), true, 2 .* bval_state, true), - ([:x, :(2b), :(3c)], true, vcat.(xval, 2 .* bval_state, 3 .* cval_state), true), - ((:x, :(2b), :(3c)), true, tuple.(xval, 2 .* bval_state, 3 .* cval_state), true), - ] + (:a, false, aval, true), + ([:a, :d], false, [aval, dval], true), + ((:a, :d), false, (aval, dval), true), + (:b, true, bval_state, true), + ([:a, :b], true, vcat.(aval, bval_state), false), + ((:a, :b), true, tuple.(aval, bval_state), true), + ([:b, :c], true, vcat.(bval_state, cval_state), true), + ((:b, :c), true, tuple.(bval_state, cval_state), true), + ([:a, :b, :c], true, vcat.(aval, bval_state, cval_state), false), + ((:a, :b, :c), true, tuple.(aval, bval_state, cval_state), true), + ([:x, :b], true, vcat.(xval, bval_state), false), + ((:x, :b), true, tuple.(xval, bval_state), true), + ([:x, :b, :c], true, vcat.(xval, bval_state, cval_state), false), + ((:x, :b, :c), true, tuple.(xval, bval_state, cval_state), true), + ([:a, :b, :x], true, vcat.(aval, bval_state, xval), false), + ((:a, :b, :x), true, tuple.(aval, bval_state, xval), true), + (:(2b), true, 2 .* bval_state, true), + ([:x, :(2b), :(3c)], true, vcat.(xval, 2 .* bval_state, 3 .* cval_state), true), + ((:x, :(2b), :(3c)), true, tuple.(xval, 2 .* bval_state, 3 .* cval_state), true) +] getter = getu(sys, sym) if val isa DataType @test_throws val getter(fs) @@ -331,7 +349,8 @@ for (sym, val_is_timeseries, val, check_inference) in [ end @test getter(fs) == val - for subidx in [1, CartesianIndex(2), :, rand(Bool, length(fs.t)), rand(eachindex(fs.t), 3), 1:2] + for subidx in [ + 1, CartesianIndex(2), :, rand(Bool, length(fs.t)), rand(eachindex(fs.t), 3), 1:2] if check_inference @inferred getter(fs, subidx) end diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index fb53a034..d8392636 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -134,7 +134,10 @@ struct FakeSolution{S, U, P, T} end SymbolicIndexingInterface.is_timeseries(::Type{<:FakeSolution}) = Timeseries() -SymbolicIndexingInterface.is_timeseries(::Type{<:FakeSolution{S, U, P, Nothing}}) where {S, U, P} = NotTimeseries() +function SymbolicIndexingInterface.is_timeseries(::Type{<:FakeSolution{ + S, U, P, Nothing}}) where {S, U, P} + NotTimeseries() +end SymbolicIndexingInterface.symbolic_container(fp::FakeSolution) = fp.sys SymbolicIndexingInterface.state_values(fp::FakeSolution) = fp.u SymbolicIndexingInterface.parameter_values(fp::FakeSolution) = fp.p @@ -191,7 +194,8 @@ for (sym, ans, check_inference) in [(:x, xvals, true) @inferred get(sol) end @test get(sol) == ans - for i in [rand(eachindex(u)), CartesianIndex(1), :, rand(Bool, length(u)), rand(eachindex(u), 3), 1:3] + for i in [rand(eachindex(u)), CartesianIndex(1), :, + rand(Bool, length(u)), rand(eachindex(u), 3), 1:3] if check_inference @inferred get(sol, i) end @@ -209,7 +213,8 @@ for (sym, val, check_inference) in [ @inferred get(sol) end @test get(sol) == val - for i in [rand(eachindex(u)), CartesianIndex(1), :, rand(Bool, length(u)), rand(eachindex(u), 3), 1:3] + for i in [rand(eachindex(u)), CartesianIndex(1), :, + rand(Bool, length(u)), rand(eachindex(u), 3), 1:3] if check_inference @inferred get(sol, i) end @@ -234,26 +239,26 @@ fs = FakeSolution(sys, u, p, nothing) @test is_timeseries(fs) == NotTimeseries() for (sym, val, check_inference) in [ - (:x, u[1], true), - (1, u[1], true), - ([:x, :y], u[1:2], true), - ((:x, :y), Tuple(u[1:2]), true), - (1:2, u[1:2], true), - ([:x, 2], u[1:2], true), - ((:x, 2), Tuple(u[1:2]), true), - ([1, 2], u[1:2], true), - ((1, 2), Tuple(u[1:2]), true), - (:a, p[1], true), - ([:a, :b], p[1:2], true), - ((:a, :b), Tuple(p[1:2]), true), - ([:x, :a], [u[1], p[1]], false), - ((:x, :a), (u[1], p[1]), true), - ([1, :a], [u[1], p[1]], false), - ((1, :a), (u[1], p[1]), true), - (:(x+y+a+b), u[1] + u[2] + p[1] + p[2], true), - ([:(x+a), :(y+b)], [u[1] + p[1], u[2] + p[2]], true), - ((:(x+a), :(y+b)), (u[1] + p[1], u[2] + p[2]), true), - ] + (:x, u[1], true), + (1, u[1], true), + ([:x, :y], u[1:2], true), + ((:x, :y), Tuple(u[1:2]), true), + (1:2, u[1:2], true), + ([:x, 2], u[1:2], true), + ((:x, 2), Tuple(u[1:2]), true), + ([1, 2], u[1:2], true), + ((1, 2), Tuple(u[1:2]), true), + (:a, p[1], true), + ([:a, :b], p[1:2], true), + ((:a, :b), Tuple(p[1:2]), true), + ([:x, :a], [u[1], p[1]], false), + ((:x, :a), (u[1], p[1]), true), + ([1, :a], [u[1], p[1]], false), + ((1, :a), (u[1], p[1]), true), + (:(x + y + a + b), u[1] + u[2] + p[1] + p[2], true), + ([:(x + a), :(y + b)], [u[1] + p[1], u[2] + p[2]], true), + ((:(x + a), :(y + b)), (u[1] + p[1], u[2] + p[2]), true) +] getter = getu(sys, sym) if check_inference @inferred getter(fs) diff --git a/test/symbol_cache_test.jl b/test/symbol_cache_test.jl index 5abae5d4..136bb0bd 100644 --- a/test/symbol_cache_test.jl +++ b/test/symbol_cache_test.jl @@ -45,12 +45,15 @@ obsfn4 = observed(sc, [:(x + a) :(y + b); :(x + y) :(a + b)]) obsfn5 = observed(sc, (:(x + a), :(y + b))) @test all(obsfn5(ones(3), 2ones(2), 3.0) .≈ (3.0, 3.0)) -@test_throws TypeError observed(sc, [:(x+a), 2]) -@test_throws TypeError observed(sc, (:(x+a), 2)) +@test_throws TypeError observed(sc, [:(x + a), 2]) +@test_throws TypeError observed(sc, (:(x + a), 2)) -@test_throws ArgumentError SymbolCache([:x, :y], [:a, :b], :t; timeseries_parameters = Dict(:c => ParameterTimeseriesIndex(1, 1))) -@test_throws TypeError SymbolCache([:x, :y], [:a, :c], :t; timeseries_parameters = Dict(:c => (1, 1))) -@test_nowarn SymbolCache([:x, :y], [:a, :c], :t; timeseries_parameters = Dict(:c => ParameterTimeseriesIndex(1, 1))) +@test_throws ArgumentError SymbolCache([:x, :y], [:a, :b], :t; + timeseries_parameters = Dict(:c => ParameterTimeseriesIndex(1, 1))) +@test_throws TypeError SymbolCache( + [:x, :y], [:a, :c], :t; timeseries_parameters = Dict(:c => (1, 1))) +@test_nowarn SymbolCache([:x, :y], [:a, :c], :t; + timeseries_parameters = Dict(:c => ParameterTimeseriesIndex(1, 1))) sc = SymbolCache([:x, :y], [:a, :b]) @test !is_time_dependent(sc) @@ -61,7 +64,8 @@ obsfn = observed(sc, :(x + b)) # make sure the constructor works @test_nowarn SymbolCache([:x, :y]) -@test_throws ArgumentError SymbolCache([:x, :y], [:a, :b]; timeseries_parameters = Dict(:b => ParameterTimeseriesIndex(1, 1))) +@test_throws ArgumentError SymbolCache( + [:x, :y], [:a, :b]; timeseries_parameters = Dict(:b => ParameterTimeseriesIndex(1, 1))) sc = SymbolCache() @test all(.!is_variable.((sc,), [:x, :y, :a, :b, :t]))