From 43b5587b75eed9ee1bac5e28c95af314e8f043d7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 20 Sep 2024 18:22:06 +0530 Subject: [PATCH] feat: support observed functions for history-dependent systems --- src/index_provider_interface.jl | 33 +++++++++++++++++++++++++++++++- src/state_indexing.jl | 34 +++++++++++++++++++++++++++++---- src/value_provider_interface.jl | 10 ++++++++++ test/state_indexing_test.jl | 31 ++++++++++++++++++++++++++++++ 4 files changed, 103 insertions(+), 5 deletions(-) diff --git a/src/index_provider_interface.jl b/src/index_provider_interface.jl index 593f2716..2688b74f 100644 --- a/src/index_provider_interface.jl +++ b/src/index_provider_interface.jl @@ -201,7 +201,13 @@ the order of states or a time index, which identifies the order of states. This does not need to be defined if [`is_observed`](@ref) always returns `false`. Thus, it is mandatory to always check `is_observed` before using this function. -See also: [`is_time_dependent`](@ref), [`constant_structure`](@ref) +If `is_historical(indp)`, the returned function must have the signature +`(u, h, p, t) -> [values...]` where `h` is the history function, which can be called +to obtain past values of the state. The exact signature and semantics of `h` depend +on how it is used inside the returned function. `h` is obtained from a value +provider using [`get_history_function`](@ref). + +See also: [`is_time_dependent`](@ref), [`is_historical`](@ref), [`constant_structure`](@ref). """ observed(indp, sym) = observed(symbolic_container(indp), sym) observed(indp, sym, states) = observed(symbolic_container(indp), sym, states) @@ -213,6 +219,31 @@ Check if `indp` has time as (one of) its independent variables. """ is_time_dependent(indp) = is_time_dependent(symbolic_container(indp)) +""" + is_historical(indp) + +Check if an index provider is historical. Historical index providers require values of states +at arbitrary offsets in the past to simulate, and these historical values are not present +as part of the states themselves. For example, systems of DDEs are historical. However, +a discrete system is not historical since the state realization includes the historical values +required to compute the next state. + +This function is only required for time-dependent index providers. + +Historical index providers return [`observed`](@ref) functions with a different signature. +All value providers associated with a historical index provider must implement +[`get_history_function`](@ref). + +Returns `false` by default. +""" +function is_historical(indp) + if hasmethod(symbolic_container, Tuple{typeof(indp)}) + is_historical(symbolic_container(indp)) + else + false + end +end + """ constant_structure(indp) diff --git a/src/state_indexing.jl b/src/state_indexing.jl index af517664..be8ff9ca 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -56,11 +56,17 @@ struct GetIndepvar <: AbstractStateGetIndexer end (::GetIndepvar)(::IsTimeseriesTrait, prob) = current_time(prob) (::GetIndepvar)(::Timeseries, prob, i) = current_time(prob, i) -struct TimeDependentObservedFunction{I, F} <: AbstractStateGetIndexer +struct TimeDependentObservedFunction{I, F, H} <: AbstractStateGetIndexer ts_idxs::I obsfn::F end +function TimeDependentObservedFunction{H}(ts_idxs, obsfn) where {H} + return TimeDependentObservedFunction{typeof(ts_idxs), typeof(obsfn), H}(ts_idxs, obsfn) +end + +const HistoricalObservedFunction = TimeDependentObservedFunction{I, F, true} where {I, F} + indexer_timeseries_index(t::TimeDependentObservedFunction) = t.ts_idxs function is_indexer_timeseries(::Type{G}) where {G <: TimeDependentObservedFunction{ContinuousTimeseries}} @@ -74,8 +80,14 @@ function (o::TimeDependentObservedFunction)(ts::IsTimeseriesTrait, prob, args... return o(ts, is_indexer_timeseries(o), prob, args...) end -function (o::TimeDependentObservedFunction)(ts::Timeseries, ::IndexerBoth, prob) +function (o::TimeDependentObservedFunction)(::Timeseries, ::IndexerBoth, prob) + return o.obsfn.(state_values(prob), + (parameter_values(prob),), + current_time(prob)) +end +function (o::HistoricalObservedFunction)(::Timeseries, ::IndexerBoth, prob) return o.obsfn.(state_values(prob), + (get_history_function(prob),), (parameter_values(prob),), current_time(prob)) end @@ -83,6 +95,11 @@ function (o::TimeDependentObservedFunction)( ::Timeseries, ::IndexerBoth, prob, i::Union{Int, CartesianIndex}) return o.obsfn(state_values(prob, i), parameter_values(prob), current_time(prob, i)) end +function (o::HistoricalObservedFunction)( + ::Timeseries, ::IndexerBoth, prob, i::Union{Int, CartesianIndex}) + return o.obsfn(state_values(prob, i), get_history_function(prob), + parameter_values(prob), current_time(prob, i)) +end function (o::TimeDependentObservedFunction)(ts::Timeseries, ::IndexerBoth, prob, ::Colon) return o(ts, prob) end @@ -98,6 +115,10 @@ end function (o::TimeDependentObservedFunction)(::NotTimeseries, ::IndexerBoth, prob) return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob)) end +function (o::HistoricalObservedFunction)(::NotTimeseries, ::IndexerBoth, prob) + return o.obsfn(state_values(prob), get_history_function(prob), + parameter_values(prob), current_time(prob)) +end function (o::TimeDependentObservedFunction)( ::Timeseries, ::IndexerMixedTimeseries, prob, args...) @@ -107,6 +128,11 @@ function (o::TimeDependentObservedFunction)( ::NotTimeseries, ::IndexerMixedTimeseries, prob, args...) return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob)) end +function (o::HistoricalObservedFunction)( + ::NotTimeseries, ::IndexerMixedTimeseries, prob, args...) + return o.obsfn(state_values(prob), get_history_function(prob), + parameter_values(prob), current_time(prob)) +end struct TimeIndependentObservedFunction{F} <: AbstractStateGetIndexer obsfn::F @@ -137,7 +163,7 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) ts_idxs = collect(ts_idxs) end fn = observed(sys, sym) - return TimeDependentObservedFunction(ts_idxs, fn) + return TimeDependentObservedFunction{is_historical(sys)}(ts_idxs, fn) else return getp(sys, sym) end @@ -256,7 +282,7 @@ for (t1, t2) in [ else obs = observed(sys, sym_arr) getter = if is_time_dependent(sys) - TimeDependentObservedFunction(ts_idxs, obs) + TimeDependentObservedFunction{is_historical(sys)}(ts_idxs, obs) else TimeIndependentObservedFunction(obs) end diff --git a/src/value_provider_interface.jl b/src/value_provider_interface.jl index 79030172..a31d62f6 100644 --- a/src/value_provider_interface.jl +++ b/src/value_provider_interface.jl @@ -142,6 +142,16 @@ current_time(arr::AbstractVector) = arr current_time(valp, i) = current_time(valp)[i] current_time(valp, ::Colon) = current_time(valp) +""" + get_history_function(valp) + +Return the history function for a value provider. This is required for all value providers +associated with an index provider `indp` for which `is_historical(indp)`. + +See also: [`is_historical`](@ref). +""" +function get_history_function end + ########### # Utilities ########### diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index c2b90b63..fa2f0180 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -291,3 +291,34 @@ for (sym, val, check_inference) in [ end @test getter(fs) == val end + +struct HistoricalWrapper{S <: SymbolCache} + sys::S +end + +SymbolicIndexingInterface.symbolic_container(hw::HistoricalWrapper) = hw.sys +SymbolicIndexingInterface.is_historical(::HistoricalWrapper) = true +function SymbolicIndexingInterface.observed(hw::HistoricalWrapper, sym) + let inner = observed(hw.sys, sym) + fn(u, h, p, t) = inner(u .+ h(t - 0.1), p, t) + end +end +function SymbolicIndexingInterface.get_history_function(fs::FakeSolution) + t -> t .* ones(length(fs.u[1])) +end +function SymbolicIndexingInterface.get_history_function(fi::FakeIntegrator) + t -> t .* ones(length(fi.u)) +end + +sys = HistoricalWrapper(SymbolCache([:x, :y, :z], [:a, :b, :c], :t)) +u0 = [1.0, 2.0, 3.0] +u = [u0 .* i for i in 1:11] +p = [10.0, 20.0, 30.0] +ts = 0.0:0.1:1.0 + +fi = FakeIntegrator(sys, u0, p, ts[1]) +fs = FakeSolution(sys, u, p, ts) +getter = getu(sys, :(x + y)) +@test getter(fi) ≈ 2.8 +@test getter(fs) ≈ [3.0i + 2(ts[i] - 0.1) for i in 1:11] +@test getter(fs, 1) ≈ 2.8