diff --git a/docs/src/api.md b/docs/src/api.md index 8286edf..bb4558d 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -31,7 +31,6 @@ allvariables ```@docs observed parameter_observed -ParameterObservedFunction ``` #### Parameter timeseries @@ -46,6 +45,8 @@ may change at different times. is_timeseries_parameter timeseries_parameter_index ParameterTimeseriesIndex +get_all_timeseries_indexes +ContinuousTimeseries ``` ## Value provider interface diff --git a/docs/src/complete_sii.md b/docs/src/complete_sii.md index 0fd6509..bf60549 100644 --- a/docs/src/complete_sii.md +++ b/docs/src/complete_sii.md @@ -132,10 +132,8 @@ end In case a type does not support such observed quantities, `is_observed` must be defined to always return `false`, and `observed` does not need to be implemented. -The same process can be followed for [`parameter_observed`](@ref), with the exception -that the returned function must not have `u` in its signature, and must be wrapped in a -[`ParameterObservedFunction`](@ref). In-place versions can also be implemented for -`parameter_observed`. +The same process can be followed for [`parameter_observed`](@ref). In-place versions +can also be implemented for `parameter_observed`. #### Note about constant structure @@ -334,7 +332,9 @@ end # To be able to access parameter values SymbolicIndexingInterface.parameter_values(mpo::MyParameterObject) = mpo.p # Update the parameter object with new values -function SymbolicIndexingInterface.with_updated_parameter_timeseries_values(mpo::MyParameterObject, args::Pair...) +# Here, we don't need the index provider but it may be necessary for other implementations +function SymbolicIndexingInterface.with_updated_parameter_timeseries_values( + ::SymbolCache, mpo::MyParameterObject, args::Pair...) for (ts_idx, val) in args mpo.p[mpo.disc_idxs[ts_idx]] = val end diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 93c5398..5e302e8 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -14,7 +14,8 @@ include("trait.jl") export is_variable, variable_index, variable_symbols, is_parameter, parameter_index, is_timeseries_parameter, timeseries_parameter_index, ParameterTimeseriesIndex, parameter_symbols, is_independent_variable, independent_variable_symbols, - is_observed, observed, parameter_observed, ParameterObservedFunction, + is_observed, observed, parameter_observed, + ContinuousTimeseries, get_all_timeseries_indexes, is_time_dependent, constant_structure, symbolic_container, all_variable_symbols, all_symbols, solvedvariables, allvariables, default_values, symbolic_evaluate diff --git a/src/index_provider_interface.jl b/src/index_provider_interface.jl index a114450..d78c48a 100644 --- a/src/index_provider_interface.jl +++ b/src/index_provider_interface.jl @@ -98,38 +98,17 @@ function timeseries_parameter_index(indp, sym) end end -""" - struct ParameterObservedFunction - function ParameterObservedFunction(timeseries_idx, observed_fn::Function) - function ParameterObservedFunction(observed_fn::Function) - -A struct which stores the parameter observed function and optional timeseries index for -a particular symbol. The timeseries index is optional and may be omitted. Specifying the -timeseries index allows [`getp`](@ref) to return the appropriate timeseries for a -timeseries parameter. - -For time-dependent index providers (where `is_time_dependent(indp)`) `observed_fn` must -have the signature `(p, t) -> [values...]`. For non-time-dependent index providers -(where `!is_time_dependent(indp)`) `observed_fn` must have the signature -`(p) -> [values...]`. To support in-place `getp` methods, `observed_fn` must also have an -additional method which takes `buffer::AbstractArray` as its first argument. The required -values must be written to the buffer in the appropriate order. -""" -struct ParameterObservedFunction{I, F <: Function} - timeseries_idx::I - observed_fn::F -end - """ parameter_observed(indp, sym) -Return the observed function of `sym` in `indp` as a [`ParameterObservedFunction`](@ref). -If `sym` only involves variables from a single parameter timeseries (optionally along -with non-timeseries parameters) the timeseries index of the parameter timeseries should -be provided in the [`ParameterObservedFunction`](@ref). In all other cases, just the -observed function should be returned as part of the `ParameterObservedFunction` object. +Return the observed function of `sym` in `indp`. This functions similarly to +[`observed`](@ref) except that `u` is not an argument of the returned function. For time- +dependent systems, the returned function must have the signature `(p, t) -> [values...]`. +For time-independent systems, the returned function must have the signature +`(p) -> [values...]`. -By default, this function returns `nothing`. +By default, this function returns `nothing`, indicating that the index provider does not +support generating parameter observed functions. """ function parameter_observed(indp, sym) if hasmethod(symbolic_container, Tuple{typeof(indp)}) @@ -139,6 +118,38 @@ function parameter_observed(indp, sym) end end +""" + struct ContinuousTimeseries end + +A singleton struct corresponding to the timeseries index of the continuous timeseries. +""" +struct ContinuousTimeseries end + +""" + get_all_timeseries_indexes(indp, sym) + +Return a `Set` of all unique timeseries indexes of variables in symbolic variable +`sym`. `sym` may be a symbolic variable or expression, an array of symbolics, an index, +or an array of indices. Continuous variables correspond to the +[`ContinuousTimeseries`](@ref) timeseries index. Non-timeseries parameters do not have a +timeseries index. Timeseries parameters have the same timeseries index as that returned by +[`timeseries_parameter_index`](@ref). Note that the independent variable corresponds to +the `ContinuousTimeseries` timeseries index. + +Any ambiguities should be resolved in favor of variables. For example, if `1` could refer +to the variable at index `1` or parameter at index `1`, it should be interpreted as the +variable. + +By default, this function returns `Set([ContinuousTimeseries()])`. +""" +function get_all_timeseries_indexes(indp, sym) + if hasmethod(symbolic_container, Tuple{typeof(indp)}) + return get_all_timeseries_indexes(symbolic_container(indp), sym) + else + return Set([ContinuousTimeseries()]) + end +end + """ parameter_symbols(indp) diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index f89cde4..0cfbb66 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -48,7 +48,7 @@ end is_indexer_timeseries(::Type{GetParameterIndex{I}}) where {I} = IndexerNotTimeseries() function is_indexer_timeseries(::Type{GetParameterIndex{I}}) where {I <: ParameterTimeseriesIndex} - IndexerTimeseries() + IndexerOnlyTimeseries() end function indexer_timeseries_index(gpi::GetParameterIndex{<:ParameterTimeseriesIndex}) gpi.idx.timeseries_idx @@ -56,8 +56,11 @@ end function (gpi::GetParameterIndex)(::IsTimeseriesTrait, prob) parameter_values(prob, gpi.idx) end +function (gpi::GetParameterIndex)(::IsTimeseriesTrait, prob, arg) + parameter_values(prob, gpi.idx) +end function (gpi::GetParameterIndex)(::Timeseries, prob, args) - throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, gpi, args)) + parameter_values(prob, gpi.idx) end function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::Timeseries, prob) get_parameter_timeseries_collection(prob)[gpi.idx] @@ -129,10 +132,12 @@ is_indexer_timeseries(::Type{G}) where {G <: GetParameterTimeseriesIndex} = Inde function indexer_timeseries_index(gpti::GetParameterTimeseriesIndex) indexer_timeseries_index(gpti.param_timeseries_idx) end -as_not_timeseries_indexer(::IndexerBoth, gpti::GetParameterTimeseriesIndex) = gpti.param_idx function as_timeseries_indexer(::IndexerBoth, gpti::GetParameterTimeseriesIndex) gpti.param_timeseries_idx end +function as_not_timeseries_indexer(::IndexerBoth, gpti::GetParameterTimeseriesIndex) + gpti.param_idx +end function (gpti::GetParameterTimeseriesIndex)(ts::Timeseries, prob, args...) gpti.param_timeseries_idx(ts, prob, args...) @@ -145,16 +150,18 @@ function (gpti::GetParameterTimeseriesIndex)(ts::NotTimeseries, prob) gpti.param_idx(ts, prob) end -struct GetParameterObserved{I, M, F <: Function} <: AbstractParameterGetIndexer +struct GetParameterObserved{I, M, P, F <: Function} <: AbstractParameterGetIndexer timeseries_idx::I + parameter_update_fn::P obsfn::F end -function GetParameterObserved{Multiple}(timeseries_idx::I, obsfn::F) where {Multiple, I, F} +function GetParameterObserved{Multiple}( + timeseries_idx::I, update_fn::P, obsfn::F) where {Multiple, I, P, F} if !isa(Multiple, Bool) throw(TypeError(:GetParameterObserved, "{Multiple}", Bool, Multiple)) end - return GetParameterObserved{I, Multiple, F}(timeseries_idx, obsfn) + return GetParameterObserved{I, Multiple, P, F}(timeseries_idx, update_fn, obsfn) end const MultipleGetParameterObserved = GetParameterObserved{I, true} where {I} @@ -163,35 +170,59 @@ const SingleGetParameterObserved = GetParameterObserved{I, false} where {I} function is_indexer_timeseries(::Type{G}) where {G <: GetParameterObserved{Nothing}} IndexerNotTimeseries() end +function is_indexer_timeseries(::Type{G}) where {I <: Vector, G <: GetParameterObserved{I}} + IndexerMixedTimeseries() +end is_indexer_timeseries(::Type{G}) where {G <: GetParameterObserved} = IndexerBoth() indexer_timeseries_index(gpo::GetParameterObserved) = gpo.timeseries_idx -function as_not_timeseries_indexer( - ::IndexerBoth, gpo::GetParameterObserved{I, M}) where {I, M} - return GetParameterObserved{M}(nothing, gpo.obsfn) -end as_timeseries_indexer(::IndexerBoth, gpo::GetParameterObserved) = gpo +as_not_timeseries_indexer(::IndexerBoth, gpo::GetParameterObserved) = gpo +as_not_timeseries_indexer(::IndexerMixedTimeseries, gpo::GetParameterObserved) = gpo function (gpo::GetParameterObserved{Nothing})(::Timeseries, prob) gpo.obsfn(parameter_values(prob), current_time(prob)[end]) end -for multiple in [true, false] - @eval function (gpo::GetParameterObserved{Nothing, $multiple})( - buffer::AbstractArray, ::Timeseries, prob) - gpo.obsfn(buffer, parameter_values(prob), current_time(prob)[end]) - return buffer - end +function (gpo::GetParameterObserved{Nothing, true})( + buffer::AbstractArray, ::Timeseries, prob) + gpo.obsfn(buffer, parameter_values(prob), current_time(prob)[end]) + return buffer end for argType in [Union{Int, CartesianIndex}, Colon, AbstractArray{Bool}, Any] @eval function (gpo::GetParameterObserved{Nothing})(::Timeseries, prob, args::$argType) - throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, gpo, args)) + gpo.obsfn(parameter_values(prob), current_time(prob)[end]) + end + @eval function (gpo::GetParameterObserved{Nothing, true})( + buffer::AbstractArray, ::Timeseries, prob, args::$argType) + gpo.obsfn(buffer, parameter_values(prob), current_time(prob)[end]) + return buffer + end + @eval function (gpo::GetParameterObserved{<:Vector})(::Timeseries, prob, args::$argType) + throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(gpo))) end for multiple in [true, false] - @eval function (gpo::GetParameterObserved{Nothing, $multiple})( + @eval function (gpo::GetParameterObserved{<:Vector, $multiple})( ::AbstractArray, ::Timeseries, prob, args::$argType) - throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, gpo, args)) + throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(gpo))) end end end + +function (gpo::GetParameterObserved{<:Vector})(::NotTimeseries, prob) + gpo.obsfn(parameter_values(prob), current_time(prob)) +end +function (gpo::GetParameterObserved{<:Vector, true})( + buffer::AbstractArray, ::NotTimeseries, prob) + gpo.obsfn(buffer, parameter_values(prob), current_time(prob)) +end +function (gpo::GetParameterObserved{<:Vector})(::Timeseries, prob) + throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(gpo))) +end +function (gpo::GetParameterObserved{<:Vector, true})(::AbstractArray, ::Timeseries, prob) + throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(gpo))) +end +function (gpo::GetParameterObserved{<:Vector, false})(::AbstractArray, ::Timeseries, prob) + throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(gpo))) +end function (gpo::GetParameterObserved)(::NotTimeseries, prob) gpo.obsfn(parameter_values(prob), current_time(prob)) end @@ -201,31 +232,31 @@ function (gpo::GetParameterObserved)(buffer::AbstractArray, ::NotTimeseries, pro end function (gpo::GetParameterObserved)(::Timeseries, prob) map(parameter_timeseries(prob, gpo.timeseries_idx)) do t - gpo.obsfn(parameter_values_at_time(prob, t), t) + gpo.obsfn(gpo.parameter_update_fn(prob, t), t) end end function (gpo::MultipleGetParameterObserved)(buffer::AbstractArray, ::Timeseries, prob) times = parameter_timeseries(prob, gpo.timeseries_idx) for (buf_idx, time) in zip(eachindex(buffer), times) - gpo.obsfn(buffer[buf_idx], parameter_values_at_time(prob, time), time) + gpo.obsfn(buffer[buf_idx], gpo.parameter_update_fn(prob, time), time) end return buffer end function (gpo::SingleGetParameterObserved)(buffer::AbstractArray, ::Timeseries, prob) times = parameter_timeseries(prob, gpo.timeseries_idx) for (buf_idx, time) in zip(eachindex(buffer), times) - buffer[buf_idx] = gpo.obsfn(parameter_values_at_time(prob, time), time) + buffer[buf_idx] = gpo.obsfn(gpo.parameter_update_fn(prob, time), time) end return buffer end function (gpo::GetParameterObserved)(::Timeseries, prob, i::Union{Int, CartesianIndex}) time = parameter_timeseries(prob, gpo.timeseries_idx)[i] - gpo.obsfn(parameter_values_at_time(prob, time), time) + gpo.obsfn(gpo.parameter_update_fn(prob, time), time) end function (gpo::MultipleGetParameterObserved)( buffer::AbstractArray, ::Timeseries, prob, i::Union{Int, CartesianIndex}) time = parameter_timeseries(prob, gpo.timeseries_idx)[i] - gpo.obsfn(buffer, parameter_values_at_time(prob, time), time) + gpo.obsfn(buffer, gpo.parameter_update_fn(prob, time), time) end function (gpo::GetParameterObserved)(ts::Timeseries, prob, ::Colon) gpo(ts, prob) @@ -305,18 +336,16 @@ function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) throw(ArgumentError("Index provider does not support `parameter_observed`; cannot use generate function for $p")) end if !is_time_dependent(sys) - return GetParameterObservedNoTime(pofn.observed_fn) + return GetParameterObservedNoTime(pofn) end - return GetParameterObserved{false}(pofn.timeseries_idx, pofn.observed_fn) + ts_idxs = _postprocess_tsidxs(get_all_timeseries_indexes(sys, p)) + update_fn = Fix1Multiple(parameter_values_at_time, sys) + return GetParameterObserved{false}(ts_idxs, update_fn, pofn) end error("Invalid symbol $p for `getp`") end -struct MixedTimeseriesIndexes - indexes::Any -end - -struct MultipleParametersGetter{T <: IsIndexerTimeseries, G, I} <: +struct MultipleParametersGetter{T <: IndexerTimeseriesType, G, I} <: AbstractParameterGetIndexer getters::G timeseries_idx::I @@ -327,36 +356,40 @@ function MultipleParametersGetter(getters) return MultipleParametersGetter{IndexerNotTimeseries, typeof(getters), Nothing}( getters, nothing) end - has_timeseries_indexers = any(getters) do g - is_indexer_timeseries(g) == IndexerTimeseries() + has_only_timeseries_indexers = any(getters) do g + is_indexer_timeseries(g) == IndexerOnlyTimeseries() end has_non_timeseries_indexers = any(getters) do g is_indexer_timeseries(g) == IndexerNotTimeseries() end - if has_timeseries_indexers && has_non_timeseries_indexers - throw(ArgumentError("Cannot mix timeseries and non-timeseries indexers in `$MultipleParametersGetter`")) + all_non_timeseries_indexers = all(getters) do g + is_indexer_timeseries(g) == IndexerNotTimeseries() + end + if has_only_timeseries_indexers && has_non_timeseries_indexers + throw(ArgumentError("Cannot mix timeseries indexes with non-timeseries variables in `$MultipleParametersGetter`")) end - indexer_type = if has_timeseries_indexers + # If any getters are only timeseries, all of them should be + indexer_type = if has_only_timeseries_indexers getters = as_timeseries_indexer.(getters) - timeseries_idx = indexer_timeseries_index(first(getters)) - IndexerTimeseries - elseif has_non_timeseries_indexers - getters = as_not_timeseries_indexer.(getters) - timeseries_idx = nothing + IndexerOnlyTimeseries + # If all indexers are non-timeseries, so should their combination + elseif all_non_timeseries_indexers IndexerNotTimeseries else - timeseries_idx = indexer_timeseries_index(first(getters)) IndexerBoth end - if indexer_type != IndexerNotTimeseries && - !allequal(indexer_timeseries_index(g) for g in getters) - if indexer_type == IndexerTimeseries - throw(ArgumentError("All parameters must belong to the same timeseries")) - else - indexer_type = IndexerNotTimeseries - timeseries_idx = MixedTimeseriesIndexes(indexer_timeseries_index.(getters)) + timeseries_idx = if indexer_type == IndexerNotTimeseries + nothing + else + timeseries_idxs = Set(Iterators.flatten(indexer_timeseries_index(g) + for g in getters if is_indexer_timeseries(g) != IndexerNotTimeseries())) + if length(timeseries_idxs) > 1 + indexer_type = IndexerMixedTimeseries getters = as_not_timeseries_indexer.(getters) + collect(timeseries_idxs) + else + only(timeseries_idxs) end end @@ -364,26 +397,37 @@ function MultipleParametersGetter(getters) getters, timeseries_idx) end -const AtLeastTimeseriesMPG = Union{ - MultipleParametersGetter{IndexerTimeseries}, MultipleParametersGetter{IndexerBoth}} -const MixedTimeseriesIndexMPG = MultipleParametersGetter{ - IndexerNotTimeseries, G, MixedTimeseriesIndexes} where {G} +const OnlyTimeseriesMPG = MultipleParametersGetter{IndexerOnlyTimeseries} +const BothMPG = MultipleParametersGetter{IndexerBoth} +const NonTimeseriesMPG = MultipleParametersGetter{IndexerNotTimeseries} +const MixedTimeseriesMPG = MultipleParametersGetter{IndexerMixedTimeseries} +const AtLeastTimeseriesMPG = Union{OnlyTimeseriesMPG, BothMPG} is_indexer_timeseries(::Type{<:MultipleParametersGetter{T}}) where {T} = T() function indexer_timeseries_index(mpg::MultipleParametersGetter) mpg.timeseries_idx end +function as_timeseries_indexer(::IndexerBoth, mpg::MultipleParametersGetter) + MultipleParametersGetter(as_timeseries_indexer.(mpg.getters)) +end function as_not_timeseries_indexer(::IndexerBoth, mpg::MultipleParametersGetter) MultipleParametersGetter(as_not_timeseries_indexer.(mpg.getters)) end +function as_not_timeseries_indexer(::IndexerMixedTimeseries, mpg::MultipleParametersGetter) + return mpg +end -function as_timeseries_indexer(::IndexerBoth, mpg::MultipleParametersGetter) - MultipleParametersGetter(as_timeseries_indexer.(mpg.getters)) +function (mpg::MixedTimeseriesMPG)(::Timeseries, prob, args...) + throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(mpg))) +end +function (mpg::MixedTimeseriesMPG)(::AbstractArray, ::Timeseries, prob, args...) + throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(mpg))) end for (indexerTimeseriesType, timeseriesType) in [ (IndexerNotTimeseries, IsTimeseriesTrait), - (IndexerBoth, NotTimeseries) + (IndexerBoth, NotTimeseries), + (IndexerMixedTimeseries, NotTimeseries) ] @eval function (mpg::MultipleParametersGetter{$indexerTimeseriesType})( ::$timeseriesType, prob) @@ -398,17 +442,16 @@ for (indexerTimeseriesType, timeseriesType) in [ end end -function (mpg::MixedTimeseriesIndexMPG)(::Timeseries, prob, args...) - throw(MixedParameterTimeseriesIndexError(prob, mpg.timeseries_idx.indexes)) -end - -function (mpg::MultipleParametersGetter{IndexerNotTimeseries})(::Timeseries, prob, args) - throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, mpg, args)) +function (mpg::NonTimeseriesMPG)(::IsTimeseriesTrait, prob, arg) + return _call.(mpg.getters, (prob,), (arg,)) end -function (mpg::MultipleParametersGetter{IndexerNotTimeseries})( - ::AbstractArray, ::Timeseries, prob, args) - throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, mpg, args)) +function (mpg::NonTimeseriesMPG)(buffer::AbstractArray, ::IsTimeseriesTrait, prob, arg) + for (buf_idx, getter) in zip(eachindex(buffer), mpg.getters) + buffer[buf_idx] = getter(prob) + end + return buffer end + function (mpg::AtLeastTimeseriesMPG)(ts::Timeseries, prob) map(eachindex(parameter_timeseries(prob, indexer_timeseries_index(mpg)))) do i mpg(ts, prob, i) @@ -457,10 +500,10 @@ function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ts::Timeseries, prob end return buffer end -function (mpg::MultipleParametersGetter{IndexerTimeseries})(::NotTimeseries, prob) +function (mpg::OnlyTimeseriesMPG)(::NotTimeseries, prob) throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, mpg)) end -function (mpg::MultipleParametersGetter{IndexerTimeseries})( +function (mpg::OnlyTimeseriesMPG)( ::AbstractArray, ::NotTimeseries, prob) throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, mpg)) end @@ -490,6 +533,10 @@ wrap_tuple(::AsParameterTupleWrapper{N}, val) where {N} = ntuple(i -> val[i], Va function (atw::AsParameterTupleWrapper)(ts::IsTimeseriesTrait, prob, args...) atw(ts, is_indexer_timeseries(atw), prob, args...) end +function (atw::AsParameterTupleWrapper)( + ts::IsTimeseriesTrait, ::IndexerMixedTimeseries, prob, args...) + wrap_tuple(atw, atw.getter(ts, prob, args...)) +end function (atw::AsParameterTupleWrapper)(ts::Timeseries, ::AtLeastTimeseriesIndexer, prob) wrap_tuple.((atw,), atw.getter(ts, prob)) end @@ -500,7 +547,6 @@ end function (atw::AsParameterTupleWrapper)(ts::Timeseries, ::AtLeastTimeseriesIndexer, prob, i) wrap_tuple.((atw,), atw.getter(ts, prob, i)) end -# args is just so it throws function (atw::AsParameterTupleWrapper)( ts::Timeseries, ::IndexerNotTimeseries, prob, args...) wrap_tuple(atw, atw.getter(ts, prob, args...)) @@ -530,15 +576,21 @@ for (t1, t2) in [ # `getp` errors on older MTK that doesn't support `parameter_observed`. getters = getp.((sys,), p) num_observed = count(is_observed_getter, getters) + p_arr = p isa Tuple ? collect(p) : p if num_observed == 0 return MultipleParametersGetter(getters) else - pofn = parameter_observed(sys, p isa Tuple ? collect(p) : p) + pofn = parameter_observed(sys, p_arr) + if pofn === nothing + return MultipleParametersGetter.(getters) + end if is_time_dependent(sys) - getter = GetParameterObserved{true}(pofn.timeseries_idx, pofn.observed_fn) + ts_idxs = _postprocess_tsidxs(get_all_timeseries_indexes(sys, p_arr)) + update_fn = Fix1Multiple(parameter_values_at_time, sys) + getter = GetParameterObserved{true}(ts_idxs, update_fn, pofn) else - getter = GetParameterObservedNoTime(pofn.observed_fn) + getter = GetParameterObservedNoTime(pofn) end return p isa Tuple ? AsParameterTupleWrapper{length(p)}(getter) : getter end @@ -632,7 +684,9 @@ function _setp(sys, ::ArraySymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) return setp(sys, idx; run_hook = false) elseif is_observed(sys, p) && (pobsfn = parameter_observed(sys, p)) !== nothing - return GetParameterObserved{false}(pobsfn.timeseries_idx, pobsfn.observed_fn) + ts_idxs = _postprocess_tsidxs(get_all_timeseries_indexes(sys, p)) + update_fn = Fix1Multiple(parameter_values_at_time, sys) + return GetParameterObserved{false}(ts_idxs, update_fn, pobsfn.observed_fn) end return setp(sys, collect(p); run_hook = false) end diff --git a/src/parameter_timeseries_collection.jl b/src/parameter_timeseries_collection.jl index 0da27ab..14ca1f0 100644 --- a/src/parameter_timeseries_collection.jl +++ b/src/parameter_timeseries_collection.jl @@ -110,7 +110,7 @@ function _timeseries_value(ptc::ParameterTimeseriesCollection, ts_idx, t) end """ - parameter_values_at_time(valp, t) + parameter_values_at_time(indp, valp, t) Return an indexable collection containing the value of all parameters in `valp` at time `t`. Note that `t` here is a floating-point time, and not an index into a timeseries. @@ -118,36 +118,12 @@ Return an indexable collection containing the value of all parameters in `valp` This has a default implementation relying on [`get_parameter_timeseries_collection`](@ref) and [`with_updated_parameter_timeseries_values`](@ref). """ -function parameter_values_at_time(valp, t) +function parameter_values_at_time(indp, valp, t) ptc = get_parameter_timeseries_collection(valp) - with_updated_parameter_timeseries_values(ptc.paramcache, + with_updated_parameter_timeseries_values(indp, ptc.paramcache, (ts_idx => _timeseries_value(ptc, ts_idx, t) for ts_idx in eachindex(ptc))...) end -""" - parameter_values_at_state_time(valp, i) - parameter_values_at_state_time(valp) - -Return an indexable collection containing the value of all parameters in `valp` at time -index `i` in the state timeseries. - -By default, this function relies on [`parameter_values_at_time`](@ref) and -[`current_time`](@ref) for a default implementation. - -The single-argument version of this function is a shorthand to return parameter values -at each point in the state timeseries. This also has a default implementation relying on -[`parameter_values_at_time`](@ref) and [`current_time`](@ref). -""" -function parameter_values_at_state_time end - -function parameter_values_at_state_time(p, i) - state_time = current_time(p, i) - return parameter_values_at_time(p, state_time) -end -function parameter_values_at_state_time(p) - return (parameter_values_at_time(p, t) for t in current_time(p)) -end - """ parameter_timeseries(valp, i) @@ -163,43 +139,3 @@ function parameter_timeseries end function parameter_timeseries(valp, i) return parameter_timeseries(get_parameter_timeseries_collection(valp), i) end - -""" - parameter_timeseries_at_state_time(valp, i, j) - parameter_timeseries_at_state_time(valp, i) - -Return the index of the timestep in the parameter timeseries at timeseries index `i` which -occurs just before or at the same time as the state timestep with index `j`. The two- -argument version of this function returns an iterable of indexes, one for each timestep in -the state timeseries. If `j` is an object that refers to multiple values in the state -timeseries (e.g. `Colon`), return an iterable of the indexes in the parameter timeseries -at the appropriate points. - -Both versions of this function have default implementations relying on -[`current_time`](@ref) and [`parameter_timeseries`](@ref), for the cases where `j` is one -of: `Int`, `CartesianIndex`, `AbstractArray{Bool}`, `Colon` or an iterable of the -aforementioned. -""" -function parameter_timeseries_at_state_time end - -function parameter_timeseries_at_state_time(valp, i, j::Union{Int, CartesianIndex}) - state_time = current_time(valp, j) - timeseries = parameter_timeseries(valp, i) - searchsortedlast(timeseries, state_time) -end - -function parameter_timeseries_at_state_time(valp, i, ::Colon) - parameter_timeseries_at_state_time(valp, i) -end - -function parameter_timeseries_at_state_time(valp, i, j::AbstractArray{Bool}) - parameter_timeseries_at_state_time(valp, i, only(to_indices(current_time(valp), (j,)))) -end - -function parameter_timeseries_at_state_time(valp, i, j) - (parameter_timeseries_at_state_time(valp, i, jj) for jj in j) -end - -function parameter_timeseries_at_state_time(valp, i) - parameter_timeseries_at_state_time(valp, i, eachindex(current_time(valp))) -end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index b17e417..af51766 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -51,107 +51,60 @@ function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym) return GetStateIndex(sym) end -struct GetpAtStateTime{G} <: AbstractStateGetIndexer - getter::G -end - -function (g::GetpAtStateTime)(ts::Timeseries, prob) - g(ts, is_parameter_timeseries(prob), prob) -end -function (g::GetpAtStateTime)(ts::Timeseries, prob, i) - g(ts, is_parameter_timeseries(prob), prob, i) -end -function (g::GetpAtStateTime)(::Timeseries, ::NotTimeseries, prob, _...) - g.getter(prob) -end -function (g::GetpAtStateTime)(ts::Timeseries, p_ts::Timeseries, prob) - g(ts, p_ts, is_indexer_timeseries(g.getter), prob) -end -function (g::GetpAtStateTime)( - ::Timeseries, ::Timeseries, ::Union{IndexerTimeseries, IndexerBoth}, prob) - g.getter.((prob,), - parameter_timeseries_at_state_time(prob, indexer_timeseries_index(g.getter))) -end -function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob) - g.getter(prob) -end -function (g::GetpAtStateTime)(ts::Timeseries, p_ts::Timeseries, prob, i) - g(ts, p_ts, is_indexer_timeseries(g.getter), prob, i) -end -function (g::GetpAtStateTime)( - ::Timeseries, ::Timeseries, ::Union{IndexerTimeseries, IndexerBoth}, prob, i) - g.getter(prob, - parameter_timeseries_at_state_time(prob, indexer_timeseries_index(g.getter), i)) -end -function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, - prob, ::Union{Int, CartesianIndex}) - g.getter(prob) -end -function (g::GetpAtStateTime)( - ::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, ::Colon) - map(_ -> g.getter(prob), current_time(prob)) -end -function (g::GetpAtStateTime)( - ::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, i::AbstractArray{Bool}) - num_ones = sum(i) - map(_ -> g.getter(prob), 1:num_ones) -end -function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, i) - map(_ -> g.getter(prob), 1:length(i)) -end -function (g::GetpAtStateTime)(::NotTimeseries, prob) - g.getter(prob) -end - struct GetIndepvar <: AbstractStateGetIndexer end (::GetIndepvar)(::IsTimeseriesTrait, prob) = current_time(prob) (::GetIndepvar)(::Timeseries, prob, i) = current_time(prob, i) -struct TimeDependentObservedFunction{F} <: AbstractStateGetIndexer +struct TimeDependentObservedFunction{I, F} <: AbstractStateGetIndexer + ts_idxs::I obsfn::F end -function (o::TimeDependentObservedFunction)(ts::Timeseries, prob) - return o(ts, is_parameter_timeseries(prob), prob) +indexer_timeseries_index(t::TimeDependentObservedFunction) = t.ts_idxs +function is_indexer_timeseries(::Type{G}) where {G <: + TimeDependentObservedFunction{ContinuousTimeseries}} + return IndexerBoth() +end +function is_indexer_timeseries(::Type{G}) where {G <: + TimeDependentObservedFunction{<:Vector}} + return IndexerMixedTimeseries() end -function (o::TimeDependentObservedFunction)(::Timeseries, ::Timeseries, prob) - map(o.obsfn, state_values(prob), - parameter_values_at_state_time(prob), current_time(prob)) +function (o::TimeDependentObservedFunction)(ts::IsTimeseriesTrait, prob, args...) + return o(ts, is_indexer_timeseries(o), prob, args...) end -function (o::TimeDependentObservedFunction)(::Timeseries, ::NotTimeseries, prob) - o.obsfn.(state_values(prob), + +function (o::TimeDependentObservedFunction)(ts::Timeseries, ::IndexerBoth, prob) + return o.obsfn.(state_values(prob), (parameter_values(prob),), current_time(prob)) end -function (o::TimeDependentObservedFunction)(ts::Timeseries, prob, i) - return o(ts, is_parameter_timeseries(prob), prob, i) -end function (o::TimeDependentObservedFunction)( - ::Timeseries, ::Timeseries, prob, i::Union{Int, CartesianIndex}) - return o.obsfn(state_values(prob, i), - parameter_values_at_state_time(prob, i), - current_time(prob, i)) + ::Timeseries, ::IndexerBoth, prob, i::Union{Int, CartesianIndex}) + return o.obsfn(state_values(prob, i), parameter_values(prob), current_time(prob, i)) end -function (o::TimeDependentObservedFunction)( - ts::Timeseries, p_ts::IsTimeseriesTrait, prob, ::Colon) - return o(ts, p_ts, prob) +function (o::TimeDependentObservedFunction)(ts::Timeseries, ::IndexerBoth, prob, ::Colon) + return o(ts, prob) end function (o::TimeDependentObservedFunction)( - ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i::AbstractArray{Bool}) + ts::Timeseries, ::IndexerBoth, prob, i::AbstractArray{Bool}) map(only(to_indices(current_time(prob), (i,)))) do idx - o(ts, p_ts, prob, idx) + o(ts, prob, idx) end end -function (o::TimeDependentObservedFunction)( - ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i) - o.((ts,), (p_ts,), (prob,), i) +function (o::TimeDependentObservedFunction)(ts::Timeseries, ::IndexerBoth, prob, i) + o.((ts,), (prob,), i) end +function (o::TimeDependentObservedFunction)(::NotTimeseries, ::IndexerBoth, prob) + return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob)) +end + function (o::TimeDependentObservedFunction)( - ::Timeseries, ::NotTimeseries, prob, i::Union{Int, CartesianIndex}) - o.obsfn(state_values(prob, i), parameter_values(prob), current_time(prob, i)) + ::Timeseries, ::IndexerMixedTimeseries, prob, args...) + throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(o))) end -function (o::TimeDependentObservedFunction)(::NotTimeseries, prob) +function (o::TimeDependentObservedFunction)( + ::NotTimeseries, ::IndexerMixedTimeseries, prob, args...) return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob)) end @@ -168,42 +121,77 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) idx = variable_index(sys, sym) return getu(sys, idx) elseif is_parameter(sys, sym) - return GetpAtStateTime(getp(sys, sym)) + return getp(sys, sym) elseif is_independent_variable(sys, sym) return GetIndepvar() elseif is_observed(sys, sym) - fn = observed(sys, sym) - if is_time_dependent(sys) - return TimeDependentObservedFunction(fn) + if !is_time_dependent(sys) + return TimeIndependentObservedFunction(observed(sys, sym)) + end + + ts_idxs = get_all_timeseries_indexes(sys, sym) + if ContinuousTimeseries() in ts_idxs + if length(ts_idxs) == 1 + ts_idxs = only(ts_idxs) + else + ts_idxs = collect(ts_idxs) + end + fn = observed(sys, sym) + return TimeDependentObservedFunction(ts_idxs, fn) else - return TimeIndependentObservedFunction(fn) + return getp(sys, sym) end end error("Invalid symbol $sym for `getu`") end -struct MultipleGetters{G} <: AbstractStateGetIndexer +struct MultipleGetters{I, G} <: AbstractStateGetIndexer + ts_idxs::I getters::G end -function (mg::MultipleGetters)(ts::Timeseries, prob) +indexer_timeseries_index(mg::MultipleGetters) = mg.ts_idxs +function is_indexer_timeseries(::Type{G}) where {G <: MultipleGetters{ContinuousTimeseries}} + return IndexerBoth() +end +function is_indexer_timeseries(::Type{G}) where {G <: MultipleGetters{<:Vector}} + return IndexerMixedTimeseries() +end +function is_indexer_timeseries(::Type{G}) where {G <: MultipleGetters{Nothing}} + return IndexerNotTimeseries() +end + +function (mg::MultipleGetters)(ts::IsTimeseriesTrait, prob, args...) + return mg(ts, is_indexer_timeseries(mg), prob, args...) +end + +function (mg::MultipleGetters)(ts::Timeseries, ::IndexerBoth, prob) return mg.((ts,), (prob,), eachindex(current_time(prob))) end -function (mg::MultipleGetters)(::Timeseries, prob, i::Union{Int, CartesianIndex}) +function (mg::MultipleGetters)( + ::Timeseries, ::IndexerBoth, prob, i::Union{Int, CartesianIndex}) return map(CallWith(prob, i), mg.getters) end -function (mg::MultipleGetters)(ts::Timeseries, prob, ::Colon) +function (mg::MultipleGetters)(ts::Timeseries, ::IndexerBoth, prob, ::Colon) return mg(ts, prob) end -function (mg::MultipleGetters)(ts::Timeseries, prob, i::AbstractArray{Bool}) +function (mg::MultipleGetters)(ts::Timeseries, ::IndexerBoth, prob, i::AbstractArray{Bool}) return map(only(to_indices(current_time(prob), (i,)))) do idx mg(ts, prob, idx) end end -function (mg::MultipleGetters)(ts::Timeseries, prob, i) +function (mg::MultipleGetters)(ts::Timeseries, ::IndexerBoth, prob, i) mg.((ts,), (prob,), i) end -function (mg::MultipleGetters)(::NotTimeseries, prob) +function (mg::MultipleGetters)( + ::NotTimeseries, ::Union{IndexerBoth, IndexerNotTimeseries}, prob) + return map(g -> g(prob), mg.getters) +end + +function (mg::MultipleGetters)(::Timeseries, ::IndexerMixedTimeseries, prob, args...) + throw(MixedParameterTimeseriesIndexError(prob, indexer_timeseries_index(mg))) +end +function (mg::MultipleGetters)(::NotTimeseries, ::IndexerMixedTimeseries, prob, args...) return map(g -> g(prob), mg.getters) end @@ -233,20 +221,42 @@ for (t1, t2) in [ (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray}) ] - @eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2) + @eval function _getu(sys, ::NotSymbolic, elt::$t1, sym::$t2) + if isempty(sym) + return MultipleGetters(ContinuousTimeseries(), sym) + end + sym_arr = sym isa Tuple ? collect(sym) : sym num_observed = count(x -> is_observed(sys, x), sym) - if num_observed == 0 || num_observed == 1 && sym isa Tuple - if !isempty(sym) && all(Base.Fix1(is_parameter, sys), sym) && - all(!Base.Fix1(is_timeseries_parameter, sys), sym) - GetpAtStateTime(getp(sys, sym)) + if !is_time_dependent(sys) + if num_observed == 0 || num_observed == 1 && sym isa Tuple + return MultipleGetters(nothing, getu.((sys,), sym)) else - getters = getu.((sys,), sym) - return MultipleGetters(getters) + obs = observed(sys, sym_arr) + getter = TimeIndependentObservedFunction(obs) + if sym isa Tuple + getter = AsTupleWrapper{length(sym)}(getter) + end + return getter end + end + ts_idxs = get_all_timeseries_indexes(sys, sym_arr) + if !(ContinuousTimeseries() in ts_idxs) + return getp(sys, sym) + end + if length(ts_idxs) == 1 + ts_idxs = only(ts_idxs) + else + ts_idxs = collect(ts_idxs) + end + + num_observed = count(x -> is_observed(sys, x), sym) + if num_observed == 0 || num_observed == 1 && sym isa Tuple + getters = getu.((sys,), sym) + return MultipleGetters(ts_idxs, getters) else - obs = observed(sys, sym isa Tuple ? collect(sym) : sym) + obs = observed(sys, sym_arr) getter = if is_time_dependent(sys) - TimeDependentObservedFunction(obs) + TimeDependentObservedFunction(ts_idxs, obs) else TimeIndependentObservedFunction(obs) end @@ -263,7 +273,7 @@ function _getu(sys, ::ArraySymbolic, ::SymbolicTypeTrait, sym) idx = variable_index(sys, sym) return getu(sys, idx) elseif is_parameter(sys, sym) - return GetpAtStateTime(getp(sys, sym)) + return getp(sys, sym) end return getu(sys, collect(sym)) end diff --git a/src/symbol_cache.jl b/src/symbol_cache.jl index 76262b5..823783f 100644 --- a/src/symbol_cache.jl +++ b/src/symbol_cache.jl @@ -72,10 +72,20 @@ function SymbolCache(vars = nothing, params = nothing, indepvars = nothing; end function is_variable(sc::SymbolCache, sym) - sc.variables !== nothing && haskey(sc.variables, sym) + sc.variables === nothing && return false + if symbolic_type(sym) == NotSymbolic() + return sym in values(sc.variables) + else + return haskey(sc.variables, sym) + end end function variable_index(sc::SymbolCache, sym) - sc.variables === nothing ? nothing : get(sc.variables, sym, nothing) + sc.variables === nothing && return nothing + if symbolic_type(sym) == NotSymbolic() + return sym + else + return get(sc.variables, sym, nothing) + end end function variable_symbols(sc::SymbolCache, i = nothing) sc.variables === nothing && return [] @@ -86,27 +96,64 @@ function variable_symbols(sc::SymbolCache, i = nothing) return buffer end function is_parameter(sc::SymbolCache, sym) - sc.parameters !== nothing && haskey(sc.parameters, sym) + sc.parameters === nothing && return false + if symbolic_type(sym) == NotSymbolic() + return sym in values(sc.parameters) + else + return haskey(sc.parameters, sym) + end end function parameter_index(sc::SymbolCache, sym) - sc.parameters === nothing ? nothing : get(sc.parameters, sym, nothing) + sc.parameters === nothing && return nothing + if symbolic_type(sym) == NotSymbolic() + return sym + else + return get(sc.parameters, sym, nothing) + end end function parameter_symbols(sc::SymbolCache) sc.parameters === nothing ? [] : collect(keys(sc.parameters)) end function is_timeseries_parameter(sc::SymbolCache, sym) - sc.timeseries_parameters !== nothing && haskey(sc.timeseries_parameters, sym) + sc.timeseries_parameters === nothing && return false + if symbolic_type(sym) == NotSymbolic() + return sym in values(sc.timeseries_parameters) + else + return haskey(sc.timeseries_parameters, sym) + end end function timeseries_parameter_index(sc::SymbolCache, sym) - sc.timeseries_parameters === nothing ? nothing : - get(sc.timeseries_parameters, sym, nothing) + sc.timeseries_parameters === nothing && return nothing + if symbolic_type(sym) == NotSymbolic() + return sym + else + return get(sc.timeseries_parameters, sym, nothing) + end +end +function get_all_timeseries_indexes(sc::SymbolCache, sym) + if is_variable(sc, sym) || is_independent_variable(sc, sym) + return Set([ContinuousTimeseries()]) + elseif is_timeseries_parameter(sc, sym) + return Set([timeseries_parameter_index(sc, sym).timeseries_idx]) + else + return Set() + end +end +function get_all_timeseries_indexes(sc::SymbolCache, sym::Expr) + exs = ExpressionSearcher() + exs(sc, sym) + return mapreduce( + Base.Fix1(get_all_timeseries_indexes, sc), union, exs.declared; init = Set()) +end +function get_all_timeseries_indexes(sc::SymbolCache, sym::AbstractArray) + return mapreduce(Base.Fix1(get_all_timeseries_indexes, sc), union, sym; init = Set()) end function is_independent_variable(sc::SymbolCache, sym) sc.independent_variables === nothing && return false if symbolic_type(sc.independent_variables) == NotSymbolic() return any(isequal(sym), sc.independent_variables) elseif symbolic_type(sc.independent_variables) == ScalarSymbolic() - return sym == sc.independent_variables + return isequal(sym, sc.independent_variables) else return any(isequal(sym), collect(sc.independent_variables)) end @@ -227,24 +274,13 @@ function parameter_observed(sc::SymbolCache, expr::Expr) if is_time_dependent(sc) exs = ExpressionSearcher() exs(sc, expr) - ts_idxs = Set() - for p in exs.parameters - is_timeseries_parameter(sc, p) || continue - push!(ts_idxs, timeseries_parameter_index(sc, p).timeseries_idx) - end - f = let fn = observed(sc, expr) + return let fn = observed(sc, expr) f1(p, t) = fn(nothing, p, t) end - if length(ts_idxs) == 1 - return ParameterObservedFunction(only(ts_idxs), f) - else - return ParameterObservedFunction(nothing, f) - end else - f = let fn = observed(sc, expr) + return let fn = observed(sc, expr) f2(p) = fn(nothing, p) end - return ParameterObservedFunction(nothing, f) end end @@ -257,27 +293,16 @@ function parameter_observed(sc::SymbolCache, exprs::Union{AbstractArray, Tuple}) if is_time_dependent(sc) exs = ExpressionSearcher() exs(sc, to_expr(exprs)) - ts_idxs = Set() - for p in exs.parameters - is_timeseries_parameter(sc, p) || continue - push!(ts_idxs, timeseries_parameter_index(sc, p).timeseries_idx) - end - f = let oop = observed(sc, to_expr(exprs)), iip = inplace_observed(sc, exprs) + return let oop = observed(sc, to_expr(exprs)), iip = inplace_observed(sc, exprs) f1(p, t) = oop(nothing, p, t) f1(buffer, p, t) = iip(buffer, nothing, p, t) end - if length(ts_idxs) == 1 - return ParameterObservedFunction(only(ts_idxs), f) - else - return ParameterObservedFunction(nothing, f) - end else - f = let oop = observed(sc, to_expr(exprs)), iip = inplace_observed(sc, exprs) + return let oop = observed(sc, to_expr(exprs)), iip = inplace_observed(sc, exprs) f2(p) = oop(nothing, p) f2(buffer, p) = iip(buffer, nothing, p) end - return ParameterObservedFunction(nothing, f) end end diff --git a/src/value_provider_interface.jl b/src/value_provider_interface.jl index 9a9cb0e..7903017 100644 --- a/src/value_provider_interface.jl +++ b/src/value_provider_interface.jl @@ -23,25 +23,30 @@ parameter_values(arr::Tuple, i) = arr[i] parameter_values(prob, i) = parameter_values(parameter_values(prob), i) """ - parameter_values_at_time(valp, t) + get_parameter_timeseries_collection(valp) -Return an indexable collection containing the value of all parameters in `valp` at time -`t`. Note that `t` here is a floating-point time, and not an index into a timeseries. - -This is useful for parameter timeseries objects, since some parameters change over time. +Return the [`ParameterTimeseriesCollection`](@ref) contained in timeseries value provider +`valp`. Valid only for value providers where [`is_parameter_timeseries`](@ref) returns +[`Timeseries`](@ref). """ function get_parameter_timeseries_collection end """ - with_updated_parameter_timeseries_values(valp, args::Pair...) + with_updated_parameter_timeseries_values(indp, params, args::Pair...) -Return an indexable collection containing the value of all parameters in `valp`, with +Return an indexable collection containing the value of all parameters in `params`, with parameters belonging to specific timeseries updated to different values. Each element in `args...` contains the timeseries index as the first value, and the saved parameter values in that partition. Not all parameter timeseries have to be updated using this method. If -an in-place update can be performed, it should be done and the modified `valp` returned. +an in-place update can be performed, it should be done and the modified `params` returned. +This method falls back on the basis of `symbolic_container(indp)`. + +Note that here `params` is the parameter object. """ -function with_updated_parameter_timeseries_values end +function with_updated_parameter_timeseries_values(indp, params, args...) + return with_updated_parameter_timeseries_values( + symbolic_container(indp), params, args...) +end """ set_parameter!(valp, val, idx) @@ -159,34 +164,38 @@ function (ai::AbstractParameterGetIndexer)(buffer::AbstractArray, prob, i) ai(buffer, is_parameter_timeseries(prob), prob, i) end -abstract type IsIndexerTimeseries end +abstract type IndexerTimeseriesType end -struct IndexerTimeseries <: IsIndexerTimeseries end -struct IndexerNotTimeseries <: IsIndexerTimeseries end -struct IndexerBoth <: IsIndexerTimeseries end - -const AtLeastTimeseriesIndexer = Union{IndexerTimeseries, IndexerBoth} -const AtLeastNotTimeseriesIndexer = Union{IndexerNotTimeseries, IndexerBoth} +# Can only index parameter timeseries +struct IndexerOnlyTimeseries <: IndexerTimeseriesType end +# Can only index non-timeseries objects +struct IndexerMixedTimeseries <: IndexerTimeseriesType end +# Can index timeseres and non-timeseries objects, has the same value at all times +struct IndexerNotTimeseries <: IndexerTimeseriesType end +# Value changes over time, can index timeseries and non-timeseries objects +struct IndexerBoth <: IndexerTimeseriesType end is_indexer_timeseries(x) = is_indexer_timeseries(typeof(x)) function indexer_timeseries_index end +const AtLeastTimeseriesIndexer = Union{IndexerOnlyTimeseries, IndexerBoth} +const AtLeastNotTimeseriesIndexer = Union{IndexerNotTimeseries, IndexerBoth} + +as_timeseries_indexer(x) = as_timeseries_indexer(is_indexer_timeseries(x), x) +as_timeseries_indexer(::IndexerOnlyTimeseries, x) = x +as_timeseries_indexer(::IndexerNotTimeseries, x) = x as_not_timeseries_indexer(x) = as_not_timeseries_indexer(is_indexer_timeseries(x), x) as_not_timeseries_indexer(::IndexerNotTimeseries, x) = x -function as_not_timeseries_indexer(::IndexerTimeseries, x) - error(""" - Tried to convert an `$IndexerTimeseries` to an `$IndexerNotTimeseries`. This \ - should never happen. Please file an issue with an MWE. - """) -end -as_timeseries_indexer(x) = as_timeseries_indexer(is_indexer_timeseries(x), x) -as_timeseries_indexer(::IndexerTimeseries, x) = x -function as_timeseries_indexer(::IndexerNotTimeseries, x) - error(""" - Tried to convert an `$IndexerNotTimeseries` to an `$IndexerTimeseries`. This \ - should never happen. Please file an issue with an MWE. - """) +function _postprocess_tsidxs(ts_idxs) + delete!(ts_idxs, ContinuousTimeseries()) + if isempty(ts_idxs) + return nothing + elseif length(ts_idxs) == 1 + return only(ts_idxs) + else + return collect(ts_idxs) + end end struct CallWith{A} @@ -203,6 +212,15 @@ function _call(f, args...) return f(args...) end +struct Fix1Multiple{F, A} + f::F + arg::A +end + +function (fn::Fix1Multiple)(args...) + fn.f(fn.arg, args...) +end + ########### # Errors ########### @@ -219,12 +237,6 @@ struct ParameterTimeseriesValueIndexMismatchError{P <: IsTimeseriesTrait} <: Exc got $(valp). Open an issue in SymbolicIndexingInterface.jl with an MWE. """)) end - if is_indexer_timeseries(indexer) != IndexerNotTimeseries() - throw(ArgumentError(""" - This should never happen. Expected non-timeseries indexer, got \ - $(indexer). Open an issue in SymbolicIndexingInterface.jl with an MWE. - """)) - end return new{Timeseries}(valp, indexer, args) end function ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(valp, indexer) @@ -235,7 +247,7 @@ struct ParameterTimeseriesValueIndexMismatchError{P <: IsTimeseriesTrait} <: Exc with an MWE. """)) end - if is_indexer_timeseries(indexer) != IndexerTimeseries() + if is_indexer_timeseries(indexer) != IndexerOnlyTimeseries() throw(ArgumentError(""" This should never happen. Expected timeseries indexer, got $(indexer). \ Open an issue in SymbolicIndexingInterface.jl with an MWE. @@ -263,13 +275,13 @@ function Base.showerror( end struct MixedParameterTimeseriesIndexError <: Exception - valp::Any + obj::Any ts_idxs::Any end function Base.showerror(io::IO, err::MixedParameterTimeseriesIndexError) print(io, """ - Invalid indexing operation: tried to access object of type $(typeof(err.valp)) \ + Invalid indexing operation: tried to access object of type $(typeof(err.obj)) \ (which is a parameter timeseries object) with variables having mixed timeseries \ indexes $(err.ts_idxs). """) diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 47a75f5..b2347b2 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -1,5 +1,6 @@ using SymbolicIndexingInterface -using SymbolicIndexingInterface: IndexerTimeseries, IndexerNotTimeseries, IndexerBoth, +using SymbolicIndexingInterface: IndexerOnlyTimeseries, IndexerNotTimeseries, IndexerBoth, + IndexerMixedTimeseries, is_indexer_timeseries, indexer_timeseries_index, ParameterTimeseriesValueIndexMismatchError, MixedParameterTimeseriesIndexError @@ -54,8 +55,7 @@ for sys in [ ([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), ([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), ((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), - ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true) - ] + ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)] get = getp(sys, sym) set! = setp(sys, sym) if check_inference @@ -105,21 +105,27 @@ for sys in [ fi.counter[] = 0 end - for (sym, val) in [ - ([:a, :b, :c, :d], p), - ([:c, :a], p[[3, 1]]), - ((:b, :a), Tuple(p[[2, 1]])), - ((1, :c), Tuple(p[[1, 3]])), - (:(a + b + t), p[1] + p[2] + fi.t), - ([:(a + b + t), :c], [p[1] + p[2] + fi.t, p[3]]), - ((:(a + b + t), :c), (p[1] + p[2] + fi.t, p[3])) + for (sym, val, check_inference) in [ + ([:a, :b, :c, :d], p, true), + ([:c, :a], p[[3, 1]], !has_ts), + ((:b, :a), Tuple(p[[2, 1]]), true), + ((1, :c), Tuple(p[[1, 3]]), true), + (:(a + b + t), p[1] + p[2] + fi.t, true), + ([:(a + b + t), :c], [p[1] + p[2] + fi.t, p[3]], true), + ((:(a + b + t), :c), (p[1] + p[2] + fi.t, p[3]), true) ] get = getp(sys, sym) - @inferred get(fi) + if check_inference + @inferred get(fi) + end @test get(fi) == val if sym isa Union{Array, Tuple} buffer = zeros(length(sym)) - @inferred get(buffer, fi) + if check_inference + @inferred get(buffer, fi) + else + get(buffer, fi) + end @test buffer == collect(val) end end @@ -155,7 +161,7 @@ end SymbolicIndexingInterface.parameter_values(mpo::MyParameterObject) = mpo.p function SymbolicIndexingInterface.with_updated_parameter_timeseries_values( - mpo::MyParameterObject, args::Pair...) + ::SymbolCache, mpo::MyParameterObject, args::Pair...) for (ts_idx, val) in args mpo.p[mpo.disc_idxs[ts_idx]] = val end @@ -206,173 +212,235 @@ dval = fs.p[4] bidx = timeseries_parameter_index(sys, :b) cidx = timeseries_parameter_index(sys, :c) -for (sym, indexer_trait, timeseries_index, val, buffer, check_inference) in [ - (:a, IndexerNotTimeseries, 0, aval, nothing, true), - (1, IndexerNotTimeseries, 0, aval, nothing, true), - ([:a, :d], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true), - ((:a, :d), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true), - ([1, 4], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true), - ((1, 4), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true), - ([:a, 4], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true), - ((:a, 4), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true), - (:b, IndexerBoth, 1, bval, zeros(length(bval)), true), - (bidx, IndexerTimeseries, 1, bval, zeros(length(bval)), true), - ([:a, :b], IndexerNotTimeseries, 0, [aval, bval[end]], zeros(2), true), - ((:a, :b), IndexerNotTimeseries, 0, (aval, bval[end]), zeros(2), true), - ([1, :b], IndexerNotTimeseries, 0, [aval, bval[end]], zeros(2), true), - ((1, :b), IndexerNotTimeseries, 0, (aval, bval[end]), zeros(2), true), - ([:b, :b], IndexerBoth, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), - ((:b, :b), IndexerBoth, 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), - ([bidx, :b], IndexerTimeseries, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), - ((bidx, :b), IndexerTimeseries, 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), - ([bidx, bidx], IndexerTimeseries, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), - ((bidx, bidx), IndexerTimeseries, 1, - tuple.(bval, bval), map(_ -> zeros(2), bval), true), - (:(a + b), IndexerBoth, 1, bval .+ aval, zeros(length(bval)), true), - ([:(a + b), :a], IndexerBoth, 1, vcat.(bval .+ aval, aval), - map(_ -> zeros(2), bval), true), - ((:(a + b), :a), IndexerBoth, 1, tuple.(bval .+ aval, aval), - map(_ -> zeros(2), bval), true), - ([:(a + b), :b], IndexerBoth, 1, vcat.(bval .+ aval, bval), - map(_ -> zeros(2), bval), true), - ((:(a + b), :b), IndexerBoth, 1, tuple.(bval .+ aval, bval), - map(_ -> zeros(2), bval), true), - ([:(a + b), :c], IndexerNotTimeseries, 0, - [aval + bval[end], cval[end]], zeros(2), true), - ((:(a + b), :c), IndexerNotTimeseries, 0, - (aval + bval[end], cval[end]), zeros(2), true) +# IndexerNotTimeseries +for (sym, val, buffer, check_inference) in [ + (:a, aval, nothing, true), + (1, aval, nothing, true), + ([:a, :d], [aval, dval], zeros(2), true), + ((:a, :d), (aval, dval), zeros(2), true), + ([1, 4], [aval, dval], zeros(2), true), + ((1, 4), (aval, dval), zeros(2), true), + ([:a, 4], [aval, dval], zeros(2), true), + ((:a, 4), (aval, dval), zeros(2), true), + (:(a + d), aval + dval, nothing, true), + ([:(a + d), :(a * d)], [aval + dval, aval * dval], zeros(2), true), + ((:(a + d), :(a * d)), (aval + dval, aval * dval), zeros(2), true) ] - getter = getp(sys, sym) - @test is_indexer_timeseries(getter) isa indexer_trait - if indexer_trait <: Union{IndexerTimeseries, IndexerBoth} - @test indexer_timeseries_index(getter) == timeseries_index - end + getter = getp(fs, sym) + @test is_indexer_timeseries(getter) isa IndexerNotTimeseries test_inplace = buffer !== nothing - test_non_timeseries = indexer_trait !== IndexerTimeseries - if test_inplace && test_non_timeseries - non_timeseries_val = indexer_trait == IndexerNotTimeseries ? val : val[end] - non_timeseries_buffer = indexer_trait == IndexerNotTimeseries ? deepcopy(buffer) : - deepcopy(buffer[end]) - test_non_timeseries_inplace = non_timeseries_buffer isa AbstractArray - end - isobs = sym isa Union{AbstractArray, Tuple} ? any(Base.Fix1(is_observed, sys), sym) : - is_observed(sys, sym) + is_observed = sym isa Expr || + sym isa Union{AbstractArray, Tuple} && any(x -> x isa Expr, sym) if check_inference @inferred getter(fs) + if !is_observed + @inferred getter(parameter_values(fs)) + end if test_inplace @inferred getter(deepcopy(buffer), fs) - end - if test_non_timeseries && !isobs - @inferred getter(parameter_values(fs)) - if test_inplace && test_non_timeseries_inplace && test_non_timeseries_inplace - @inferred getter(deepcopy(non_timeseries_buffer), parameter_values(fs)) + if !is_observed + @inferred getter(deepcopy(buffer), parameter_values(fs)) end end end @test getter(fs) == val + if !is_observed + @test getter(parameter_values(fs)) == val + end if test_inplace - tmp = deepcopy(buffer) - getter(tmp, fs) - if val isa Tuple - target = collect(val) - elseif eltype(val) <: Tuple - target = collect.(val) - else - target = val + target = collect(val) + valps = is_observed ? (fs,) : (fs, parameter_values(fs)) + for valp in valps + tmp = deepcopy(buffer) + getter(tmp, valp) + @test tmp == target end - @test tmp == target end - if test_non_timeseries && !isobs - non_timeseries_val = indexer_trait == IndexerNotTimeseries ? val : val[end] - @test getter(parameter_values(fs)) == non_timeseries_val - if test_inplace && test_non_timeseries && test_non_timeseries_inplace - getter(non_timeseries_buffer, parameter_values(fs)) - if non_timeseries_val isa Tuple - target = collect(non_timeseries_val) - else - target = non_timeseries_val + for subidx in [1, CartesianIndex(1), :, rand(Bool, 4), rand(1:4, 3), 1:2] + @test getter(fs, subidx) == val + if test_inplace + tmp = deepcopy(buffer) + getter(tmp, fs, subidx) + @test tmp == collect(val) + end + end +end + +# IndexerBoth +for (sym, timeseries_index, val, buffer, check_inference) in [ + (:b, 1, bval, zeros(length(bval)), true), + ([:a, :b], 1, vcat.(aval, bval), map(_ -> zeros(2), bval), false), + ((:a, :b), 1, tuple.(aval, bval), map(_ -> zeros(2), bval), true), + ([1, :b], 1, vcat.(aval, bval), map(_ -> zeros(2), bval), false), + ((1, :b), 1, tuple.(aval, bval), map(_ -> zeros(2), bval), true), + ([:b, :b], 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), + ((:b, :b), 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), + (:(a + b), 1, bval .+ aval, zeros(length(bval)), true), + ([:(a + b), :a], 1, vcat.(bval .+ aval, aval), map(_ -> zeros(2), bval), true), + ((:(a + b), :a), 1, tuple.(bval .+ aval, aval), map(_ -> zeros(2), bval), true), + ([:(a + b), :b], 1, vcat.(bval .+ aval, bval), map(_ -> zeros(2), bval), true), + ((:(a + b), :b), 1, tuple.(bval .+ aval, bval), map(_ -> zeros(2), bval), true) +] + getter = getp(sys, sym) + @test is_indexer_timeseries(getter) isa IndexerBoth + @test indexer_timeseries_index(getter) == timeseries_index + isobs = sym isa Union{AbstractArray, Tuple} ? any(Base.Fix1(is_observed, sys), sym) : + is_observed(sys, sym) + + if check_inference + @inferred getter(fs) + @inferred getter(deepcopy(buffer), fs) + if !isobs + @inferred getter(parameter_values(fs)) + if !(eltype(val) <: Number) + @inferred getter(deepcopy(buffer[1]), parameter_values(fs)) end - @test non_timeseries_buffer == target end - elseif !isobs - @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter(parameter_values(fs)) - if test_inplace - @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter( - [], parameter_values(fs)) + end + + @test getter(fs) == val + if eltype(val) <: Number + target = val + else + target = collect.(val) + end + tmp = deepcopy(buffer) + getter(tmp, fs) + @test tmp == target + + if !isobs + @test getter(parameter_values(fs)) == val[end] + if !(eltype(val) <: Number) + target = collect(val[end]) + tmp = deepcopy(buffer)[end] + getter(tmp, parameter_values(fs)) + @test tmp == target end end for subidx in [ 1, CartesianIndex(1), :, rand(Bool, length(val)), rand(eachindex(val), 3), 1:2] - if indexer_trait <: IndexerNotTimeseries - @test_throws ParameterTimeseriesValueIndexMismatchError{Timeseries} getter( - fs, subidx) - if test_inplace - @test_throws ParameterTimeseriesValueIndexMismatchError{Timeseries} getter( - [], fs, subidx) + if check_inference + @inferred getter(fs, subidx) + if !isa(val[subidx], Number) + @inferred getter(deepcopy(buffer[subidx]), fs, subidx) end + end + @test getter(fs, subidx) == val[subidx] + tmp = deepcopy(buffer[subidx]) + if val[subidx] isa Number + continue + end + target = val[subidx] + if eltype(target) <: Number + target = collect(target) else - if check_inference - @inferred getter(fs, subidx) - if test_inplace && buffer[subidx] isa AbstractArray - @inferred getter(deepcopy(buffer[subidx]), fs, subidx) - end - end - @test getter(fs, subidx) == val[subidx] - if test_inplace && buffer[subidx] isa AbstractArray - tmp = deepcopy(buffer[subidx]) - getter(tmp, fs, subidx) - if val[subidx] isa Tuple - target = collect(val[subidx]) - elseif eltype(val) <: Tuple - target = collect.(val[subidx]) - else - target = val[subidx] - end - @test tmp == target + target = collect.(target) + end + getter(tmp, fs, subidx) + @test tmp == target + end +end + +# IndexerOnlyTimeseries +for (sym, timeseries_index, val, buffer, check_inference) in [ + (bidx, 1, bval, zeros(length(bval)), true), + ([bidx, :b], 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), + ((bidx, :b), 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true), + ([bidx, bidx], 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true), + ((bidx, bidx), 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true) +] + getter = getp(sys, sym) + @test is_indexer_timeseries(getter) isa IndexerOnlyTimeseries + @test indexer_timeseries_index(getter) == timeseries_index + + isscalar = eltype(val) <: Number + + if check_inference + @inferred getter(fs) + @inferred getter(deepcopy(buffer), fs) + end + + @test getter(fs) == val + target = if isscalar + val + else + collect.(val) + end + tmp = deepcopy(buffer) + getter(tmp, fs) + @test tmp == target + + @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter(parameter_values(fs)) + @test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter( + [], parameter_values(fs)) + + for subidx in [ + 1, CartesianIndex(1), :, rand(Bool, length(val)), rand(eachindex(val), 3), 1:2] + if check_inference + @inferred getter(fs, subidx) + if !isa(val[subidx], Number) + @inferred getter(deepcopy(buffer[subidx]), fs, subidx) end end + @test getter(fs, subidx) == val[subidx] + if val[subidx] isa Number + continue + end + tmp = deepcopy(buffer[subidx]) + target = val[subidx] + if eltype(target) <: Number + target = collect(target) + else + target = collect.(target) + end + getter(tmp, fs, subidx) + @test tmp == target + end +end + +# IndexerMixedTimeseries +for sym in [ + [:a, :b, :c], + (:a, :b, :c), + :(b + c), + [:(a + b), :c], + (:(a + b), :c) +] + getter = getp(sys, sym) + @test_throws MixedParameterTimeseriesIndexError getter(fs) + @test_throws MixedParameterTimeseriesIndexError getter([], fs) + for subidx in [1, CartesianIndex(1), :, rand(Bool, 4), rand(1:4, 3), 1:2] + @test_throws MixedParameterTimeseriesIndexError getter(fs, subidx) + @test_throws MixedParameterTimeseriesIndexError getter([], fs, subidx) end end -for sym in [[:a, bidx], (:a, bidx), [1, bidx], (1, bidx), - [bidx, :c], (bidx, :c), [bidx, cidx], (bidx, cidx)] +for sym in [[:a, bidx], (:a, bidx), [1, bidx], (1, bidx)] @test_throws ArgumentError getp(sys, sym) end -for (sym, val) in [ - ([:b, :c], [bval[end], cval[end]]), - ((:b, :c), (bval[end], cval[end])) -] +for (sym, val) in [([:b, :c], [bval[end], cval[end]]) + ((:b, :c), (bval[end], cval[end]))] getter = getp(sys, sym) - @test is_indexer_timeseries(getter) == IndexerNotTimeseries() + @test is_indexer_timeseries(getter) == IndexerMixedTimeseries() @test_throws MixedParameterTimeseriesIndexError getter(fs) @test getter(parameter_values(fs)) == val end -bval_state = [b_timeseries.u[searchsortedlast(b_timeseries.t, t)][] for t in fs.t] -cval_state = [c_timeseries.u[searchsortedlast(c_timeseries.t, t)][] for t in fs.t] xval = getindex.(fs.u, 1) for (sym, val_is_timeseries, val, check_inference) in [ (:a, false, aval, true), ([:a, :d], false, [aval, dval], true), ((:a, :d), false, (aval, dval), true), - (:b, true, bval_state, true), - ([:a, :b], true, vcat.(aval, bval_state), false), - ((:a, :b), true, tuple.(aval, bval_state), true), - ([:b, :c], true, vcat.(bval_state, cval_state), true), - ((:b, :c), true, tuple.(bval_state, cval_state), true), - ([:a, :b, :c], true, vcat.(aval, bval_state, cval_state), false), - ((:a, :b, :c), true, tuple.(aval, bval_state, cval_state), true), - ([:x, :b], true, vcat.(xval, bval_state), false), - ((:x, :b), true, tuple.(xval, bval_state), true), - ([:x, :b, :c], true, vcat.(xval, bval_state, cval_state), false), - ((:x, :b, :c), true, tuple.(xval, bval_state, cval_state), true), - ([:a, :b, :x], true, vcat.(aval, bval_state, xval), false), - ((:a, :b, :x), true, tuple.(aval, bval_state, xval), true), - (:(2b), true, 2 .* bval_state, true), - ([:x, :(2b), :(3c)], true, vcat.(xval, 2 .* bval_state, 3 .* cval_state), true), - ((:x, :(2b), :(3c)), true, tuple.(xval, 2 .* bval_state, 3 .* cval_state), true) + (:b, true, bval, true), + ([:a, :b], true, vcat.(aval, bval), false), + ((:a, :b), true, tuple.(aval, bval), true), + ([:a, :x], true, vcat.(aval, xval), false), + ((:a, :x), true, tuple.(aval, xval), true), + (:(2b), true, 2 .* bval, true), + ([:a, :(2b)], true, vcat.(aval, 2 .* bval), true), + ((:a, :(2b)), true, tuple.(aval, 2 .* bval), true) ] getter = getu(sys, sym) if check_inference @@ -380,26 +448,54 @@ for (sym, val_is_timeseries, val, check_inference) in [ end @test getter(fs) == val + reference = val_is_timeseries ? val : xval for subidx in [ - 1, CartesianIndex(2), :, rand(Bool, length(fs.t)), rand(eachindex(fs.t), 3), 1:2] + 1, CartesianIndex(2), :, rand(Bool, length(reference)), + rand(eachindex(reference), 3), 1:2] if check_inference @inferred getter(fs, subidx) end target = if val_is_timeseries val[subidx] else - if fs.t[subidx] isa AbstractArray - len = length(fs.t[subidx]) - fill(val, len) - else - val - end + val end @test getter(fs, subidx) == target end end -@test_throws ErrorException getp(sys, :not_a_param) +temp_state = ProblemState(; u = fs.u[1], + p = with_updated_parameter_timeseries_values( + sys, parameter_values(fs), 1 => fs.p_ts[1, 1], 2 => fs.p_ts[2, 1]), + t = fs.t[1]) +_xval = temp_state.u[1] +_bval = bval[1] +_cval = cval[1] +for (sym, val, check_inference) in [ + ([:x, :b], [_xval, _bval], false), + ((:x, :c), (_xval, _cval), true), + (:(x + b), _xval + _bval, true), + ([:(2b), :(3x)], [2_bval, 3_xval], true), + ((:(2b), :(3x)), (2_bval, 3_xval), true) +] + getter = getu(sys, sym) + @test_throws MixedParameterTimeseriesIndexError getter(fs) + for subidx in [1, CartesianIndex(2), :, rand(Bool, 4), rand(1:4, 3), 1:2] + @test_throws MixedParameterTimeseriesIndexError getter(fs, subidx) + end + if check_inference + @inferred getter(temp_state) + end + @test getter(temp_state) == val +end + +for sym in [ + :err, + [:err, :b], + (:err, :b) +] + @test_throws ErrorException getp(sys, sym) +end let fs = fs, sys = sys getter = getp(sys, []) diff --git a/test/symbol_cache_test.jl b/test/symbol_cache_test.jl index d4a55ab..23320c4 100644 --- a/test/symbol_cache_test.jl +++ b/test/symbol_cache_test.jl @@ -48,14 +48,14 @@ obsfn5 = observed(sc, (:(x + a), :(y + b))) @test_throws TypeError observed(sc, [:(x + a), 2]) @test_throws TypeError observed(sc, (:(x + a), 2)) -pobsfn1 = parameter_observed(sc, :(a + b + t)).observed_fn +pobsfn1 = parameter_observed(sc, :(a + b + t)) @test pobsfn1(2ones(2), 3.0) == 7.0 -pobsfn2 = parameter_observed(sc, [:(a + b + t), :(a + t)]).observed_fn +pobsfn2 = parameter_observed(sc, [:(a + b + t), :(a + t)]) @test pobsfn2(2ones(2), 3.0) == [7.0, 5.0] buffer = zeros(2) pobsfn2(buffer, 2ones(2), 3.0) @test buffer == [7.0, 5.0] -pobsfn3 = parameter_observed(sc, (:(a + b + t), :(a + t))).observed_fn +pobsfn3 = parameter_observed(sc, (:(a + b + t), :(a + t))) @test pobsfn3(2ones(2), 3.0) == (7.0, 5.0) buffer = zeros(2) pobsfn3(buffer, 2ones(2), 3.0) @@ -67,14 +67,11 @@ pobsfn3(buffer, 2ones(2), 3.0) sc = SymbolCache([:x, :y], [:a, :b, :c], :t; timeseries_parameters = Dict( :b => ParameterTimeseriesIndex(1, 1), :c => ParameterTimeseriesIndex(2, 1))) -@test parameter_observed(sc, :(a + c)).timeseries_idx == 2 -@test parameter_observed(sc, [:a, :c]).timeseries_idx == 2 -@test parameter_observed(sc, (:a, :c)).timeseries_idx == 2 -@test parameter_observed(sc, :(2a)).timeseries_idx === nothing -@test parameter_observed(sc, [:(2a), :(3a)]).timeseries_idx === nothing -@test parameter_observed(sc, (:(2a), :(3a))).timeseries_idx === nothing -@test parameter_observed(sc, [:b, :c]).timeseries_idx === nothing -@test parameter_observed(sc, (:b, :c)).timeseries_idx === nothing +@test only(get_all_timeseries_indexes(sc, :(a + c))) == 2 +@test only(get_all_timeseries_indexes(sc, [:a, :c])) == 2 +@test isempty(get_all_timeseries_indexes(sc, :(2a))) +@test isempty(get_all_timeseries_indexes(sc, [:(2a), :(3a)])) +@test sort(collect(get_all_timeseries_indexes(sc, [:b, :c]))) == [1, 2] @test_throws ArgumentError SymbolCache([:x, :y], [:a, :b], :t; timeseries_parameters = Dict(:c => ParameterTimeseriesIndex(1, 1)))