Skip to content

Commit

Permalink
feat: add set_state! and set_parameter!, update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 2, 2024
1 parent 0166940 commit 7a8e647
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 31 deletions.
31 changes: 26 additions & 5 deletions docs/src/api.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Interface Functions

## Mandatory methods

```@docs
symbolic_container
is_variable
Expand All @@ -11,26 +13,45 @@ parameter_symbols
is_independent_variable
independent_variable_symbols
is_observed
observed
is_time_dependent
constant_structure
all_variable_symbols
all_symbols
solvedvariables
allvariables
```

## Optional Methods

### Observed equation handling

```@docs
observed
```

### Parameter indexing

```@docs
parameter_values
set_parameter!
getp
setp
```

### State indexing

```@docs
Timeseries
NotTimeseries
is_timeseries
state_values
parameter_values
set_state!
current_time
getp
setp
getu
setu
```

# Traits
# Symbolic Trait

```@docs
ScalarSymbolic
Expand Down
83 changes: 71 additions & 12 deletions docs/src/complete_sii.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ function SymbolicIndexingInterface.observed(sys::ExampleSystem, sym::Expr)
end
```

In case a type does not support such observed quantities, `is_observed` must be
defined to always return `false`, and `observed` does not need to be implemented.

### Note about constant structure

Note that the method definitions are all assuming `constant_structure(p) == true`.
Expand Down Expand Up @@ -174,35 +177,91 @@ mutable struct ExampleIntegrator
u::Vector{Float64}
p::Vector{Float64}
t::Float64
state_index::Dict{Symbol,Int}
parameter_index::Dict{Symbol,Int}
independent_variable::Symbol
sys::ExampleSystem
end
```

Assume that it implements the mandatory part of the interface as described above, and
the following methods below:

```julia
# define a fallback for the interface methods
SymbolicIndexingInterface.symbolic_container(integ::ExampleIntegrator) = integ.sys
SymbolicIndexingInterface.state_values(sys::ExampleIntegrator) = sys.u
SymbolicIndexingInterface.parameter_values(sys::ExampleIntegrator) = sys.p
SymbolicIndexingInterface.current_time(sys::ExampleIntegrator) = sys.t
```

Then the following example would work:
```julia
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t)
getx = getu(integrator, :x)
sys = ExampleSystem(Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t, Dict())
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, sys)
getx = getu(sys, :x)
getx(integrator) # 1.0

get_expr = getu(integrator, :(x + y + t))
get_expr = getu(sys, :(x + y + t))
get_expr(integrator) # 13.0

setx! = setu(integrator, :y)
setx! = setu(sys, :y)
setx!(integrator, 0.0)
getx(integrator) # 0.0
```

In case a type stores timeseries data (such as solutions), then it must also implement
the [`Timeseries`](@ref) trait. The type would then return a timeseries from
[`state_values`](@ref) and [`current_time`](@ref) and the function returned from
[`getu`](@ref) would then return a timeseries as well. For example, consider the
`ExampleSolution` below:

```julia
struct ExampleSolution
u::Vector{Vector{Float64}}
t::Vector{Float64}
p::Vector{Float64}
sys::ExampleSystem
end

# define a fallback for the interface methods
SymbolicIndexingInterface.symbolic_container(sol::ExampleSolution) = sol.sys
SymbolicIndexingInterface.parameter_values(sol::ExampleSolution) = sol.p
# define the trait
SymbolicIndexingInterface.is_timeseries(::Type{ExampleSolution}) = Timeseries()
# both state_values and current_time return a timeseries, which must be
# the same length
SymbolicIndexingInterface.state_values(sol::ExampleSolution) = sol.u
SymbolicIndexingInterface.current_time(sol::ExampleSolution) = sol.t
```

Then the following example would work:
```julia
# using the same system that the ExampleIntegrator used
sol = ExampleSolution([[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]], [4.0, 5.0], [6.0, 7.0], sys)
getx = getu(sys, :x)
getx(sol) # [1.0, 1.5]

get_expr = getu(sys, :(x + y + t))
get_expr(sol) # [9.0, 11.0]

get_arr = getu(sys, [:y, :(x + a)])
get_arr(sol) # [[2.0, 5.0], [2.5, 5.5]]

