Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support indexing in mixed discrete-continuous systems #53

Merged
merged 5 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ getu
setu
```

### Parameter timeseries

If a solution object saves a timeseries of parameter values that are updated during the
simulation (such as by callbacks), it must implement the following methods to ensure
correct functioning of [`getu`](@ref) and [`getp`](@ref).

```@docs
parameter_timeseries
parameter_values_at_time
parameter_values_at_state_time
```

# Symbolic Trait

```@docs
Expand Down
78 changes: 78 additions & 0 deletions docs/src/complete_sii.md
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,81 @@ end

[`hasname`](@ref) is not required to always be `true` for symbolic types. For example,
`Symbolics.Num` returns `false` whenever the wrapped value is a number, or an expression.

## Parameter Timeseries

If a solution object saves modified parameter values (such as through callbacks) during the
simulation, it must implement [`parameter_timeseries`](@ref),
[`parameter_values_at_time`](@ref) and [`parameter_values_at_state_time`](@ref) for correct
functioning of [`getu`](@ref) and [`getp`](@ref). The following mockup gives an example
of correct implementation of these functions and the indexing syntax they enable.

```@example param_timeseries
using SymbolicIndexingInterface

struct ExampleSolution2
sys::SymbolCache
u::Vector{Vector{Float64}}
t::Vector{Float64}
p::Vector{Vector{Float64}}
pt::Vector{Float64}
end

# Add the `:ps` property to automatically wrap in `ParameterIndexingProxy`
function Base.getproperty(fs::ExampleSolution2, s::Symbol)
s === :ps ? ParameterIndexingProxy(fs) : getfield(fs, s)
end
# Use the contained `SymbolCache` for indexing
SymbolicIndexingInterface.symbolic_container(fs::ExampleSolution2) = fs.sys
# By default, `parameter_values` refers to the last value
SymbolicIndexingInterface.parameter_values(fs::ExampleSolution2) = fs.p[end]
SymbolicIndexingInterface.parameter_values(fs::ExampleSolution2, i) = fs.p[end][i]
# Index into the parameter timeseries vector
function SymbolicIndexingInterface.parameter_values_at_time(fs::ExampleSolution2, t)
fs.p[t]
end
# Find the first index in the parameter timeseries vector with a time smaller
# than the time from the state timeseries, and use that to index the parameter
# timeseries
function SymbolicIndexingInterface.parameter_values_at_state_time(fs::ExampleSolution2, t)
ptind = searchsortedfirst(fs.pt, fs.t[t]; lt = <=)
fs.p[ptind - 1]
end
SymbolicIndexingInterface.parameter_timeseries(fs::ExampleSolution2) = fs.pt
# Mark the object as a `Timeseries` object
SymbolicIndexingInterface.is_timeseries(::Type{ExampleSolution2}) = Timeseries()

```

Now we can create an example object and observe the new functionality. Note that
`sol.ps[sym, args...]` is identical to `getp(sol, sym)(sol, args...)`.

```@example param_timeseries
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
sol = ExampleSolution2(
sys,
[i * ones(3) for i in 1:5],
[0.2i for i in 1:5],
[2i * ones(3) for i in 1:10],
[0.1i for i in 1:10]
)
sol.ps[:a] # returns the value at the last timestep
```

```@example param_timeseries
sol.ps[:a, :] # use Colon to fetch the entire parameter timeseries
```

```@example param_timeseries
sol.ps[:a, 3] # index at a specific index in the parameter timeseries
```

```@example param_timeseries
sol.ps[:a, [3, 6, 8]] # index using arrays
```

```@example param_timeseries
idxs = @show rand(Bool, 10) # boolean mask for indexing
sol.ps[:a, idxs]
```

