From b1a7f72a0a43cca5c16dde03648718477fe9621e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 1 Jan 2024 21:57:26 +0530 Subject: [PATCH] feat: add `IsTimeseriesTrait`, support timeseries objects in `getu` --- docs/src/api.md | 3 + src/SymbolicIndexingInterface.jl | 2 +- src/state_indexing.jl | 118 ++++++++++++++++++++++++++----- test/state_indexing_test.jl | 24 ++++++- 4 files changed, 127 insertions(+), 20 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index dc6a579..4470842 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -18,6 +18,9 @@ all_variable_symbols all_symbols solvedvariables allvariables +Timeseries +NotTimeseries +is_timeseries state_values parameter_values current_time diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 1635737..7f79e66 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -15,7 +15,7 @@ include("symbol_cache.jl") export parameter_values, getp, setp include("parameter_indexing.jl") -export state_values, current_time, getu, setu +export Timeseries, NotTimeseries, is_timeseries, state_values, current_time, getu, setu include("state_indexing.jl") end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 9ce95c3..284cfc9 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -1,29 +1,77 @@ +abstract type IsTimeseriesTrait end + +""" + struct Timeseries <: IsTimeseriesTrait end + +Trait indicating a type contains timeseries data. This affects the behaviour of +functions such as [`state_values`](@ref) and [`current_time`](@ref). + +See also: [`NotTimeseries`](@ref), [`is_timeseries`](@ref) +""" +struct Timeseries <: IsTimeseriesTrait end + +""" + struct NotTimeseries <: IsTimeseriesTrait end + +Trait indicating a type does not contain timeseries data. This affects the behaviour +of functions such as [`state_values`](@ref) and [`current_time`](@ref). Note that +if a type is `NotTimeseries` this only implies that it does not _store_ timeseries +data. It may still be time-dependent. For example, an `ODEProblem` only stores +the initial state of a system, so it is `NotTimeseries`, but still time-dependent. +This is the default trait variant for all types. + +See also: [`Timeseries`](@ref), [`is_timeseries`](@ref) +""" +struct NotTimeseries <: IsTimeseriesTrait end + +""" + is_timeseries(x) = is_timeseries(typeof(x)) + is_timeseries(::Type) + +Get the timeseries trait of a type. Defaults to [`NotTimeseries`](@ref) for all types. + +See also: [`Timeseries`](@ref), [`NotTimeseries`](@ref) +""" +function is_timeseries end + +is_timeseries(x) = is_timeseries(typeof(x)) +is_timeseries(::Type) = NotTimeseries() + """ state_values(p) Return an indexable collection containing the values of all states in the integrator or -problem `p`. +problem `p`. If `is_timeseries(p)` is [`Timeseries`](@ref), return a vector of arrays, +each of which contain the state values at the corresponding timestep. + +See: [`is_timeseries`](@ref) """ function state_values end """ current_time(p) -Return the current time in the integrator or problem `p`. +Return the current time in the integrator or problem `p`. If +`is_timeseries(p)` is [`Timeseries`](@ref), return the vector of timesteps at which +the state value is saved. + + +See: [`is_timeseries`](@ref) """ function current_time end """ getu(sys, sym) -Return a function that takes an integrator or problem of `sys`, and returns the value of -the symbolic `sym`. `sym` can be a direct index into the state vector, a symbolic state, -a symbolic expression involving symbolic quantities in the system `sys`, or an -array/tuple of the aforementioned. +Return a function that takes an integrator, problem or solution of `sys`, and returns +the value of the symbolic `sym`. `sym` can be a direct index into the state vector, a +symbolic state, a symbolic expression involving symbolic quantities in the system +`sys`, or an array/tuple of the aforementioned. -At minimum, this requires that the integrator or problem implement [`state_values`](@ref). -To support symbolic expressions, the integrator or problem must implement -[`observed`](@ref), [`parameter_values`](@ref) and [`current_time`](@ref). +At minimum, this requires that the integrator, problem or solution implement +[`state_values`](@ref). To support symbolic expressions, the integrator or problem +must implement [`observed`](@ref), [`parameter_values`](@ref) and +[`current_time`](@ref). This function typically does not need to be implemented, and has a default implementation relying on the above functions. @@ -40,37 +88,68 @@ function getu(sys, sym) end function _getu(sys, ::NotSymbolic, sym) + _getter(::Timeseries, prob) = getindex.(state_values(prob), (sym,)) + _getter(::NotTimeseries, prob) = state_values(prob)[sym] return function getter(prob) - return state_values(prob)[sym] + return _getter(is_timeseries(prob), prob) end end function _getu(sys, ::ScalarSymbolic, sym) if is_variable(sys, sym) idx = variable_index(sys, sym) - return function getter1(prob) - return state_values(prob)[idx] - end + return getu(sys, idx) elseif is_observed(sys, sym) fn = observed(sys, sym) if is_time_dependent(sys) - function getter2(prob) + function _getter2(::Timeseries, prob) + return fn.(state_values(prob), + (parameter_values(prob),), + current_time(prob)) + end + function _getter2(::NotTimeseries, prob) return fn(state_values(prob), parameter_values(prob), current_time(prob)) end + + return function getter2(prob) + return _getter2(is_timeseries(prob), prob) + end else - function getter3(prob) + function _getter3(::Timeseries, prob) + return fn.(state_values(prob), (parameter_values(prob),)) + end + function _getter3(::NotTimeseries, prob) return fn(state_values(prob), parameter_values(prob)) end + + return function getter3(prob) + return _getter3(is_timeseries(prob), prob) + end end end error("Invalid symbol $sym for `getu`") end -function _getu(sys, ::ScalarSymbolic, sym::Union{<:Tuple,<:AbstractArray}) +struct TimeseriesIndexWrapper{T, I} + timeseries::T + idx::I +end + +state_values(t::TimeseriesIndexWrapper) = state_values(t.timeseries)[t.idx] +parameter_values(t::TimeseriesIndexWrapper) = parameter_values(t.timeseries) +current_time(t::TimeseriesIndexWrapper) = current_time(t.timeseries)[t.idx] + +function _getu(sys, ::ScalarSymbolic, sym::Union{<:Tuple, <:AbstractArray}) getters = getu.((sys,), sym) _call(getter, prob) = getter(prob) + + function _getter(::Timeseries, prob) + tiws = TimeseriesIndexWrapper.((prob,), eachindex(state_values(prob))) + return [_getter(NotTimeseries(), tiw) for tiw in tiws] + end + _getter(::NotTimeseries, prob) = _call.(getters, (prob,)) return function getter(prob) - return _call.(getters, (prob,)) + return _getter(is_timeseries(prob), prob) end end @@ -86,6 +165,9 @@ the state `sym` to that value. Note that `sym` can be a direct numerical index, Requires that the integrator implement [`state_values`](@ref) and the returned collection be a mutable reference to the state vector in the integrator/problem. +This function does not work on types for which [`is_timeseries`](@ref) is +[`Timeseries`](@ref). + In case `state_values` cannot return such a mutable reference, `setu` needs to be implemented manually. """ @@ -114,7 +196,7 @@ function _setu(sys, ::ScalarSymbolic, sym) end end -function _setu(sys, ::ScalarSymbolic, sym::Union{<:Tuple,<:AbstractArray}) +function _setu(sys, ::ScalarSymbolic, sym::Union{<:Tuple, <:AbstractArray}) setters = setu.((sys,), sym) _call!(setter!, prob, val) = setter!(prob, val) return function setter!(prob, val) diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index 87a1a8f..7fc3a61 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -1,6 +1,6 @@ using SymbolicIndexingInterface -struct FakeIntegrator{S,U} +struct FakeIntegrator{S, U} sys::S u::U end @@ -20,3 +20,25 @@ for (i, sym) in [(1, :x), (2, :y), (3, :z), ([1, 2], [:x, :y]), ((3, 2), (:z, :y @test get(fi) == 0.5 .* i set!(fi, true_value) end + +struct FakeSolution{S, U} + sys::S + u::U +end + +SymbolicIndexingInterface.is_timeseries(::Type{<:FakeSolution}) = Timeseries() +SymbolicIndexingInterface.symbolic_container(fp::FakeSolution) = fp.sys +SymbolicIndexingInterface.state_values(fp::FakeSolution) = fp.u + +sys = SymbolCache([:x, :y, :z], [:a, :b], [:t]) +u = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] +sol = FakeSolution(sys, u) +for (i, sym) in [(1, :x), (2, :y), (3, :z), ([1, 2], [:x, :y]), ((3, 2), (:z, :y))] + get = getu(sys, sym) + true_value = if i isa Tuple + [getindex.((v,), i) for v in u] + else + getindex.(u, (i,)) + end + @test get(sol) == true_value +end