get_tuple = getu(sys, (:z, :(z * t)))
get_tuple(sol) # [(3.0, 18.0), (3.5, 24.5)]
```

Note that `setu` is not designed to work for `Timeseries` objects.

If a type needs to perform some additional actions when updating the state/parameters
or if it is not possible to return a mutable reference to the state/parameter vector
which can directly be modified, the functions [`set_state!`](@ref) and/or
[`set_parameter!`](@ref) can be used. For example, suppose our `ExampleIntegrator`
had an additional field `u_modified::Bool` to allow it to keep track of when a
discontinuity occurs and handle it appropriately. This flag needs to be set to `true`
whenever the state is modified. The `set_state!` function can then be implemented as
follows:

```julia
function SymbolicIndexingInterface.set_state!(integrator::ExampleIntegrator, val, idx)
integrator.u[idx] = val
integrator.u_modified = true
end
```

# Implementing the `SymbolicTypeTrait` for a type

The `SymbolicTypeTrait` is used to identify values that can act as symbolic variables. It
Expand Down
5 changes: 3 additions & 2 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ include("interface.jl")
export SymbolCache
include("symbol_cache.jl")

export parameter_values, getp, setp
export parameter_values, set_parameter!, getp, setp
include("parameter_indexing.jl")

export Timeseries, NotTimeseries, is_timeseries, state_values, current_time, getu, setu
export Timeseries,
NotTimeseries, is_timeseries, state_values, set_state!, current_time, getu, setu
include("state_indexing.jl")

end
4 changes: 3 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ have the signature `(u, p) -> [values...]` where `u` and `p` is the current stat
parameter vector, respectively. If `istimedependent(sys) == true`, the function should accept
the current time `t` as its third parameter. If `constant_structure(sys) == false`,
accept a third parameter, which can either be a vector of symbols indicating the order
of states or a time index, which identifies the order of states.
of states or a time index, which identifies the order of states. This function
does not need to be defined if [`is_observed`](@ref) always returns `false`. Thus,
it is mandatory to always check `is_observed` before using this function.
See also: [`is_time_dependent`](@ref), [`constant_structure`](@ref)
"""
Expand Down
25 changes: 20 additions & 5 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,20 @@ Return an indexable collection containing the value of each parameter in `p`.
"""
function parameter_values end

"""
set_parameter!(sys, val, idx)
Set the parameter at index `idx` to `val` for system `sys`. This defaults to modifying
`parameter_values(sys)`. If any additional bookkeeping needs to be performed or the
default implementation does not work for a particular type, this method needs to be
defined to enable the proper functioning of [`setp`](@ref).
See: [`parameter_values`](@ref)
"""
function set_parameter!(sys, val, idx)
parameter_values(sys)[idx] = val
end

"""
getp(sys, p)
Expand Down Expand Up @@ -55,8 +69,9 @@ Return a function that takes an integrator of `sys` and a value, and sets
the parameter `p` to that value. Note that `p` can be a direct numerical index or a
symbolic value. Requires that the integrator implement [`parameter_values`](@ref) and the
returned collection be a mutable reference to the parameter vector in the integrator. In
case `parameter_values` cannot return such a mutable reference, `setp` needs to be
implemented manually.
case `parameter_values` cannot return such a mutable reference, or additional actions
need to be performed when updating parameters, [`set_parameter!`](@ref) must be
implemented.
"""
function setp(sys, p)
symtype = symbolic_type(p)
Expand All @@ -70,21 +85,21 @@ end

function _setp(sys, ::NotSymbolic, p)
return function setter!(sol, val)
parameter_values(sol)[p] = val
set_parameter!(sol, val, p)

Check warning on line 88 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L88

Added line #L88 was not covered by tests
end
end

function _setp(sys, ::ScalarSymbolic, p)
idx = parameter_index(sys, p)
return function setter!(sol, val)
parameter_values(sol)[idx] = val
set_parameter!(sol, val, idx)
end
end

function _setp(sys, ::ScalarSymbolic, p::Union{Tuple, AbstractArray})
idxs = parameter_index.((sys,), p)
return function setter!(sol, val)
setindex!.((parameter_values(sol),), val, idxs)
set_parameter!.((sol,), val, idxs)

Check warning on line 102 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L102

Added line #L102 was not covered by tests
end
end

Expand Down
24 changes: 18 additions & 6 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,20 @@ See: [`is_timeseries`](@ref)
"""
function state_values end

"""
set_state!(sys, val, idx)
Set the state at index `idx` to `val` for system `sys`. This defaults to modifying
`state_values(sys)`. If any additional bookkeeping needs to be performed or the
default implementation does not work for a particular type, this method needs to be
defined to enable the proper functioning of [`setu`](@ref).
See: [`state_values`](@ref)
"""
function set_state!(sys, val, idx)
state_values(sys)[idx] = val

Check warning on line 62 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L61-L62

Added lines #L61 - L62 were not covered by tests
end

"""
current_time(p)
Expand Down Expand Up @@ -164,12 +178,10 @@ Return a function that takes an integrator or problem of `sys` and a value, and
the state `sym` to that value. Note that `sym` can be a direct numerical index, a symbolic state, or an array/tuple of the aforementioned.
Requires that the integrator implement [`state_values`](@ref) and the
returned collection be a mutable reference to the state vector in the integrator/problem.
returned collection be a mutable reference to the state vector in the integrator/problem. Alternatively, if this is not possible or additional actions need to
be performed when updating state, [`set_state!`](@ref) can be defined.
This function does not work on types for which [`is_timeseries`](@ref) is
[`Timeseries`](@ref).
In case `state_values` cannot return such a mutable reference, `setu` needs to be
implemented manually.
"""
function setu(sys, sym)
symtype = symbolic_type(sym)
Expand All @@ -184,15 +196,15 @@ end

function _setu(sys, ::NotSymbolic, sym)
return function setter!(prob, val)
state_values(prob)[sym] = val
set_state!(prob, val, sym)

Check warning on line 199 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L199

Added line #L199 was not covered by tests
end
end

function _setu(sys, ::ScalarSymbolic, sym)
is_variable(sys, sym) || error("Invalid symbol $sym for `setu`")
idx = variable_index(sys, sym)
return function setter!(prob, val)
state_values(prob)[idx] = val
set_state!(prob, val, idx)

Check warning on line 207 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L207

Added line #L207 was not covered by tests
end
end

Expand Down

0 comments on commit 7a8e647

Please sign in to comment.