6 changes: 3 additions & 3 deletions docs/src/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,11 @@ sol2 = solve(prob, Tsit5())
σ_ρ_getter(sol)
```

To set the entire parameter vector at once, [`parameter_values`](@ref) can be used
(note the usage of broadcasted assignment).
To set the entire parameter vector at once, [`setp`](@ref) can be used
(note that the order of symbols passed to `setp` must match the order of values in the array).

```@example Usage
parameter_values(prob) .= [29.0, 11.0, 2.5]
setp(prob, parameter_symbols(prob))(prob, [29.0, 11.0, 2.5])
parameter_values(prob)
```

Expand Down
9 changes: 5 additions & 4 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module SymbolicIndexingInterface

export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getname
export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getname,
Timeseries, NotTimeseries, is_timeseries
include("trait.jl")

export is_variable, variable_index, variable_symbols, is_parameter, parameter_index,
Expand All @@ -14,11 +15,11 @@ include("interface.jl")
export SymbolCache
include("symbol_cache.jl")

export parameter_values, set_parameter!, getp, setp
export parameter_values, set_parameter!, parameter_values_at_time,
parameter_values_at_state_time, parameter_timeseries, getp, setp
include("parameter_indexing.jl")

export Timeseries,
NotTimeseries, is_timeseries, state_values, set_state!, current_time, getu, setu
export state_values, set_state!, current_time, getu, setu
include("state_indexing.jl")

export ParameterIndexingProxy
Expand Down
147 changes: 135 additions & 12 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,56 @@
parameter_values(arr::AbstractArray, i) = arr[i]
parameter_values(prob, i) = parameter_values(parameter_values(prob), i)

"""
parameter_values_at_time(p, i)

Return an indexable collection containing the value of all parameters in `p` at time index
`i`. This is useful when parameter values change during the simulation
(such as through callbacks) and their values are saved. `i` is the time index in the
timeseries formed by these changing parameter values, obtained using
[`parameter_timeseries`](@ref).

By default, this function returns `parameter_values(p)` regardless of `i`, and only needs
to be specialized for timeseries objects where parameter values are not constant at all
times. The resultant object should be indexable using [`parameter_values`](@ref).

If this function is implemented, [`parameter_values_at_state_time`](@ref) must be
implemented for [`getu`](@ref) to work correctly.
"""
function parameter_values_at_time end
parameter_values_at_time(p, i) = parameter_values(p)

Check warning on line 35 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L35

Added line #L35 was not covered by tests

"""
parameter_values_at_state_time(p, i)

Return an indexable collection containing the value of all parameters in `p` at time
index `i`. This is useful when parameter values change during the simulation (such as
through callbacks) and their values are saved. `i` is the time index in the timeseries
formed by dependent variables (as opposed to the timeseries of the parameters, as in
[`parameter_values_at_time`](@ref)).

By default, this function returns `parameter_values(p)` regardless of `i`, and only needs
to be specialized for timeseries objects where parameter values are not constant at
all times. The resultant object should be indexable using [`parameter_values`](@ref).

If this function is implemented, [`parameter_values_at_time`](@ref) must be implemented for
[`getp`](@ref) to work correctly.
"""
function parameter_values_at_state_time end
parameter_values_at_state_time(p, i) = parameter_values(p)

Check warning on line 54 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L54

Added line #L54 was not covered by tests

"""
parameter_timeseries(p)

Return an iterable of time steps at which the parameter values are saved. This is only
required for objects where `is_timeseries(p) === Timeseries()` and the parameter values
change during the simulation (such as through callbacks). By default, this returns `[0]`.

See also: [`parameter_values_at_time`](@ref).
"""
function parameter_timeseries end
parameter_timeseries(_) = [0]

Check warning on line 66 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L66

Added line #L66 was not covered by tests

"""
set_parameter!(sys, val, idx)

Expand Down Expand Up @@ -47,6 +97,13 @@
Requires that the integrator or solution implement [`parameter_values`](@ref). This function
typically does not need to be implemented, and has a default implementation relying on
[`parameter_values`](@ref).

If the returned function is used on a timeseries object which saves parameter timeseries, it
can be used to index said timeseries. The timeseries object must implement
[`parameter_timeseries`](@ref), [`parameter_values_at_time`](@ref) and
[`parameter_values_at_state_time`](@ref). The function returned from `getp` will can be passed
`Colon()` (`:`) as the last argument to return the entire parameter timeseries for `p`, or
any index into the parameter timeseries for a subset of values.
"""
function getp(sys, p)
symtype = symbolic_type(p)
Expand All @@ -55,18 +112,42 @@
end

function _getp(sys, ::NotSymbolic, ::NotSymbolic, p)
return function getter(sol)
return parameter_values(sol, p)
return let p = p
function _getter(::NotTimeseries, prob)
parameter_values(prob, p)
end
function _getter(::Timeseries, prob)
parameter_values(prob, p)
end
function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex})
parameter_values(
parameter_values_at_time(
prob, only(to_indices(parameter_timeseries(prob), (i,)))),
p)
end
function _getter(::Timeseries, prob, i::Union{AbstractArray{Bool}, Colon})
parameter_values.(
parameter_values_at_time.((prob,),
(j for j in only(to_indices(parameter_timeseries(prob), (i,))))),
p)
end
function _getter(::Timeseries, prob, i)
parameter_values.(parameter_values_at_time.((prob,), i), (p,))
end
getter = let _getter = _getter
function getter(prob, args...)
return _getter(is_timeseries(prob), prob, args...)
end
end
getter
end
end

function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
idx = parameter_index(sys, p)
return let idx = idx
function getter(sol)
return parameter_values(sol, idx)
end
end
return invoke(_getp, Tuple{Any, NotSymbolic, NotSymbolic, Any},
sys, NotSymbolic(), NotSymbolic(), idx)
return _getp(sys, NotSymbolic(), NotSymbolic(), idx)

Check warning on line 150 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L150

Added line #L150 was not covered by tests
end

for (t1, t2) in [
Expand All @@ -78,15 +159,57 @@
getters = getp.((sys,), p)

return let getters = getters
function getter(sol)
map(g -> g(sol), getters)
function _getter(::NotTimeseries, prob)
map(g -> g(prob), getters)

Check warning on line 163 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L163

Added line #L163 was not covered by tests
end
function getter(buffer, sol)
for (i, g) in zip(eachindex(buffer), getters)
buffer[i] = g(sol)
function _getter(::Timeseries, prob)
map(g -> g(prob), getters)
end
function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex})
map(g -> g(prob, i), getters)

Check warning on line 169 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L168-L169

Added lines #L168 - L169 were not covered by tests
end
function _getter(::Timeseries, prob, i)
[map(g -> g(prob, j), getters)

Check warning on line 172 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L171-L172

Added lines #L171 - L172 were not covered by tests
for j in only(to_indices(parameter_timeseries(prob), (i,)))]
end
function _getter!(buffer, ::NotTimeseries, prob)
for (g, bufi) in zip(getters, eachindex(buffer))
buffer[bufi] = g(prob)

Check warning on line 177 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L176-L177

Added lines #L176 - L177 were not covered by tests
end
buffer
end
function _getter!(buffer, ::Timeseries, prob)
for (g, bufi) in zip(getters, eachindex(buffer))
buffer[bufi] = g(prob)
end
buffer

Check warning on line 185 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L181-L185

Added lines #L181 - L185 were not covered by tests
end
function _getter!(buffer, ::Timeseries, prob, i::Union{Int, CartesianIndex})
for (g, bufi) in zip(getters, eachindex(buffer))
buffer[bufi] = g(prob, i)
end
buffer

Check warning on line 191 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L187-L191

Added lines #L187 - L191 were not covered by tests
end
function _getter!(buffer, ::Timeseries, prob, i)
for (bufi, tsi) in zip(

Check warning on line 194 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L193-L194

Added lines #L193 - L194 were not covered by tests
eachindex(buffer), only(to_indices(parameter_timeseries(prob), (i,))))
for (g, bufj) in zip(getters, eachindex(buffer[bufi]))
buffer[bufi][bufj] = g(prob, tsi)
end
end
buffer

Check warning on line 200 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L196-L200

Added lines #L196 - L200 were not covered by tests
end
_getter, _getter!
getter = let _getter = _getter, _getter! = _getter!
function getter(prob, i...)
return _getter(is_timeseries(prob), prob, i...)
end
function getter(buffer::AbstractArray, prob, i...)
return _getter!(buffer, is_timeseries(prob), prob, i...)

Check warning on line 208 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L207-L208

Added lines #L207 - L208 were not covered by tests
end
getter
end
getter
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/parameter_indexing_proxy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ struct ParameterIndexingProxy{T}
wrapped::T
end

function Base.getindex(p::ParameterIndexingProxy, idx)
return getp(p.wrapped, idx)(p.wrapped)
function Base.getindex(p::ParameterIndexingProxy, idx, args...)
getp(p.wrapped, idx)(p.wrapped, args...)
end

function Base.setindex!(p::ParameterIndexingProxy, val, idx)
Expand Down
Loading
Loading