Skip to content

Commit

Permalink
feat: add IsTimeseriesTrait, support timeseries objects in getu
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 1, 2024
1 parent 9c5e3b5 commit b1a7f72
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 20 deletions.
3 changes: 3 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ all_variable_symbols
all_symbols
solvedvariables
allvariables
Timeseries
NotTimeseries
is_timeseries
state_values
parameter_values
current_time
Expand Down
2 changes: 1 addition & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ include("symbol_cache.jl")
export parameter_values, getp, setp
include("parameter_indexing.jl")

export state_values, current_time, getu, setu
export Timeseries, NotTimeseries, is_timeseries, state_values, current_time, getu, setu
include("state_indexing.jl")

end
118 changes: 100 additions & 18 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,77 @@
abstract type IsTimeseriesTrait end

"""
struct Timeseries <: IsTimeseriesTrait end
Trait indicating a type contains timeseries data. This affects the behaviour of
functions such as [`state_values`](@ref) and [`current_time`](@ref).
See also: [`NotTimeseries`](@ref), [`is_timeseries`](@ref)
"""
struct Timeseries <: IsTimeseriesTrait end

"""
struct NotTimeseries <: IsTimeseriesTrait end
Trait indicating a type does not contain timeseries data. This affects the behaviour
of functions such as [`state_values`](@ref) and [`current_time`](@ref). Note that
if a type is `NotTimeseries` this only implies that it does not _store_ timeseries
data. It may still be time-dependent. For example, an `ODEProblem` only stores
the initial state of a system, so it is `NotTimeseries`, but still time-dependent.
This is the default trait variant for all types.
See also: [`Timeseries`](@ref), [`is_timeseries`](@ref)
"""
struct NotTimeseries <: IsTimeseriesTrait end

"""
is_timeseries(x) = is_timeseries(typeof(x))
is_timeseries(::Type)
Get the timeseries trait of a type. Defaults to [`NotTimeseries`](@ref) for all types.
See also: [`Timeseries`](@ref), [`NotTimeseries`](@ref)
"""
function is_timeseries end

is_timeseries(x) = is_timeseries(typeof(x))
is_timeseries(::Type) = NotTimeseries()

"""
state_values(p)
Return an indexable collection containing the values of all states in the integrator or
problem `p`.
problem `p`. If `is_timeseries(p)` is [`Timeseries`](@ref), return a vector of arrays,
each of which contain the state values at the corresponding timestep.
See: [`is_timeseries`](@ref)
"""
function state_values end

"""
current_time(p)
Return the current time in the integrator or problem `p`.
Return the current time in the integrator or problem `p`. If
`is_timeseries(p)` is [`Timeseries`](@ref), return the vector of timesteps at which
the state value is saved.
See: [`is_timeseries`](@ref)
"""
function current_time end

