Skip to content

Commit

Permalink
refactor: rework discrete indexing behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jul 4, 2024
1 parent 4183188 commit 37a612a
Show file tree
Hide file tree
Showing 11 changed files with 657 additions and 510 deletions.
3 changes: 2 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ allvariables
```@docs
observed
parameter_observed
ParameterObservedFunction
```

#### Parameter timeseries
Expand All @@ -46,6 +45,8 @@ may change at different times.
is_timeseries_parameter
timeseries_parameter_index
ParameterTimeseriesIndex
get_all_timeseries_indexes
ContinuousTimeseries
```

## Value provider interface
Expand Down
22 changes: 13 additions & 9 deletions docs/src/complete_sii.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,8 @@ end
In case a type does not support such observed quantities, `is_observed` must be
defined to always return `false`, and `observed` does not need to be implemented.

The same process can be followed for [`parameter_observed`](@ref), with the exception
that the returned function must not have `u` in its signature, and must be wrapped in a
[`ParameterObservedFunction`](@ref). In-place versions can also be implemented for
`parameter_observed`.
The same process can be followed for [`parameter_observed`](@ref). In-place versions
can also be implemented for `parameter_observed`.

#### Note about constant structure

Expand Down Expand Up @@ -334,7 +332,9 @@ end
# To be able to access parameter values
SymbolicIndexingInterface.parameter_values(mpo::MyParameterObject) = mpo.p
# Update the parameter object with new values
function SymbolicIndexingInterface.with_updated_parameter_timeseries_values(mpo::MyParameterObject, args::Pair...)
# Here, we don't need the index provider but it may be necessary for other implementations
function SymbolicIndexingInterface.with_updated_parameter_timeseries_values(
::SymbolCache, mpo::MyParameterObject, args::Pair...)
for (ts_idx, val) in args
mpo.p[mpo.disc_idxs[ts_idx]] = val
end
Expand Down Expand Up @@ -440,15 +440,15 @@ sol.ps[:b, idxs]
```

```@example param_timeseries
sol.ps[[:a, :b]] # returns the values at the last timestep, since :a is not timeseries
sol.ps[[:a, :b]] # :a has the same value at all time points
```

```@example param_timeseries
# throws an error since :b and :d belong to different timeseries
try
sol.ps[[:b, :d]]
catch e
@show e
showerror(stdout, e)
end
```

Expand All @@ -457,11 +457,15 @@ sol.ps[:(b + c)] # observed quantities work too
```

```@example param_timeseries
getu(sol, :b)(sol) # returns the values :b takes at the times in the state timeseries
getu(sol, :b)(sol) # works
```

```@example param_timeseries
getu(sol, [:b, :d])(sol) # works
try
getu(sol, [:b, :d])(sol) # errors since :b and :d belong to different timeseries
catch e
showerror(stdout, e)
end
```

## Custom containers
Expand Down
3 changes: 2 additions & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ include("trait.jl")
export is_variable, variable_index, variable_symbols, is_parameter, parameter_index,
is_timeseries_parameter, timeseries_parameter_index, ParameterTimeseriesIndex,
parameter_symbols, is_independent_variable, independent_variable_symbols,
is_observed, observed, parameter_observed, ParameterObservedFunction,
is_observed, observed, parameter_observed,
ContinuousTimeseries, get_all_timeseries_indexes,
is_time_dependent, constant_structure, symbolic_container,
all_variable_symbols, all_symbols, solvedvariables, allvariables, default_values,
symbolic_evaluate
Expand Down
67 changes: 39 additions & 28 deletions src/index_provider_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,38 +98,17 @@ function timeseries_parameter_index(indp, sym)
end
end

"""
struct ParameterObservedFunction
function ParameterObservedFunction(timeseries_idx, observed_fn::Function)
function ParameterObservedFunction(observed_fn::Function)
A struct which stores the parameter observed function and optional timeseries index for
a particular symbol. The timeseries index is optional and may be omitted. Specifying the
timeseries index allows [`getp`](@ref) to return the appropriate timeseries for a
timeseries parameter.
For time-dependent index providers (where `is_time_dependent(indp)`) `observed_fn` must
have the signature `(p, t) -> [values...]`. For non-time-dependent index providers
(where `!is_time_dependent(indp)`) `observed_fn` must have the signature
`(p) -> [values...]`. To support in-place `getp` methods, `observed_fn` must also have an
additional method which takes `buffer::AbstractArray` as its first argument. The required
values must be written to the buffer in the appropriate order.
"""
struct ParameterObservedFunction{I, F <: Function}
timeseries_idx::I
observed_fn::F
end

"""
parameter_observed(indp, sym)
Return the observed function of `sym` in `indp` as a [`ParameterObservedFunction`](@ref).
If `sym` only involves variables from a single parameter timeseries (optionally along
with non-timeseries parameters) the timeseries index of the parameter timeseries should
be provided in the [`ParameterObservedFunction`](@ref). In all other cases, just the
observed function should be returned as part of the `ParameterObservedFunction` object.
Return the observed function of `sym` in `indp`. This functions similarly to
[`observed`](@ref) except that `u` is not an argument of the returned function. For time-
dependent systems, the returned function must have the signature `(p, t) -> [values...]`.
For time-independent systems, the returned function must have the signature
`(p) -> [values...]`.
By default, this function returns `nothing`.
By default, this function returns `nothing`, indicating that the index provider does not
support generating parameter observed functions.
"""
function parameter_observed(indp, sym)
if hasmethod(symbolic_container, Tuple{typeof(indp)})
Expand All @@ -139,6 +118,38 @@ function parameter_observed(indp, sym)
end
end

"""
struct ContinuousTimeseries end
A singleton struct corresponding to the timeseries index of the continuous timeseries.
"""
struct ContinuousTimeseries end

"""
get_all_timeseries_indexes(indp, sym)
Return a `Set` of all unique timeseries indexes of variables in symbolic variable
`sym`. `sym` may be a symbolic variable or expression, an array of symbolics, an index,
or an array of indices. Continuous variables correspond to the
[`ContinuousTimeseries`](@ref) timeseries index. Non-timeseries parameters do not have a
timeseries index. Timeseries parameters have the same timeseries index as that returned by
[`timeseries_parameter_index`](@ref). Note that the independent variable corresponds to
the `ContinuousTimeseries` timeseries index.
Any ambiguities should be resolved in favor of variables. For example, if `1` could refer
to the variable at index `1` or parameter at index `1`, it should be interpreted as the
variable.
By default, this function returns `Set([ContinuousTimeseries()])`.
"""
function get_all_timeseries_indexes(indp, sym)
if hasmethod(symbolic_container, Tuple{typeof(indp)})
return get_all_timeseries_indexes(symbolic_container(indp), sym)
else
return Set([ContinuousTimeseries()])
end
end

"""
parameter_symbols(indp)
Expand Down
Loading

0 comments on commit 37a612a

Please sign in to comment.