Skip to content

Commit

Permalink
feat: support observed functions for history-dependent systems
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Sep 20, 2024
1 parent 8cf2450 commit 43b5587
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 5 deletions.
33 changes: 32 additions & 1 deletion src/index_provider_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,13 @@ the order of states or a time index, which identifies the order of states. This
does not need to be defined if [`is_observed`](@ref) always returns `false`. Thus,
it is mandatory to always check `is_observed` before using this function.
See also: [`is_time_dependent`](@ref), [`constant_structure`](@ref)
If `is_historical(indp)`, the returned function must have the signature
`(u, h, p, t) -> [values...]` where `h` is the history function, which can be called
to obtain past values of the state. The exact signature and semantics of `h` depend
on how it is used inside the returned function. `h` is obtained from a value
provider using [`get_history_function`](@ref).
See also: [`is_time_dependent`](@ref), [`is_historical`](@ref), [`constant_structure`](@ref).
"""
observed(indp, sym) = observed(symbolic_container(indp), sym)
observed(indp, sym, states) = observed(symbolic_container(indp), sym, states)
Expand All @@ -213,6 +219,31 @@ Check if `indp` has time as (one of) its independent variables.
"""
is_time_dependent(indp) = is_time_dependent(symbolic_container(indp))

"""
is_historical(indp)
Check if an index provider is historical. Historical index providers require values of states
at arbitrary offsets in the past to simulate, and these historical values are not present
as part of the states themselves. For example, systems of DDEs are historical. However,
a discrete system is not historical since the state realization includes the historical values
required to compute the next state.
This function is only required for time-dependent index providers.
Historical index providers return [`observed`](@ref) functions with a different signature.
All value providers associated with a historical index provider must implement
[`get_history_function`](@ref).
Returns `false` by default.
"""
function is_historical(indp)
if hasmethod(symbolic_container, Tuple{typeof(indp)})
is_historical(symbolic_container(indp))
else
false
end
end