"""
getu(sys, sym)
Return a function that takes an integrator or problem of `sys`, and returns the value of
the symbolic `sym`. `sym` can be a direct index into the state vector, a symbolic state,
a symbolic expression involving symbolic quantities in the system `sys`, or an
array/tuple of the aforementioned.
Return a function that takes an integrator, problem or solution of `sys`, and returns
the value of the symbolic `sym`. `sym` can be a direct index into the state vector, a
symbolic state, a symbolic expression involving symbolic quantities in the system
`sys`, or an array/tuple of the aforementioned.
At minimum, this requires that the integrator or problem implement [`state_values`](@ref).
To support symbolic expressions, the integrator or problem must implement
[`observed`](@ref), [`parameter_values`](@ref) and [`current_time`](@ref).
At minimum, this requires that the integrator, problem or solution implement
[`state_values`](@ref). To support symbolic expressions, the integrator or problem
must implement [`observed`](@ref), [`parameter_values`](@ref) and
[`current_time`](@ref).
This function typically does not need to be implemented, and has a default implementation
relying on the above functions.
Expand All @@ -40,37 +88,68 @@ function getu(sys, sym)
end

function _getu(sys, ::NotSymbolic, sym)
_getter(::Timeseries, prob) = getindex.(state_values(prob), (sym,))
_getter(::NotTimeseries, prob) = state_values(prob)[sym]
return function getter(prob)
return state_values(prob)[sym]
return _getter(is_timeseries(prob), prob)
end
end

function _getu(sys, ::ScalarSymbolic, sym)
if is_variable(sys, sym)
idx = variable_index(sys, sym)
return function getter1(prob)
return state_values(prob)[idx]
end
return getu(sys, idx)
elseif is_observed(sys, sym)
fn = observed(sys, sym)
if is_time_dependent(sys)
function getter2(prob)
function _getter2(::Timeseries, prob)
return fn.(state_values(prob),
(parameter_values(prob),),
current_time(prob))
end
function _getter2(::NotTimeseries, prob)
return fn(state_values(prob), parameter_values(prob), current_time(prob))
end

return function getter2(prob)
return _getter2(is_timeseries(prob), prob)
end
else
function getter3(prob)
function _getter3(::Timeseries, prob)
return fn.(state_values(prob), (parameter_values(prob),))
end
function _getter3(::NotTimeseries, prob)
return fn(state_values(prob), parameter_values(prob))
end

return function getter3(prob)
return _getter3(is_timeseries(prob), prob)
end
end
end
error("Invalid symbol $sym for `getu`")
end

function _getu(sys, ::ScalarSymbolic, sym::Union{<:Tuple,<:AbstractArray})
struct TimeseriesIndexWrapper{T, I}
timeseries::T
idx::I
end

state_values(t::TimeseriesIndexWrapper) = state_values(t.timeseries)[t.idx]
parameter_values(t::TimeseriesIndexWrapper) = parameter_values(t.timeseries)
current_time(t::TimeseriesIndexWrapper) = current_time(t.timeseries)[t.idx]

function _getu(sys, ::ScalarSymbolic, sym::Union{<:Tuple, <:AbstractArray})
getters = getu.((sys,), sym)
_call(getter, prob) = getter(prob)

function _getter(::Timeseries, prob)
tiws = TimeseriesIndexWrapper.((prob,), eachindex(state_values(prob)))
return [_getter(NotTimeseries(), tiw) for tiw in tiws]
end
_getter(::NotTimeseries, prob) = _call.(getters, (prob,))
return function getter(prob)
return _call.(getters, (prob,))
return _getter(is_timeseries(prob), prob)
end
end

Expand All @@ -86,6 +165,9 @@ the state `sym` to that value. Note that `sym` can be a direct numerical index,
Requires that the integrator implement [`state_values`](@ref) and the
returned collection be a mutable reference to the state vector in the integrator/problem.
This function does not work on types for which [`is_timeseries`](@ref) is
[`Timeseries`](@ref).
In case `state_values` cannot return such a mutable reference, `setu` needs to be
implemented manually.
"""
Expand Down Expand Up @@ -114,7 +196,7 @@ function _setu(sys, ::ScalarSymbolic, sym)
end
end

function _setu(sys, ::ScalarSymbolic, sym::Union{<:Tuple,<:AbstractArray})
function _setu(sys, ::ScalarSymbolic, sym::Union{<:Tuple, <:AbstractArray})
setters = setu.((sys,), sym)
_call!(setter!, prob, val) = setter!(prob, val)
return function setter!(prob, val)
Expand Down
24 changes: 23 additions & 1 deletion test/state_indexing_test.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using SymbolicIndexingInterface

struct FakeIntegrator{S,U}
struct FakeIntegrator{S, U}
sys::S
u::U
end
Expand All @@ -20,3 +20,25 @@ for (i, sym) in [(1, :x), (2, :y), (3, :z), ([1, 2], [:x, :y]), ((3, 2), (:z, :y
@test get(fi) == 0.5 .* i
set!(fi, true_value)
end

struct FakeSolution{S, U}
sys::S
u::U
end

SymbolicIndexingInterface.is_timeseries(::Type{<:FakeSolution}) = Timeseries()
SymbolicIndexingInterface.symbolic_container(fp::FakeSolution) = fp.sys
SymbolicIndexingInterface.state_values(fp::FakeSolution) = fp.u

sys = SymbolCache([:x, :y, :z], [:a, :b], [:t])
u = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
sol = FakeSolution(sys, u)
for (i, sym) in [(1, :x), (2, :y), (3, :z), ([1, 2], [:x, :y]), ((3, 2), (:z, :y))]
get = getu(sys, sym)
true_value = if i isa Tuple
[getindex.((v,), i) for v in u]
else
getindex.(u, (i,))
end
@test get(sol) == true_value
end

0 comments on commit b1a7f72

Please sign in to comment.