Skip to content

Commit

Permalink
Merge pull request #25 from SciML/as/getu-sol
Browse files Browse the repository at this point in the history
feat: add `IsTimeseriesTrait`, support timeseries objects in `getu`
  • Loading branch information
ChrisRackauckas authored Jan 2, 2024
2 parents 9c5e3b5 + 29d692b commit 4da53e6
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 51 deletions.
1 change: 1 addition & 0 deletions .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
package:
- {user: SciML, repo: RecursiveArrayTools.jl, group: All}
- {user: JuliaSymbolics, repo: Symbolics.jl, group: SymbolicIndexingInterface}
- {user: SciML, repo: SciMLBase.jl, group: SymbolicIndexingInterface}
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
32 changes: 28 additions & 4 deletions docs/src/api.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Interface Functions

## Mandatory methods

```@docs
symbolic_container
is_variable
Expand All @@ -11,23 +13,45 @@ parameter_symbols
is_independent_variable
independent_variable_symbols
is_observed
observed
is_time_dependent
constant_structure
all_variable_symbols
all_symbols
solvedvariables
allvariables
state_values
```

## Optional Methods

### Observed equation handling

```@docs
observed
```

### Parameter indexing

```@docs
parameter_values
current_time
set_parameter!
getp
setp
```

### State indexing

```@docs
Timeseries
NotTimeseries
is_timeseries
state_values
set_state!
current_time
getu
setu
```

# Traits
# Symbolic Trait

```@docs
ScalarSymbolic
Expand Down
83 changes: 71 additions & 12 deletions docs/src/complete_sii.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ function SymbolicIndexingInterface.observed(sys::ExampleSystem, sym::Expr)
end
```

In case a type does not support such observed quantities, `is_observed` must be
defined to always return `false`, and `observed` does not need to be implemented.

### Note about constant structure

Note that the method definitions are all assuming `constant_structure(p) == true`.
Expand Down Expand Up @@ -174,35 +177,91 @@ mutable struct ExampleIntegrator
u::Vector{Float64}
p::Vector{Float64}
t::Float64
state_index::Dict{Symbol,Int}
parameter_index::Dict{Symbol,Int}
independent_variable::Symbol
sys::ExampleSystem
end
```

Assume that it implements the mandatory part of the interface as described above, and
the following methods below:

```julia
# define a fallback for the interface methods
SymbolicIndexingInterface.symbolic_container(integ::ExampleIntegrator) = integ.sys
SymbolicIndexingInterface.state_values(sys::ExampleIntegrator) = sys.u
SymbolicIndexingInterface.parameter_values(sys::ExampleIntegrator) = sys.p
SymbolicIndexingInterface.current_time(sys::ExampleIntegrator) = sys.t
```

Then the following example would work:
```julia
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t)
getx = getu(integrator, :x)
sys = ExampleSystem(Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t, Dict())
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, sys)
getx = getu(sys, :x)
getx(integrator) # 1.0

get_expr = getu(integrator, :(x + y + t))
get_expr = getu(sys, :(x + y + t))
get_expr(integrator) # 13.0

setx! = setu(integrator, :y)
setx! = setu(sys, :y)
setx!(integrator, 0.0)
getx(integrator) # 0.0
```

In case a type stores timeseries data (such as solutions), then it must also implement
the [`Timeseries`](@ref) trait. The type would then return a timeseries from
[`state_values`](@ref) and [`current_time`](@ref) and the function returned from
[`getu`](@ref) would then return a timeseries as well. For example, consider the
`ExampleSolution` below:

```julia
struct ExampleSolution
u::Vector{Vector{Float64}}
t::Vector{Float64}
p::Vector{Float64}
sys::ExampleSystem
end

# define a fallback for the interface methods
SymbolicIndexingInterface.symbolic_container(sol::ExampleSolution) = sol.sys
SymbolicIndexingInterface.parameter_values(sol::ExampleSolution) = sol.p
# define the trait
SymbolicIndexingInterface.is_timeseries(::Type{ExampleSolution}) = Timeseries()
# both state_values and current_time return a timeseries, which must be
# the same length
SymbolicIndexingInterface.state_values(sol::ExampleSolution) = sol.u
SymbolicIndexingInterface.current_time(sol::ExampleSolution) = sol.t
```

Then the following example would work:
```julia
# using the same system that the ExampleIntegrator used
sol = ExampleSolution([[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]], [4.0, 5.0], [6.0, 7.0], sys)
getx = getu(sys, :x)
getx(sol) # [1.0, 1.5]

get_expr = getu(sys, :(x + y + t))
get_expr(sol) # [9.0, 11.0]

