From 5a8d9a5ff5f216dc54e3add27c3dc562e1e0ac02 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 21 Jun 2024 16:05:09 +0530 Subject: [PATCH] fixup! refactor: rework discrete indexing behavior --- docs/src/api.md | 3 ++- docs/src/complete_sii.md | 6 ++--- src/SymbolicIndexingInterface.jl | 2 +- src/index_provider_interface.jl | 38 +++++--------------------------- src/parameter_indexing.jl | 16 +++++++++----- src/state_indexing.jl | 5 +++-- src/symbol_cache.jl | 34 ++++------------------------ src/value_provider_interface.jl | 11 +++++++-- test/parameter_indexing_test.jl | 2 +- test/symbol_cache_test.jl | 19 +++++++--------- 10 files changed, 45 insertions(+), 91 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 8286edf..bb4558d 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -31,7 +31,6 @@ allvariables ```@docs observed parameter_observed -ParameterObservedFunction ``` #### Parameter timeseries @@ -46,6 +45,8 @@ may change at different times. is_timeseries_parameter timeseries_parameter_index ParameterTimeseriesIndex +get_all_timeseries_indexes +ContinuousTimeseries ``` ## Value provider interface diff --git a/docs/src/complete_sii.md b/docs/src/complete_sii.md index 0fd6509..0b67892 100644 --- a/docs/src/complete_sii.md +++ b/docs/src/complete_sii.md @@ -132,10 +132,8 @@ end In case a type does not support such observed quantities, `is_observed` must be defined to always return `false`, and `observed` does not need to be implemented. -The same process can be followed for [`parameter_observed`](@ref), with the exception -that the returned function must not have `u` in its signature, and must be wrapped in a -[`ParameterObservedFunction`](@ref). In-place versions can also be implemented for -`parameter_observed`. +The same process can be followed for [`parameter_observed`](@ref). In-place versions +can also be implemented for `parameter_observed`. #### Note about constant structure diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 8a1a569..5e302e8 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -14,7 +14,7 @@ include("trait.jl") export is_variable, variable_index, variable_symbols, is_parameter, parameter_index, is_timeseries_parameter, timeseries_parameter_index, ParameterTimeseriesIndex, parameter_symbols, is_independent_variable, independent_variable_symbols, - is_observed, observed, parameter_observed, ParameterObservedFunction, + is_observed, observed, parameter_observed, ContinuousTimeseries, get_all_timeseries_indexes, is_time_dependent, constant_structure, symbolic_container, all_variable_symbols, all_symbols, solvedvariables, allvariables, default_values, diff --git a/src/index_provider_interface.jl b/src/index_provider_interface.jl index a7d0185..c51c7a8 100644 --- a/src/index_provider_interface.jl +++ b/src/index_provider_interface.jl @@ -98,42 +98,14 @@ function timeseries_parameter_index(indp, sym) end end -""" - struct ParameterObservedFunction - function ParameterObservedFunction(timeseries_idx, observed_fn::Function) - function ParameterObservedFunction(observed_fn::Function) - -A struct which stores the parameter observed function and optional timeseries index for -a particular symbol. The timeseries index is optional and may be omitted. Specifying the -timeseries index allows [`getp`](@ref) to return the appropriate timeseries for a -timeseries parameter. - -For time-dependent index providers (where `is_time_dependent(indp)`) `observed_fn` must -have the signature `(p, t) -> [values...]`. For non-time-dependent index providers -(where `!is_time_dependent(indp)`) `observed_fn` must have the signature -`(p) -> [values...]`. To support in-place `getp` methods, `observed_fn` must also have an -additional method which takes `buffer::AbstractArray` as its first argument. The required -values must be written to the buffer in the appropriate order. -""" -struct ParameterObservedFunction{I, F <: Function} - timeseries_idx::I - observed_fn::F -end - -function ParameterObservedFunction(ts_idx, f) - ParameterObservedFunction{typeof(ts_idx), typeof(f)}(ts_idx, f) -end -ParameterObservedFunction(f) = ParameterObservedFunction(nothing, f) - """ parameter_observed(indp, sym) -Return the observed function of `sym` in `indp` as a [`ParameterObservedFunction`](@ref). -If `sym` only involves non-timeseries parameters, the timeseries index in the -`ParameterObservedFunction` should be `nothing`. If `sym` involves timeseries parameters, -the timeseries index should be specified in the `ParameterObservedFunction`. If `sym` -involves timeseries parameters from multiple different timeseries, all the relevant -timeseries indexes should be specified in the `ParameterObservedFunction` as a `Vector`. +Return the observed function of `sym` in `indp`. This functions similarly to +[`observed`](@ref) except that `u` is not an argument of the returned function. For time- +dependent systems, the returned function must have the signature `(p, t) -> [values...]`. +For time-independent systems, the returned function must have the signature +`(p) -> [values...]`. By default, this function returns `nothing`, indicating that the index provider does not support generating parameter observed functions. diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index 6658a44..45b223a 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -334,9 +334,10 @@ function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p) throw(ArgumentError("Index provider does not support `parameter_observed`; cannot use generate function for $p")) end if !is_time_dependent(sys) - return GetParameterObservedNoTime(pofn.observed_fn) + return GetParameterObservedNoTime(pofn) end - return GetParameterObserved{false}(pofn.timeseries_idx, pofn.observed_fn) + ts_idxs = _postprocess_tsidxs(get_all_timeseries_indexes(sys, p)) + return GetParameterObserved{false}(ts_idxs, pofn) end error("Invalid symbol $p for `getp`") end @@ -572,18 +573,20 @@ for (t1, t2) in [ # `getp` errors on older MTK that doesn't support `parameter_observed`. getters = getp.((sys,), p) num_observed = count(is_observed_getter, getters) + p_arr = p isa Tuple ? collect(p) : p if num_observed == 0 return MultipleParametersGetter(getters) else - pofn = parameter_observed(sys, p isa Tuple ? collect(p) : p) + pofn = parameter_observed(sys, p_arr) if pofn === nothing return MultipleParametersGetter.(getters) end if is_time_dependent(sys) - getter = GetParameterObserved{true}(pofn.timeseries_idx, pofn.observed_fn) + ts_idxs = _postprocess_tsidxs(get_all_timeseries_indexes(sys, p_arr)) + getter = GetParameterObserved{true}(ts_idxs, pofn) else - getter = GetParameterObservedNoTime(pofn.observed_fn) + getter = GetParameterObservedNoTime(pofn) end return p isa Tuple ? AsParameterTupleWrapper{length(p)}(getter) : getter end @@ -677,7 +680,8 @@ function _setp(sys, ::ArraySymbolic, ::SymbolicTypeTrait, p) idx = parameter_index(sys, p) return setp(sys, idx; run_hook = false) elseif is_observed(sys, p) && (pobsfn = parameter_observed(sys, p)) !== nothing - return GetParameterObserved{false}(pobsfn.timeseries_idx, pobsfn.observed_fn) + ts_idxs = _postprocess_tsidxs(get_all_timeseries_indexes(sys, p)) + return GetParameterObserved{false}(ts_idxs, pobsfn.observed_fn) end return setp(sys, collect(p); run_hook = false) end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 99b95e2..55c3699 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -174,7 +174,8 @@ for (t1, t2) in [ if isempty(sym) return MultipleGetters(sym) end - ts_idxs = get_all_timeseries_indexes(sys, sym isa Tuple ? collect(sym) : sym) + sym_arr = sym isa Tuple ? collect(sym) : sym + ts_idxs = get_all_timeseries_indexes(sys, sym_arr) if ContinuousTimeseries() in ts_idxs && length(ts_idxs) > 1 throw(MixedContinuousParameterTimeseriesError(sym)) end @@ -187,7 +188,7 @@ for (t1, t2) in [ getters = getu.((sys,), sym) return MultipleGetters(getters) else - obs = observed(sys, sym isa Tuple ? collect(sym) : sym) + obs = observed(sys, sym_arr) getter = if is_time_dependent(sys) TimeDependentObservedFunction(obs) else diff --git a/src/symbol_cache.jl b/src/symbol_cache.jl index 8b17441..2dc82b7 100644 --- a/src/symbol_cache.jl +++ b/src/symbol_cache.jl @@ -245,26 +245,13 @@ function parameter_observed(sc::SymbolCache, expr::Expr) if is_time_dependent(sc) exs = ExpressionSearcher() exs(sc, expr) - ts_idxs = Set() - for p in exs.parameters - is_timeseries_parameter(sc, p) || continue - push!(ts_idxs, timeseries_parameter_index(sc, p).timeseries_idx) - end - f = let fn = observed(sc, expr) + return let fn = observed(sc, expr) f1(p, t) = fn(nothing, p, t) end - if isempty(ts_idxs) - return ParameterObservedFunction(f) - elseif length(ts_idxs) == 1 - return ParameterObservedFunction(only(ts_idxs), f) - else - return ParameterObservedFunction(collect(ts_idxs), f) - end else - f = let fn = observed(sc, expr) + return let fn = observed(sc, expr) f2(p) = fn(nothing, p) end - return ParameterObservedFunction(nothing, f) end end @@ -277,29 +264,16 @@ function parameter_observed(sc::SymbolCache, exprs::Union{AbstractArray, Tuple}) if is_time_dependent(sc) exs = ExpressionSearcher() exs(sc, to_expr(exprs)) - ts_idxs = Set() - for p in exs.parameters - is_timeseries_parameter(sc, p) || continue - push!(ts_idxs, timeseries_parameter_index(sc, p).timeseries_idx) - end - f = let oop = observed(sc, to_expr(exprs)), iip = inplace_observed(sc, exprs) + return let oop = observed(sc, to_expr(exprs)), iip = inplace_observed(sc, exprs) f1(p, t) = oop(nothing, p, t) f1(buffer, p, t) = iip(buffer, nothing, p, t) end - if isempty(ts_idxs) - return ParameterObservedFunction(f) - elseif length(ts_idxs) == 1 - return ParameterObservedFunction(only(ts_idxs), f) - else - return ParameterObservedFunction(collect(ts_idxs), f) - end else - f = let oop = observed(sc, to_expr(exprs)), iip = inplace_observed(sc, exprs) + return let oop = observed(sc, to_expr(exprs)), iip = inplace_observed(sc, exprs) f2(p) = oop(nothing, p) f2(buffer, p) = iip(buffer, nothing, p) end - return ParameterObservedFunction(nothing, f) end end diff --git a/src/value_provider_interface.jl b/src/value_provider_interface.jl index 69394ca..7fd7f06 100644 --- a/src/value_provider_interface.jl +++ b/src/value_provider_interface.jl @@ -182,8 +182,15 @@ as_timeseries_indexer(::IndexerNotTimeseries, x) = x as_not_timeseries_indexer(x) = as_not_timeseries_indexer(is_indexer_timeseries(x), x) as_not_timeseries_indexer(::IndexerNotTimeseries, x) = x -struct MultipleTimeseriesIndexes{I} - idxs::Vector{I} +function _postprocess_tsidxs(ts_idxs) + delete!(ts_idxs, ContinuousTimeseries()) + if isempty(ts_idxs) + return nothing + elseif length(ts_idxs) == 1 + return only(ts_idxs) + else + return collect(ts_idxs) + end end struct CallWith{A} diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 642219e..b276ce4 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -227,7 +227,7 @@ for (sym, val, buffer, check_inference) in [ ([:(a + d), :(a * d)], [aval + dval, aval * dval], zeros(2), true), ((:(a + d), :(a * d)), (aval + dval, aval * dval), zeros(2), true) ] - getter = getp(sys, sym) + getter = getp(fs, sym) @test is_indexer_timeseries(getter) isa IndexerNotTimeseries test_inplace = buffer !== nothing is_observed = sym isa Expr || diff --git a/test/symbol_cache_test.jl b/test/symbol_cache_test.jl index 43fdf73..23320c4 100644 --- a/test/symbol_cache_test.jl +++ b/test/symbol_cache_test.jl @@ -48,14 +48,14 @@ obsfn5 = observed(sc, (:(x + a), :(y + b))) @test_throws TypeError observed(sc, [:(x + a), 2]) @test_throws TypeError observed(sc, (:(x + a), 2)) -pobsfn1 = parameter_observed(sc, :(a + b + t)).observed_fn +pobsfn1 = parameter_observed(sc, :(a + b + t)) @test pobsfn1(2ones(2), 3.0) == 7.0 -pobsfn2 = parameter_observed(sc, [:(a + b + t), :(a + t)]).observed_fn +pobsfn2 = parameter_observed(sc, [:(a + b + t), :(a + t)]) @test pobsfn2(2ones(2), 3.0) == [7.0, 5.0] buffer = zeros(2) pobsfn2(buffer, 2ones(2), 3.0) @test buffer == [7.0, 5.0] -pobsfn3 = parameter_observed(sc, (:(a + b + t), :(a + t))).observed_fn +pobsfn3 = parameter_observed(sc, (:(a + b + t), :(a + t))) @test pobsfn3(2ones(2), 3.0) == (7.0, 5.0) buffer = zeros(2) pobsfn3(buffer, 2ones(2), 3.0) @@ -67,14 +67,11 @@ pobsfn3(buffer, 2ones(2), 3.0) sc = SymbolCache([:x, :y], [:a, :b, :c], :t; timeseries_parameters = Dict( :b => ParameterTimeseriesIndex(1, 1), :c => ParameterTimeseriesIndex(2, 1))) -@test parameter_observed(sc, :(a + c)).timeseries_idx == 2 -@test parameter_observed(sc, [:a, :c]).timeseries_idx == 2 -@test parameter_observed(sc, (:a, :c)).timeseries_idx == 2 -@test parameter_observed(sc, :(2a)).timeseries_idx === nothing -@test parameter_observed(sc, [:(2a), :(3a)]).timeseries_idx === nothing -@test parameter_observed(sc, (:(2a), :(3a))).timeseries_idx === nothing -@test sort(parameter_observed(sc, [:b, :c]).timeseries_idx) == [1, 2] -@test sort(parameter_observed(sc, (:b, :c)).timeseries_idx) == [1, 2] +@test only(get_all_timeseries_indexes(sc, :(a + c))) == 2 +@test only(get_all_timeseries_indexes(sc, [:a, :c])) == 2 +@test isempty(get_all_timeseries_indexes(sc, :(2a))) +@test isempty(get_all_timeseries_indexes(sc, [:(2a), :(3a)])) +@test sort(collect(get_all_timeseries_indexes(sc, [:b, :c]))) == [1, 2] @test_throws ArgumentError SymbolCache([:x, :y], [:a, :b], :t; timeseries_parameters = Dict(:c => ParameterTimeseriesIndex(1, 1)))