From aff02b3b3918546bc4c4faf338902bd0fbf77be9 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 21 Jun 2024 12:30:29 +0530 Subject: [PATCH] fixup! refactor: rework discrete indexing behavior --- src/SymbolicIndexingInterface.jl | 1 + src/parameter_indexing.jl | 24 +++++----- src/parameter_timeseries_collection.jl | 64 -------------------------- src/state_indexing.jl | 37 ++++----------- test/parameter_indexing_test.jl | 28 +++++++---- test/symbol_cache_test.jl | 4 +- 6 files changed, 43 insertions(+), 115 deletions(-) diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 93c5398..8a1a569 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -15,6 +15,7 @@ export is_variable, variable_index, variable_symbols, is_parameter, parameter_in is_timeseries_parameter, timeseries_parameter_index, ParameterTimeseriesIndex, parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed, observed, parameter_observed, ParameterObservedFunction, + ContinuousTimeseries, get_all_timeseries_indexes, is_time_dependent, constant_structure, symbolic_container, all_variable_symbols, all_symbols, solvedvariables, allvariables, default_values, symbolic_evaluate diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index 1c82d15..6658a44 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -177,6 +177,14 @@ as_timeseries_indexer(::IndexerBoth, gpo::GetParameterObserved) = gpo as_not_timeseries_indexer(::IndexerBoth, gpo::GetParameterObserved) = gpo as_not_timeseries_indexer(::IndexerMixedTimeseries, gpo::GetParameterObserved) = gpo +function (gpo::GetParameterObserved{Nothing})(::Timeseries, prob) + gpo.obsfn(parameter_values(prob), current_time(prob)[end]) +end +function (gpo::GetParameterObserved{Nothing, true})( + buffer::AbstractArray, ::Timeseries, prob) + gpo.obsfn(buffer, parameter_values(prob), current_time(prob)[end]) + return buffer +end for argType in [Union{Int, CartesianIndex}, Colon, AbstractArray{Bool}, Any] @eval function (gpo::GetParameterObserved{Nothing})(::Timeseries, prob, args::$argType) gpo.obsfn(parameter_values(prob), current_time(prob)[end]) @@ -535,7 +543,6 @@ end function (atw::AsParameterTupleWrapper)(ts::Timeseries, ::AtLeastTimeseriesIndexer, prob, i) wrap_tuple.((atw,), atw.getter(ts, prob, i)) end -# args is just so it throws function (atw::AsParameterTupleWrapper)( ts::Timeseries, ::IndexerNotTimeseries, prob, args...) wrap_tuple(atw, atw.getter(ts, prob, args...)) @@ -554,16 +561,6 @@ is_observed_getter(::GetParameterObserved) = true is_observed_getter(::GetParameterObservedNoTime) = true is_observed_getter(mpg::MultipleParametersGetter) = any(is_observed_getter, mpg.getters) -function _supports_parameter_observed(indp, sym) - if hasmethod(parameter_observed, Tuple{typeof(indp), typeof(sym)}) - return true - elseif hasmethod(symbolic_container, Tuple{typeof(indp)}) - return _supports_parameter_observed(symbolic_container(indp), sym) - else - return false - end -end - for (t1, t2) in [ (ArraySymbolic, Any), (ScalarSymbolic, Any), @@ -576,10 +573,13 @@ for (t1, t2) in [ getters = getp.((sys,), p) num_observed = count(is_observed_getter, getters) - if num_observed == 0 || !_supports_parameter_observed(sys, p) + if num_observed == 0 return MultipleParametersGetter(getters) else pofn = parameter_observed(sys, p isa Tuple ? collect(p) : p) + if pofn === nothing + return MultipleParametersGetter.(getters) + end if is_time_dependent(sys) getter = GetParameterObserved{true}(pofn.timeseries_idx, pofn.observed_fn) else diff --git a/src/parameter_timeseries_collection.jl b/src/parameter_timeseries_collection.jl index 0da27ab..535b8b7 100644 --- a/src/parameter_timeseries_collection.jl +++ b/src/parameter_timeseries_collection.jl @@ -124,30 +124,6 @@ function parameter_values_at_time(valp, t) (ts_idx => _timeseries_value(ptc, ts_idx, t) for ts_idx in eachindex(ptc))...) end -""" - parameter_values_at_state_time(valp, i) - parameter_values_at_state_time(valp) - -Return an indexable collection containing the value of all parameters in `valp` at time -index `i` in the state timeseries. - -By default, this function relies on [`parameter_values_at_time`](@ref) and -[`current_time`](@ref) for a default implementation. - -The single-argument version of this function is a shorthand to return parameter values -at each point in the state timeseries. This also has a default implementation relying on -[`parameter_values_at_time`](@ref) and [`current_time`](@ref). -""" -function parameter_values_at_state_time end - -function parameter_values_at_state_time(p, i) - state_time = current_time(p, i) - return parameter_values_at_time(p, state_time) -end -function parameter_values_at_state_time(p) - return (parameter_values_at_time(p, t) for t in current_time(p)) -end - """ parameter_timeseries(valp, i) @@ -163,43 +139,3 @@ function parameter_timeseries end function parameter_timeseries(valp, i) return parameter_timeseries(get_parameter_timeseries_collection(valp), i) end - -""" - parameter_timeseries_at_state_time(valp, i, j) - parameter_timeseries_at_state_time(valp, i) - -Return the index of the timestep in the parameter timeseries at timeseries index `i` which -occurs just before or at the same time as the state timestep with index `j`. The two- -argument version of this function returns an iterable of indexes, one for each timestep in -the state timeseries. If `j` is an object that refers to multiple values in the state -timeseries (e.g. `Colon`), return an iterable of the indexes in the parameter timeseries -at the appropriate points. - -Both versions of this function have default implementations relying on -[`current_time`](@ref) and [`parameter_timeseries`](@ref), for the cases where `j` is one -of: `Int`, `CartesianIndex`, `AbstractArray{Bool}`, `Colon` or an iterable of the -aforementioned. -""" -function parameter_timeseries_at_state_time end - -function parameter_timeseries_at_state_time(valp, i, j::Union{Int, CartesianIndex}) - state_time = current_time(valp, j) - timeseries = parameter_timeseries(valp, i) - searchsortedlast(timeseries, state_time) -end - -function parameter_timeseries_at_state_time(valp, i, ::Colon) - parameter_timeseries_at_state_time(valp, i) -end - -function parameter_timeseries_at_state_time(valp, i, j::AbstractArray{Bool}) - parameter_timeseries_at_state_time(valp, i, only(to_indices(current_time(valp), (j,)))) -end - -function parameter_timeseries_at_state_time(valp, i, j) - (parameter_timeseries_at_state_time(valp, i, jj) for jj in j) -end - -function parameter_timeseries_at_state_time(valp, i) - parameter_timeseries_at_state_time(valp, i, eachindex(current_time(valp))) -end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 7583ae3..99b95e2 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -61,43 +61,24 @@ struct TimeDependentObservedFunction{F} <: AbstractStateGetIndexer end function (o::TimeDependentObservedFunction)(ts::Timeseries, prob) - return o(ts, is_parameter_timeseries(prob), prob) -end -function (o::TimeDependentObservedFunction)(::Timeseries, ::Timeseries, prob) - map(o.obsfn, state_values(prob), - parameter_values_at_state_time(prob), current_time(prob)) -end -function (o::TimeDependentObservedFunction)(::Timeseries, ::NotTimeseries, prob) - o.obsfn.(state_values(prob), + return o.obsfn.(state_values(prob), (parameter_values(prob),), current_time(prob)) end -function (o::TimeDependentObservedFunction)(ts::Timeseries, prob, i) - return o(ts, is_parameter_timeseries(prob), prob, i) -end function (o::TimeDependentObservedFunction)( - ::Timeseries, ::Timeseries, prob, i::Union{Int, CartesianIndex}) - return o.obsfn(state_values(prob, i), - parameter_values_at_state_time(prob, i), - current_time(prob, i)) + ::Timeseries, prob, i::Union{Int, CartesianIndex}) + return o.obsfn(state_values(prob, i), parameter_values(prob), current_time(prob, i)) end -function (o::TimeDependentObservedFunction)( - ts::Timeseries, p_ts::IsTimeseriesTrait, prob, ::Colon) - return o(ts, p_ts, prob) +function (o::TimeDependentObservedFunction)(ts::Timeseries, prob, ::Colon) + return o(ts, prob) end -function (o::TimeDependentObservedFunction)( - ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i::AbstractArray{Bool}) +function (o::TimeDependentObservedFunction)(ts::Timeseries, prob, i::AbstractArray{Bool}) map(only(to_indices(current_time(prob), (i,)))) do idx - o(ts, p_ts, prob, idx) + o(ts, prob, idx) end end -function (o::TimeDependentObservedFunction)( - ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i) - o.((ts,), (p_ts,), (prob,), i) -end -function (o::TimeDependentObservedFunction)( - ::Timeseries, ::NotTimeseries, prob, i::Union{Int, CartesianIndex}) - o.obsfn(state_values(prob, i), parameter_values(prob), current_time(prob, i)) +function (o::TimeDependentObservedFunction)(ts::Timeseries, prob, i) + o.((ts,), (prob,), i) end function (o::TimeDependentObservedFunction)(::NotTimeseries, prob) return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob)) diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index fcc5008..642219e 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -56,8 +56,7 @@ for sys in [ ([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), ([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), ((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), - ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true) - ] + ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)] get = getp(sys, sym) set! = setp(sys, sym) if check_inference @@ -223,27 +222,38 @@ for (sym, val, buffer, check_inference) in [ ([1, 4], [aval, dval], zeros(2), true), ((1, 4), (aval, dval), zeros(2), true), ([:a, 4], [aval, dval], zeros(2), true), - ((:a, 4), (aval, dval), zeros(2), true) + ((:a, 4), (aval, dval), zeros(2), true), + (:(a + d), aval + dval, nothing, true), + ([:(a + d), :(a * d)], [aval + dval, aval * dval], zeros(2), true), + ((:(a + d), :(a * d)), (aval + dval, aval * dval), zeros(2), true) ] getter = getp(sys, sym) @test is_indexer_timeseries(getter) isa IndexerNotTimeseries test_inplace = buffer !== nothing - + is_observed = sym isa Expr || + sym isa Union{AbstractArray, Tuple} && any(x -> x isa Expr, sym) if check_inference @inferred getter(fs) - @inferred getter(parameter_values(fs)) + if !is_observed + @inferred getter(parameter_values(fs)) + end if test_inplace @inferred getter(deepcopy(buffer), fs) - @inferred getter(deepcopy(buffer), parameter_values(fs)) + if !is_observed + @inferred getter(deepcopy(buffer), parameter_values(fs)) + end end end @test getter(fs) == val - @test getter(parameter_values(fs)) == val + if !is_observed + @test getter(parameter_values(fs)) == val + end if test_inplace target = collect(val) - for obj in (fs, parameter_values(fs)) + valps = is_observed ? (fs,) : (fs, parameter_values(fs)) + for valp in valps tmp = deepcopy(buffer) - getter(tmp, obj) + getter(tmp, valp) @test tmp == target end end diff --git a/test/symbol_cache_test.jl b/test/symbol_cache_test.jl index d4a55ab..43fdf73 100644 --- a/test/symbol_cache_test.jl +++ b/test/symbol_cache_test.jl @@ -73,8 +73,8 @@ sc = SymbolCache([:x, :y], [:a, :b, :c], :t; @test parameter_observed(sc, :(2a)).timeseries_idx === nothing @test parameter_observed(sc, [:(2a), :(3a)]).timeseries_idx === nothing @test parameter_observed(sc, (:(2a), :(3a))).timeseries_idx === nothing -@test parameter_observed(sc, [:b, :c]).timeseries_idx === nothing -@test parameter_observed(sc, (:b, :c)).timeseries_idx === nothing +@test sort(parameter_observed(sc, [:b, :c]).timeseries_idx) == [1, 2] +@test sort(parameter_observed(sc, (:b, :c)).timeseries_idx) == [1, 2] @test_throws ArgumentError SymbolCache([:x, :y], [:a, :b], :t; timeseries_parameters = Dict(:c => ParameterTimeseriesIndex(1, 1)))