"""
constant_structure(indp)
Expand Down
34 changes: 30 additions & 4 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,17 @@ struct GetIndepvar <: AbstractStateGetIndexer end
(::GetIndepvar)(::IsTimeseriesTrait, prob) = current_time(prob)
(::GetIndepvar)(::Timeseries, prob, i) = current_time(prob, i)

struct TimeDependentObservedFunction{I, F} <: AbstractStateGetIndexer
struct TimeDependentObservedFunction{I, F, H} <: AbstractStateGetIndexer
ts_idxs::I
obsfn::F
end

function TimeDependentObservedFunction{H}(ts_idxs, obsfn) where {H}
return TimeDependentObservedFunction{typeof(ts_idxs), typeof(obsfn), H}(ts_idxs, obsfn)
end

const HistoricalObservedFunction = TimeDependentObservedFunction{I, F, true} where {I, F}

indexer_timeseries_index(t::TimeDependentObservedFunction) = t.ts_idxs
function is_indexer_timeseries(::Type{G}) where {G <:
TimeDependentObservedFunction{ContinuousTimeseries}}
Expand All @@ -74,15 +80,26 @@ function (o::TimeDependentObservedFunction)(ts::IsTimeseriesTrait, prob, args...
return o(ts, is_indexer_timeseries(o), prob, args...)
end

function (o::TimeDependentObservedFunction)(ts::Timeseries, ::IndexerBoth, prob)
function (o::TimeDependentObservedFunction)(::Timeseries, ::IndexerBoth, prob)
return o.obsfn.(state_values(prob),
(parameter_values(prob),),
current_time(prob))
end
function (o::HistoricalObservedFunction)(::Timeseries, ::IndexerBoth, prob)
return o.obsfn.(state_values(prob),
(get_history_function(prob),),
(parameter_values(prob),),
current_time(prob))
end
function (o::TimeDependentObservedFunction)(
::Timeseries, ::IndexerBoth, prob, i::Union{Int, CartesianIndex})
return o.obsfn(state_values(prob, i), parameter_values(prob), current_time(prob, i))
end
function (o::HistoricalObservedFunction)(
::Timeseries, ::IndexerBoth, prob, i::Union{Int, CartesianIndex})
return o.obsfn(state_values(prob, i), get_history_function(prob),
parameter_values(prob), current_time(prob, i))
end
function (o::TimeDependentObservedFunction)(ts::Timeseries, ::IndexerBoth, prob, ::Colon)
return o(ts, prob)
end
Expand All @@ -98,6 +115,10 @@ end
function (o::TimeDependentObservedFunction)(::NotTimeseries, ::IndexerBoth, prob)
return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob))
end
function (o::HistoricalObservedFunction)(::NotTimeseries, ::IndexerBoth, prob)
return o.obsfn(state_values(prob), get_history_function(prob),
parameter_values(prob), current_time(prob))
end

function (o::TimeDependentObservedFunction)(
::Timeseries, ::IndexerMixedTimeseries, prob, args...)
Expand All @@ -107,6 +128,11 @@ function (o::TimeDependentObservedFunction)(
::NotTimeseries, ::IndexerMixedTimeseries, prob, args...)
return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob))
end
function (o::HistoricalObservedFunction)(
::NotTimeseries, ::IndexerMixedTimeseries, prob, args...)
return o.obsfn(state_values(prob), get_history_function(prob),
parameter_values(prob), current_time(prob))
end

struct TimeIndependentObservedFunction{F} <: AbstractStateGetIndexer
obsfn::F
Expand Down Expand Up @@ -137,7 +163,7 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
ts_idxs = collect(ts_idxs)
end
fn = observed(sys, sym)
return TimeDependentObservedFunction(ts_idxs, fn)
return TimeDependentObservedFunction{is_historical(sys)}(ts_idxs, fn)
else
return getp(sys, sym)
end
Expand Down Expand Up @@ -256,7 +282,7 @@ for (t1, t2) in [
else
obs = observed(sys, sym_arr)
getter = if is_time_dependent(sys)
TimeDependentObservedFunction(ts_idxs, obs)
TimeDependentObservedFunction{is_historical(sys)}(ts_idxs, obs)
else
TimeIndependentObservedFunction(obs)
end
Expand Down
10 changes: 10 additions & 0 deletions src/value_provider_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,16 @@ current_time(arr::AbstractVector) = arr
current_time(valp, i) = current_time(valp)[i]
current_time(valp, ::Colon) = current_time(valp)

"""
get_history_function(valp)
Return the history function for a value provider. This is required for all value providers
associated with an index provider `indp` for which `is_historical(indp)`.
See also: [`is_historical`](@ref).
"""
function get_history_function end

###########
# Utilities
###########
Expand Down
31 changes: 31 additions & 0 deletions test/state_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,34 @@ for (sym, val, check_inference) in [
end
@test getter(fs) == val
end

struct HistoricalWrapper{S <: SymbolCache}
sys::S
end

SymbolicIndexingInterface.symbolic_container(hw::HistoricalWrapper) = hw.sys
SymbolicIndexingInterface.is_historical(::HistoricalWrapper) = true
function SymbolicIndexingInterface.observed(hw::HistoricalWrapper, sym)
let inner = observed(hw.sys, sym)
fn(u, h, p, t) = inner(u .+ h(t - 0.1), p, t)
end
end
function SymbolicIndexingInterface.get_history_function(fs::FakeSolution)
t -> t .* ones(length(fs.u[1]))
end
function SymbolicIndexingInterface.get_history_function(fi::FakeIntegrator)
t -> t .* ones(length(fi.u))
end

sys = HistoricalWrapper(SymbolCache([:x, :y, :z], [:a, :b, :c], :t))
u0 = [1.0, 2.0, 3.0]
u = [u0 .* i for i in 1:11]
p = [10.0, 20.0, 30.0]
ts = 0.0:0.1:1.0

fi = FakeIntegrator(sys, u0, p, ts[1])
fs = FakeSolution(sys, u, p, ts)
getter = getu(sys, :(x + y))
@test getter(fi) 2.8
@test getter(fs) [3.0i + 2(ts[i] - 0.1) for i in 1:11]
@test getter(fs, 1) 2.8

0 comments on commit 43b5587

Please sign in to comment.