From b1a7f72a0a43cca5c16dde03648718477fe9621e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 1 Jan 2024 21:57:26 +0530 Subject: [PATCH 1/4] feat: add `IsTimeseriesTrait`, support timeseries objects in `getu` --- docs/src/api.md | 3 + src/SymbolicIndexingInterface.jl | 2 +- src/state_indexing.jl | 118 ++++++++++++++++++++++++++----- test/state_indexing_test.jl | 24 ++++++- 4 files changed, 127 insertions(+), 20 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index dc6a579..4470842 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -18,6 +18,9 @@ all_variable_symbols all_symbols solvedvariables allvariables +Timeseries +NotTimeseries +is_timeseries state_values parameter_values current_time diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 1635737..7f79e66 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -15,7 +15,7 @@ include("symbol_cache.jl") export parameter_values, getp, setp include("parameter_indexing.jl") -export state_values, current_time, getu, setu +export Timeseries, NotTimeseries, is_timeseries, state_values, current_time, getu, setu include("state_indexing.jl") end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 9ce95c3..284cfc9 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -1,29 +1,77 @@ +abstract type IsTimeseriesTrait end + +""" + struct Timeseries <: IsTimeseriesTrait end + +Trait indicating a type contains timeseries data. This affects the behaviour of +functions such as [`state_values`](@ref) and [`current_time`](@ref). + +See also: [`NotTimeseries`](@ref), [`is_timeseries`](@ref) +""" +struct Timeseries <: IsTimeseriesTrait end + +""" + struct NotTimeseries <: IsTimeseriesTrait end + +Trait indicating a type does not contain timeseries data. This affects the behaviour +of functions such as [`state_values`](@ref) and [`current_time`](@ref). Note that +if a type is `NotTimeseries` this only implies that it does not _store_ timeseries +data. It may still be time-dependent. For example, an `ODEProblem` only stores +the initial state of a system, so it is `NotTimeseries`, but still time-dependent. +This is the default trait variant for all types. + +See also: [`Timeseries`](@ref), [`is_timeseries`](@ref) +""" +struct NotTimeseries <: IsTimeseriesTrait end + +""" + is_timeseries(x) = is_timeseries(typeof(x)) + is_timeseries(::Type) + +Get the timeseries trait of a type. Defaults to [`NotTimeseries`](@ref) for all types. + +See also: [`Timeseries`](@ref), [`NotTimeseries`](@ref) +""" +function is_timeseries end + +is_timeseries(x) = is_timeseries(typeof(x)) +is_timeseries(::Type) = NotTimeseries() + """ state_values(p) Return an indexable collection containing the values of all states in the integrator or -problem `p`. +problem `p`. If `is_timeseries(p)` is [`Timeseries`](@ref), return a vector of arrays, +each of which contain the state values at the corresponding timestep. + +See: [`is_timeseries`](@ref) """ function state_values end """ current_time(p) -Return the current time in the integrator or problem `p`. +Return the current time in the integrator or problem `p`. If +`is_timeseries(p)` is [`Timeseries`](@ref), return the vector of timesteps at which +the state value is saved. + + +See: [`is_timeseries`](@ref) """ function current_time end """ getu(sys, sym) -Return a function that takes an integrator or problem of `sys`, and returns the value of -the symbolic `sym`. `sym` can be a direct index into the state vector, a symbolic state, -a symbolic expression involving symbolic quantities in the system `sys`, or an -array/tuple of the aforementioned. +Return a function that takes an integrator, problem or solution of `sys`, and returns +the value of the symbolic `sym`. `sym` can be a direct index into the state vector, a +symbolic state, a symbolic expression involving symbolic quantities in the system +`sys`, or an array/tuple of the aforementioned. -At minimum, this requires that the integrator or problem implement [`state_values`](@ref). -To support symbolic expressions, the integrator or problem must implement -[`observed`](@ref), [`parameter_values`](@ref) and [`current_time`](@ref). +At minimum, this requires that the integrator, problem or solution implement +[`state_values`](@ref). To support symbolic expressions, the integrator or problem +must implement [`observed`](@ref), [`parameter_values`](@ref) and +[`current_time`](@ref). This function typically does not need to be implemented, and has a default implementation relying on the above functions. @@ -40,37 +88,68 @@ function getu(sys, sym) end function _getu(sys, ::NotSymbolic, sym) + _getter(::Timeseries, prob) = getindex.(state_values(prob), (sym,)) + _getter(::NotTimeseries, prob) = state_values(prob)[sym] return function getter(prob) - return state_values(prob)[sym] + return _getter(is_timeseries(prob), prob) end end function _getu(sys, ::ScalarSymbolic, sym) if is_variable(sys, sym) idx = variable_index(sys, sym) - return function getter1(prob) - return state_values(prob)[idx] - end + return getu(sys, idx) elseif is_observed(sys, sym) fn = observed(sys, sym) if is_time_dependent(sys) - function getter2(prob) + function _getter2(::Timeseries, prob) + return fn.(state_values(prob), + (parameter_values(prob),), + current_time(prob)) + end + function _getter2(::NotTimeseries, prob) return fn(state_values(prob), parameter_values(prob), current_time(prob)) end + + return function getter2(prob) + return _getter2(is_timeseries(prob), prob) + end else - function getter3(prob) + function _getter3(::Timeseries, prob) + return fn.(state_values(prob), (parameter_values(prob),)) + end + function _getter3(::NotTimeseries, prob) return fn(state_values(prob), parameter_values(prob)) end + + return function getter3(prob) + return _getter3(is_timeseries(prob), prob) + end end end error("Invalid symbol $sym for `getu`") end -function _getu(sys, ::ScalarSymbolic, sym::Union{<:Tuple,<:AbstractArray}) +struct TimeseriesIndexWrapper{T, I} + timeseries::T + idx::I +end + +state_values(t::TimeseriesIndexWrapper) = state_values(t.timeseries)[t.idx] +parameter_values(t::TimeseriesIndexWrapper) = parameter_values(t.timeseries) +current_time(t::TimeseriesIndexWrapper) = current_time(t.timeseries)[t.idx] + +function _getu(sys, ::ScalarSymbolic, sym::Union{<:Tuple, <:AbstractArray}) getters = getu.((sys,), sym) _call(getter, prob) = getter(prob) + + function _getter(::Timeseries, prob) + tiws = TimeseriesIndexWrapper.((prob,), eachindex(state_values(prob))) + return [_getter(NotTimeseries(), tiw) for tiw in tiws] + end + _getter(::NotTimeseries, prob) = _call.(getters, (prob,)) return function getter(prob) - return _call.(getters, (prob,)) + return _getter(is_timeseries(prob), prob) end end @@ -86,6 +165,9 @@ the state `sym` to that value. Note that `sym` can be a direct numerical index, Requires that the integrator implement [`state_values`](@ref) and the returned collection be a mutable reference to the state vector in the integrator/problem. +This function does not work on types for which [`is_timeseries`](@ref) is +[`Timeseries`](@ref). + In case `state_values` cannot return such a mutable reference, `setu` needs to be implemented manually. """ @@ -114,7 +196,7 @@ function _setu(sys, ::ScalarSymbolic, sym) end end -function _setu(sys, ::ScalarSymbolic, sym::Union{<:Tuple,<:AbstractArray}) +function _setu(sys, ::ScalarSymbolic, sym::Union{<:Tuple, <:AbstractArray}) setters = setu.((sys,), sym) _call!(setter!, prob, val) = setter!(prob, val) return function setter!(prob, val) diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index 87a1a8f..7fc3a61 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -1,6 +1,6 @@ using SymbolicIndexingInterface -struct FakeIntegrator{S,U} +struct FakeIntegrator{S, U} sys::S u::U end @@ -20,3 +20,25 @@ for (i, sym) in [(1, :x), (2, :y), (3, :z), ([1, 2], [:x, :y]), ((3, 2), (:z, :y @test get(fi) == 0.5 .* i set!(fi, true_value) end + +struct FakeSolution{S, U} + sys::S + u::U +end + +SymbolicIndexingInterface.is_timeseries(::Type{<:FakeSolution}) = Timeseries() +SymbolicIndexingInterface.symbolic_container(fp::FakeSolution) = fp.sys +SymbolicIndexingInterface.state_values(fp::FakeSolution) = fp.u + +sys = SymbolCache([:x, :y, :z], [:a, :b], [:t]) +u = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] +sol = FakeSolution(sys, u) +for (i, sym) in [(1, :x), (2, :y), (3, :z), ([1, 2], [:x, :y]), ((3, 2), (:z, :y))] + get = getu(sys, sym) + true_value = if i isa Tuple + [getindex.((v,), i) for v in u] + else + getindex.(u, (i,)) + end + @test get(sol) == true_value +end From 0166940e38233960d2f14ba1629a0c484968a39b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 1 Jan 2024 21:57:33 +0530 Subject: [PATCH 2/4] refactor: format --- src/SymbolicIndexingInterface.jl | 3 ++- src/symbol_cache.jl | 4 +++- test/parameter_indexing_test.jl | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 7f79e66..4663bdc 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -5,7 +5,8 @@ include("trait.jl") export is_variable, variable_index, variable_symbols, is_parameter, parameter_index, parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed, - observed, is_time_dependent, constant_structure, symbolic_container, all_variable_symbols, + observed, is_time_dependent, constant_structure, symbolic_container, + all_variable_symbols, all_symbols, solvedvariables, allvariables include("interface.jl") diff --git a/src/symbol_cache.jl b/src/symbol_cache.jl index 2509f96..a1f566e 100644 --- a/src/symbol_cache.jl +++ b/src/symbol_cache.jl @@ -72,7 +72,9 @@ function is_time_dependent(sc::SymbolCache) end constant_structure(::SymbolCache) = true all_variable_symbols(sc::SymbolCache) = variable_symbols(sc) -all_symbols(sc::SymbolCache) = vcat(variable_symbols(sc), parameter_symbols(sc), independent_variable_symbols(sc)) +function all_symbols(sc::SymbolCache) + vcat(variable_symbols(sc), parameter_symbols(sc), independent_variable_symbols(sc)) +end function Base.copy(sc::SymbolCache) return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables), diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 455dba7..b67fb61 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -1,7 +1,7 @@ using SymbolicIndexingInterface using Test -struct FakeIntegrator{S,P} +struct FakeIntegrator{S, P} sys::S p::P end From 7a8e647cda11db022aeca14cb8fef80d441c10a5 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 2 Jan 2024 11:01:27 +0530 Subject: [PATCH 3/4] feat: add `set_state!` and `set_parameter!`, update docs --- docs/src/api.md | 31 ++++++++++-- docs/src/complete_sii.md | 83 +++++++++++++++++++++++++++----- src/SymbolicIndexingInterface.jl | 5 +- src/interface.jl | 4 +- src/parameter_indexing.jl | 25 ++++++++-- src/state_indexing.jl | 24 ++++++--- 6 files changed, 141 insertions(+), 31 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 4470842..dbc3a3f 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -1,5 +1,7 @@ # Interface Functions +## Mandatory methods + ```@docs symbolic_container is_variable @@ -11,26 +13,45 @@ parameter_symbols is_independent_variable independent_variable_symbols is_observed -observed is_time_dependent constant_structure all_variable_symbols all_symbols solvedvariables allvariables +``` + +## Optional Methods + +### Observed equation handling + +```@docs +observed +``` + +### Parameter indexing + +```@docs +parameter_values +set_parameter! +getp +setp +``` + +### State indexing + +```@docs Timeseries NotTimeseries is_timeseries state_values -parameter_values +set_state! current_time -getp -setp getu setu ``` -# Traits +# Symbolic Trait ```@docs ScalarSymbolic diff --git a/docs/src/complete_sii.md b/docs/src/complete_sii.md index 035eec1..12c1623 100644 --- a/docs/src/complete_sii.md +++ b/docs/src/complete_sii.md @@ -123,6 +123,9 @@ function SymbolicIndexingInterface.observed(sys::ExampleSystem, sym::Expr) end ``` +In case a type does not support such observed quantities, `is_observed` must be +defined to always return `false`, and `observed` does not need to be implemented. + ### Note about constant structure Note that the method definitions are all assuming `constant_structure(p) == true`. @@ -174,16 +177,11 @@ mutable struct ExampleIntegrator u::Vector{Float64} p::Vector{Float64} t::Float64 - state_index::Dict{Symbol,Int} - parameter_index::Dict{Symbol,Int} - independent_variable::Symbol + sys::ExampleSystem end -``` -Assume that it implements the mandatory part of the interface as described above, and -the following methods below: - -```julia +# define a fallback for the interface methods +SymbolicIndexingInterface.symbolic_container(integ::ExampleIntegrator) = integ.sys SymbolicIndexingInterface.state_values(sys::ExampleIntegrator) = sys.u SymbolicIndexingInterface.parameter_values(sys::ExampleIntegrator) = sys.p SymbolicIndexingInterface.current_time(sys::ExampleIntegrator) = sys.t @@ -191,18 +189,79 @@ SymbolicIndexingInterface.current_time(sys::ExampleIntegrator) = sys.t Then the following example would work: ```julia -integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t) -getx = getu(integrator, :x) +sys = ExampleSystem(Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t, Dict()) +integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, sys) +getx = getu(sys, :x) getx(integrator) # 1.0 -get_expr = getu(integrator, :(x + y + t)) +get_expr = getu(sys, :(x + y + t)) get_expr(integrator) # 13.0 -setx! = setu(integrator, :y) +setx! = setu(sys, :y) setx!(integrator, 0.0) getx(integrator) # 0.0 ``` +In case a type stores timeseries data (such as solutions), then it must also implement +the [`Timeseries`](@ref) trait. The type would then return a timeseries from +[`state_values`](@ref) and [`current_time`](@ref) and the function returned from +[`getu`](@ref) would then return a timeseries as well. For example, consider the +`ExampleSolution` below: + +```julia +struct ExampleSolution + u::Vector{Vector{Float64}} + t::Vector{Float64} + p::Vector{Float64} + sys::ExampleSystem +end + +# define a fallback for the interface methods +SymbolicIndexingInterface.symbolic_container(sol::ExampleSolution) = sol.sys +SymbolicIndexingInterface.parameter_values(sol::ExampleSolution) = sol.p +# define the trait +SymbolicIndexingInterface.is_timeseries(::Type{ExampleSolution}) = Timeseries() +# both state_values and current_time return a timeseries, which must be +# the same length +SymbolicIndexingInterface.state_values(sol::ExampleSolution) = sol.u +SymbolicIndexingInterface.current_time(sol::ExampleSolution) = sol.t +``` + +Then the following example would work: +```julia +# using the same system that the ExampleIntegrator used +sol = ExampleSolution([[1.0, 2.0, 3.0], [1.5, 2.5, 3.5]], [4.0, 5.0], [6.0, 7.0], sys) +getx = getu(sys, :x) +getx(sol) # [1.0, 1.5] + +get_expr = getu(sys, :(x + y + t)) +get_expr(sol) # [9.0, 11.0] + +get_arr = getu(sys, [:y, :(x + a)]) +get_arr(sol) # [[2.0, 5.0], [2.5, 5.5]] + +get_tuple = getu(sys, (:z, :(z * t))) +get_tuple(sol) # [(3.0, 18.0), (3.5, 24.5)] +``` + +Note that `setu` is not designed to work for `Timeseries` objects. + +If a type needs to perform some additional actions when updating the state/parameters +or if it is not possible to return a mutable reference to the state/parameter vector +which can directly be modified, the functions [`set_state!`](@ref) and/or +[`set_parameter!`](@ref) can be used. For example, suppose our `ExampleIntegrator` +had an additional field `u_modified::Bool` to allow it to keep track of when a +discontinuity occurs and handle it appropriately. This flag needs to be set to `true` +whenever the state is modified. The `set_state!` function can then be implemented as +follows: + +```julia +function SymbolicIndexingInterface.set_state!(integrator::ExampleIntegrator, val, idx) + integrator.u[idx] = val + integrator.u_modified = true +end +``` + # Implementing the `SymbolicTypeTrait` for a type The `SymbolicTypeTrait` is used to identify values that can act as symbolic variables. It diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 4663bdc..a730d3e 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -13,10 +13,11 @@ include("interface.jl") export SymbolCache include("symbol_cache.jl") -export parameter_values, getp, setp +export parameter_values, set_parameter!, getp, setp include("parameter_indexing.jl") -export Timeseries, NotTimeseries, is_timeseries, state_values, current_time, getu, setu +export Timeseries, + NotTimeseries, is_timeseries, state_values, set_state!, current_time, getu, setu include("state_indexing.jl") end diff --git a/src/interface.jl b/src/interface.jl index 9838b1e..36dac86 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -91,7 +91,9 @@ have the signature `(u, p) -> [values...]` where `u` and `p` is the current stat parameter vector, respectively. If `istimedependent(sys) == true`, the function should accept the current time `t` as its third parameter. If `constant_structure(sys) == false`, accept a third parameter, which can either be a vector of symbols indicating the order -of states or a time index, which identifies the order of states. +of states or a time index, which identifies the order of states. This function +does not need to be defined if [`is_observed`](@ref) always returns `false`. Thus, +it is mandatory to always check `is_observed` before using this function. See also: [`is_time_dependent`](@ref), [`constant_structure`](@ref) """ diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index c6745ab..5243072 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -5,6 +5,20 @@ Return an indexable collection containing the value of each parameter in `p`. """ function parameter_values end +""" + set_parameter!(sys, val, idx) + +Set the parameter at index `idx` to `val` for system `sys`. This defaults to modifying +`parameter_values(sys)`. If any additional bookkeeping needs to be performed or the +default implementation does not work for a particular type, this method needs to be +defined to enable the proper functioning of [`setp`](@ref). + +See: [`parameter_values`](@ref) +""" +function set_parameter!(sys, val, idx) + parameter_values(sys)[idx] = val +end + """ getp(sys, p) @@ -55,8 +69,9 @@ Return a function that takes an integrator of `sys` and a value, and sets the parameter `p` to that value. Note that `p` can be a direct numerical index or a symbolic value. Requires that the integrator implement [`parameter_values`](@ref) and the returned collection be a mutable reference to the parameter vector in the integrator. In -case `parameter_values` cannot return such a mutable reference, `setp` needs to be -implemented manually. +case `parameter_values` cannot return such a mutable reference, or additional actions +need to be performed when updating parameters, [`set_parameter!`](@ref) must be +implemented. """ function setp(sys, p) symtype = symbolic_type(p) @@ -70,21 +85,21 @@ end function _setp(sys, ::NotSymbolic, p) return function setter!(sol, val) - parameter_values(sol)[p] = val + set_parameter!(sol, val, p) end end function _setp(sys, ::ScalarSymbolic, p) idx = parameter_index(sys, p) return function setter!(sol, val) - parameter_values(sol)[idx] = val + set_parameter!(sol, val, idx) end end function _setp(sys, ::ScalarSymbolic, p::Union{Tuple, AbstractArray}) idxs = parameter_index.((sys,), p) return function setter!(sol, val) - setindex!.((parameter_values(sol),), val, idxs) + set_parameter!.((sol,), val, idxs) end end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 284cfc9..94a02a6 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -48,6 +48,20 @@ See: [`is_timeseries`](@ref) """ function state_values end +""" + set_state!(sys, val, idx) + +Set the state at index `idx` to `val` for system `sys`. This defaults to modifying +`state_values(sys)`. If any additional bookkeeping needs to be performed or the +default implementation does not work for a particular type, this method needs to be +defined to enable the proper functioning of [`setu`](@ref). + +See: [`state_values`](@ref) +""" +function set_state!(sys, val, idx) + state_values(sys)[idx] = val +end + """ current_time(p) @@ -164,12 +178,10 @@ Return a function that takes an integrator or problem of `sys` and a value, and the state `sym` to that value. Note that `sym` can be a direct numerical index, a symbolic state, or an array/tuple of the aforementioned. Requires that the integrator implement [`state_values`](@ref) and the -returned collection be a mutable reference to the state vector in the integrator/problem. +returned collection be a mutable reference to the state vector in the integrator/problem. Alternatively, if this is not possible or additional actions need to +be performed when updating state, [`set_state!`](@ref) can be defined. This function does not work on types for which [`is_timeseries`](@ref) is [`Timeseries`](@ref). - -In case `state_values` cannot return such a mutable reference, `setu` needs to be -implemented manually. """ function setu(sys, sym) symtype = symbolic_type(sym) @@ -184,7 +196,7 @@ end function _setu(sys, ::NotSymbolic, sym) return function setter!(prob, val) - state_values(prob)[sym] = val + set_state!(prob, val, sym) end end @@ -192,7 +204,7 @@ function _setu(sys, ::ScalarSymbolic, sym) is_variable(sys, sym) || error("Invalid symbol $sym for `setu`") idx = variable_index(sys, sym) return function setter!(prob, val) - state_values(prob)[idx] = val + set_state!(prob, val, idx) end end From 29d692b283ed6c0a108437482a4117d0f4285590 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 2 Jan 2024 16:17:04 +0530 Subject: [PATCH 4/4] build: run SciML SII tests in downstream CI --- .github/workflows/Downstream.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index a5ecc57..701df92 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -19,6 +19,7 @@ jobs: package: - {user: SciML, repo: RecursiveArrayTools.jl, group: All} - {user: JuliaSymbolics, repo: Symbolics.jl, group: SymbolicIndexingInterface} + - {user: SciML, repo: SciMLBase.jl, group: SymbolicIndexingInterface} steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1