get_arr = getu(sys, [:y, :(x + a)])
get_arr(sol) # [[2.0, 5.0], [2.5, 5.5]]

get_tuple = getu(sys, (:z, :(z * t)))
get_tuple(sol) # [(3.0, 18.0), (3.5, 24.5)]
```

Note that `setu` is not designed to work for `Timeseries` objects.

If a type needs to perform some additional actions when updating the state/parameters
or if it is not possible to return a mutable reference to the state/parameter vector
which can directly be modified, the functions [`set_state!`](@ref) and/or
[`set_parameter!`](@ref) can be used. For example, suppose our `ExampleIntegrator`
had an additional field `u_modified::Bool` to allow it to keep track of when a
discontinuity occurs and handle it appropriately. This flag needs to be set to `true`
whenever the state is modified. The `set_state!` function can then be implemented as
follows:

```julia
function SymbolicIndexingInterface.set_state!(integrator::ExampleIntegrator, val, idx)
integrator.u[idx] = val
integrator.u_modified = true
end
```

# Implementing the `SymbolicTypeTrait` for a type

The `SymbolicTypeTrait` is used to identify values that can act as symbolic variables. It
Expand Down
8 changes: 5 additions & 3 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@ include("trait.jl")

export is_variable, variable_index, variable_symbols, is_parameter, parameter_index,
parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed,
observed, is_time_dependent, constant_structure, symbolic_container, all_variable_symbols,
observed, is_time_dependent, constant_structure, symbolic_container,
all_variable_symbols,
all_symbols, solvedvariables, allvariables
include("interface.jl")

export SymbolCache
include("symbol_cache.jl")

export parameter_values, getp, setp
export parameter_values, set_parameter!, getp, setp
include("parameter_indexing.jl")

export state_values, current_time, getu, setu
export Timeseries,
NotTimeseries, is_timeseries, state_values, set_state!, current_time, getu, setu
include("state_indexing.jl")

end
4 changes: 3 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ have the signature `(u, p) -> [values...]` where `u` and `p` is the current stat
parameter vector, respectively. If `istimedependent(sys) == true`, the function should accept
the current time `t` as its third parameter. If `constant_structure(sys) == false`,
accept a third parameter, which can either be a vector of symbols indicating the order
of states or a time index, which identifies the order of states.
of states or a time index, which identifies the order of states. This function
does not need to be defined if [`is_observed`](@ref) always returns `false`. Thus,
it is mandatory to always check `is_observed` before using this function.
See also: [`is_time_dependent`](@ref), [`constant_structure`](@ref)
"""
Expand Down
25 changes: 20 additions & 5 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,20 @@ Return an indexable collection containing the value of each parameter in `p`.
"""
function parameter_values end

"""
set_parameter!(sys, val, idx)
Set the parameter at index `idx` to `val` for system `sys`. This defaults to modifying
`parameter_values(sys)`. If any additional bookkeeping needs to be performed or the
default implementation does not work for a particular type, this method needs to be
defined to enable the proper functioning of [`setp`](@ref).
See: [`parameter_values`](@ref)
"""
function set_parameter!(sys, val, idx)
parameter_values(sys)[idx] = val
end

"""
getp(sys, p)
Expand Down Expand Up @@ -55,8 +69,9 @@ Return a function that takes an integrator of `sys` and a value, and sets
the parameter `p` to that value. Note that `p` can be a direct numerical index or a
symbolic value. Requires that the integrator implement [`parameter_values`](@ref) and the
returned collection be a mutable reference to the parameter vector in the integrator. In
case `parameter_values` cannot return such a mutable reference, `setp` needs to be
implemented manually.
case `parameter_values` cannot return such a mutable reference, or additional actions
need to be performed when updating parameters, [`set_parameter!`](@ref) must be
implemented.
"""
function setp(sys, p)
symtype = symbolic_type(p)
Expand All @@ -70,21 +85,21 @@ end

function _setp(sys, ::NotSymbolic, p)
return function setter!(sol, val)
parameter_values(sol)[p] = val
set_parameter!(sol, val, p)
end
end

function _setp(sys, ::ScalarSymbolic, p)
idx = parameter_index(sys, p)
return function setter!(sol, val)
parameter_values(sol)[idx] = val
set_parameter!(sol, val, idx)
end
end

function _setp(sys, ::ScalarSymbolic, p::Union{Tuple, AbstractArray})
idxs = parameter_index.((sys,), p)
return function setter!(sol, val)
setindex!.((parameter_values(sol),), val, idxs)
set_parameter!.((sol,), val, idxs)
end
end

Expand Down
Loading

0 comments on commit 4da53e6

Please sign in to comment.