From 8fcceb2f5e703e98601c4843bb8f3d703595ed3b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 1 Jan 2024 13:45:58 +0530 Subject: [PATCH] feat: add `getu`/`setu`, update docs, add tests --- Project.toml | 4 +- docs/src/api.md | 4 + docs/src/complete_sii.md | 83 +++++++++++++++++--- src/SymbolicIndexingInterface.jl | 3 + src/state_indexing.jl | 127 +++++++++++++++++++++++++++++++ src/trait.jl | 3 +- test/example_test.jl | 3 + test/fallback_test.jl | 1 + test/parameter_indexing_test.jl | 6 +- test/runtests.jl | 17 +++-- test/state_indexing_test.jl | 22 ++++++ test/symbol_cache_test.jl | 3 + 12 files changed, 256 insertions(+), 20 deletions(-) create mode 100644 src/state_indexing.jl create mode 100644 test/state_indexing_test.jl diff --git a/Project.toml b/Project.toml index 46cbad57..88114a1e 100644 --- a/Project.toml +++ b/Project.toml @@ -5,12 +5,14 @@ version = "0.3.1" [compat] Aqua = "0.8" +SafeTestsets = "0.0.1" Test = "1" julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Test"] +test = ["Aqua", "Test", "SafeTestsets"] diff --git a/docs/src/api.md b/docs/src/api.md index bc60cac2..dc6a5792 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -18,9 +18,13 @@ all_variable_symbols all_symbols solvedvariables allvariables +state_values parameter_values +current_time getp setp +getu +setu ``` # Traits diff --git a/docs/src/complete_sii.md b/docs/src/complete_sii.md index e714718b..035eec1d 100644 --- a/docs/src/complete_sii.md +++ b/docs/src/complete_sii.md @@ -85,22 +85,41 @@ These are for handling symbolic expressions and generating equations which are n in the solution vector. ```julia +using RuntimeGeneratedFunctions +RuntimeGeneratedFunctions.init(@__MODULE__) + # this type accepts `Expr` for observed expressions involving state/parameter/observed # variables SymbolicIndexingInterface.is_observed(sys::ExampleSystem, sym) = sym isa Expr || sym isa Symbol && haskey(sys.observed, sym) function SymbolicIndexingInterface.observed(sys::ExampleSystem, sym::Expr) + # generate a function with the appropriate signature if is_time_dependent(sys) - return function (u, p, t) - # compute value from `sym`, leveraging `variable_index` and - # `parameter_index` to turn symbols into indices - end + fn_expr = :( + function gen(u, p, t) + # assign a variable for each state symbol it's value in u + $([:($var = u[$idx]) for (var, idx) in pairs(sys.state_index)]...) + # assign a variable for each parameter symbol it's value in p + $([:($var = p[$idx]) for (var, idx) in pairs(sys.parameter_index)]...) + # assign a variable for the independent variable + $(sys.independent_variable) = t + # return the value of the expression + return $sym + end + ) else - return function (u, p) - # compute value from `sym`, leveraging `variable_index` and - # `parameter_index` to turn symbols into indices - end + fn_expr = :( + function gen(u, p) + # assign a variable for each state symbol it's value in u + $([:($var = u[$idx]) for (var, idx) in pairs(sys.state_index)]...) + # assign a variable for each parameter symbol it's value in p + $([:($var = p[$idx]) for (var, idx) in pairs(sys.parameter_index)]...) + # return the value of the expression + return $sym + end + ) end + return @RuntimeGeneratedFunction(fn_expr) end ``` @@ -127,7 +146,7 @@ only responsible for identifying observed values and `observed` will always be c on a type that wraps this type. An example is `ModelingToolkit.AbstractSystem`, which can identify whether a value is observed, but cannot implement `observed` itself. -Other optional methods relate to parameter indexing. If a type contains the values of +Other optional methods relate to indexing functions. If a type contains the values of parameter variables, it must implement [`parameter_values`](@ref). This allows the default definitions of [`getp`](@ref) and [`setp`](@ref) to work. While `setp` is not typically useful for solution objects, it may be useful for integrators. Typically, @@ -140,7 +159,51 @@ function SymbolicIndexingInterface.parameter_values(sys::ExampleSystem) end ``` -## Implementing the `SymbolicTypeTrait` for a type +If a type contains the value of state variables, it can define [`state_values`](@ref) to +enable the usage of [`getu`](@ref) and [`setu`](@ref). These methods retturn getter/ +setter functions to access or update the value of a state variable (or a collection of +them). If the type also supports generating [`observed`](@ref) functions, `getu` also +enables returning functions to access the value of arbitrary expressions involving +the system's symbols. This also requires that the type implement +[`parameter_values`](@ref) and [`current_time`](@ref) (if the system is time-dependent). + +Consider the following `ExampleIntegrator` + +```julia +mutable struct ExampleIntegrator + u::Vector{Float64} + p::Vector{Float64} + t::Float64 + state_index::Dict{Symbol,Int} + parameter_index::Dict{Symbol,Int} + independent_variable::Symbol +end +``` + +Assume that it implements the mandatory part of the interface as described above, and +the following methods below: + +```julia +SymbolicIndexingInterface.state_values(sys::ExampleIntegrator) = sys.u +SymbolicIndexingInterface.parameter_values(sys::ExampleIntegrator) = sys.p +SymbolicIndexingInterface.current_time(sys::ExampleIntegrator) = sys.t +``` + +Then the following example would work: +```julia +integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t) +getx = getu(integrator, :x) +getx(integrator) # 1.0 + +get_expr = getu(integrator, :(x + y + t)) +get_expr(integrator) # 13.0 + +setx! = setu(integrator, :y) +setx!(integrator, 0.0) +getx(integrator) # 0.0 +``` + +# Implementing the `SymbolicTypeTrait` for a type The `SymbolicTypeTrait` is used to identify values that can act as symbolic variables. It has three variants: diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 32b6f9e5..16357374 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -15,4 +15,7 @@ include("symbol_cache.jl") export parameter_values, getp, setp include("parameter_indexing.jl") +export state_values, current_time, getu, setu +include("state_indexing.jl") + end diff --git a/src/state_indexing.jl b/src/state_indexing.jl new file mode 100644 index 00000000..9ce95c38 --- /dev/null +++ b/src/state_indexing.jl @@ -0,0 +1,127 @@ +""" + state_values(p) + +Return an indexable collection containing the values of all states in the integrator or +problem `p`. +""" +function state_values end + +""" + current_time(p) + +Return the current time in the integrator or problem `p`. +""" +function current_time end + +""" + getu(sys, sym) + +Return a function that takes an integrator or problem of `sys`, and returns the value of +the symbolic `sym`. `sym` can be a direct index into the state vector, a symbolic state, +a symbolic expression involving symbolic quantities in the system `sys`, or an +array/tuple of the aforementioned. + +At minimum, this requires that the integrator or problem implement [`state_values`](@ref). +To support symbolic expressions, the integrator or problem must implement +[`observed`](@ref), [`parameter_values`](@ref) and [`current_time`](@ref). + +This function typically does not need to be implemented, and has a default implementation +relying on the above functions. +""" +function getu(sys, sym) + symtype = symbolic_type(sym) + elsymtype = symbolic_type(eltype(sym)) + + if symtype != NotSymbolic() + _getu(sys, symtype, sym) + else + _getu(sys, elsymtype, sym) + end +end + +function _getu(sys, ::NotSymbolic, sym) + return function getter(prob) + return state_values(prob)[sym] + end +end + +function _getu(sys, ::ScalarSymbolic, sym) + if is_variable(sys, sym) + idx = variable_index(sys, sym) + return function getter1(prob) + return state_values(prob)[idx] + end + elseif is_observed(sys, sym) + fn = observed(sys, sym) + if is_time_dependent(sys) + function getter2(prob) + return fn(state_values(prob), parameter_values(prob), current_time(prob)) + end + else + function getter3(prob) + return fn(state_values(prob), parameter_values(prob)) + end + end + end + error("Invalid symbol $sym for `getu`") +end + +function _getu(sys, ::ScalarSymbolic, sym::Union{<:Tuple,<:AbstractArray}) + getters = getu.((sys,), sym) + _call(getter, prob) = getter(prob) + return function getter(prob) + return _call.(getters, (prob,)) + end +end + +function _getu(sys, ::ArraySymbolic, sym) + return getu(sys, collect(sym)) +end + +""" + setu(sys, sym) + +Return a function that takes an integrator or problem of `sys` and a value, and sets the +the state `sym` to that value. Note that `sym` can be a direct numerical index, a symbolic state, or an array/tuple of the aforementioned. + +Requires that the integrator implement [`state_values`](@ref) and the +returned collection be a mutable reference to the state vector in the integrator/problem. +In case `state_values` cannot return such a mutable reference, `setu` needs to be +implemented manually. +""" +function setu(sys, sym) + symtype = symbolic_type(sym) + elsymtype = symbolic_type(eltype(sym)) + + if symtype != NotSymbolic() + _setu(sys, symtype, sym) + else + _setu(sys, elsymtype, sym) + end +end + +function _setu(sys, ::NotSymbolic, sym) + return function setter!(prob, val) + state_values(prob)[sym] = val + end +end + +function _setu(sys, ::ScalarSymbolic, sym) + is_variable(sys, sym) || error("Invalid symbol $sym for `setu`") + idx = variable_index(sys, sym) + return function setter!(prob, val) + state_values(prob)[idx] = val + end +end + +function _setu(sys, ::ScalarSymbolic, sym::Union{<:Tuple,<:AbstractArray}) + setters = setu.((sys,), sym) + _call!(setter!, prob, val) = setter!(prob, val) + return function setter!(prob, val) + _call!.(setters, (prob,), val) + end +end + +function _setu(sys, ::ArraySymbolic, sym) + return setu(sys, collect(sym)) +end diff --git a/src/trait.jl b/src/trait.jl index 6660e917..42351148 100644 --- a/src/trait.jl +++ b/src/trait.jl @@ -35,13 +35,14 @@ struct NotSymbolic <: SymbolicTypeTrait end symbolic_type(::Type) Get the symbolic type trait of a type. Default to [`NotSymbolic`](@ref) for all types -except `Symbol`. +except `Symbol` and `Expr`, both of which are [`ScalarSymbolic`](@ref). See also: [`ScalarSymbolic`](@ref), [`ArraySymbolic`](@ref), [`NotSymbolic`](@ref) """ symbolic_type(x) = symbolic_type(typeof(x)) symbolic_type(::Type) = NotSymbolic() symbolic_type(::Type{Symbol}) = ScalarSymbolic() +symbolic_type(::Type{Expr}) = ScalarSymbolic() """ hasname(x) diff --git a/test/example_test.jl b/test/example_test.jl index 965d1b41..5d3caa0c 100644 --- a/test/example_test.jl +++ b/test/example_test.jl @@ -1,3 +1,6 @@ +using SymbolicIndexingInterface +using Test + struct SystemMockup static::Bool vars::Vector{Symbol} diff --git a/test/fallback_test.jl b/test/fallback_test.jl index f390178a..622d8685 100644 --- a/test/fallback_test.jl +++ b/test/fallback_test.jl @@ -1,4 +1,5 @@ using SymbolicIndexingInterface +using Test struct Wrapper{W} wrapped::W diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 66b18936..455dba7b 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -1,6 +1,8 @@ using SymbolicIndexingInterface +using Test -struct FakeIntegrator{P} +struct FakeIntegrator{S,P} + sys::S p::P end @@ -9,7 +11,7 @@ SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p sys = SymbolCache([:x, :y, :z], [:a, :b], [:t]) p = [1.0, 2.0] -fi = FakeIntegrator(copy(p)) +fi = FakeIntegrator(sys, copy(p)) for (i, sym) in [(1, :a), (2, :b), ([1, 2], [:a, :b]), ((1, 2), (:a, :b))] get = getp(sys, sym) set! = setp(sys, sym) diff --git a/test/runtests.jl b/test/runtests.jl index beeda772..5dbabc29 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,20 +1,25 @@ using SymbolicIndexingInterface +using SafeTestsets using Test -@testset "Quality Assurance" begin + +@safetestset "Quality Assurance" begin @time include("qa.jl") end -@testset "Interface test" begin +@safetestset "Interface test" begin @time include("example_test.jl") end -@testset "Trait test" begin +@safetestset "Trait test" begin @time include("trait_test.jl") end -@testset "SymbolCache test" begin +@safetestset "SymbolCache test" begin @time include("symbol_cache_test.jl") end -@testset "Fallback test" begin +@safetestset "Fallback test" begin @time include("fallback_test.jl") end -@testset "Parameter indexing test" begin +@safetestset "Parameter indexing test" begin @time include("parameter_indexing_test.jl") end +@safetestset "State indexing test" begin + @time include("state_indexing_test.jl") +end diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl new file mode 100644 index 00000000..87a1a8f1 --- /dev/null +++ b/test/state_indexing_test.jl @@ -0,0 +1,22 @@ +using SymbolicIndexingInterface + +struct FakeIntegrator{S,U} + sys::S + u::U +end + +SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys +SymbolicIndexingInterface.state_values(fp::FakeIntegrator) = fp.u + +sys = SymbolCache([:x, :y, :z], [:a, :b], [:t]) +u = [1.0, 2.0, 3.0] +fi = FakeIntegrator(sys, copy(u)) +for (i, sym) in [(1, :x), (2, :y), (3, :z), ([1, 2], [:x, :y]), ((3, 2), (:z, :y))] + get = getu(sys, sym) + set! = setu(sys, sym) + true_value = i isa Tuple ? getindex.((u,), i) : u[i] + @test get(fi) == true_value + set!(fi, 0.5 .* i) + @test get(fi) == 0.5 .* i + set!(fi, true_value) +end diff --git a/test/symbol_cache_test.jl b/test/symbol_cache_test.jl index 39e388d2..5fb1b03a 100644 --- a/test/symbol_cache_test.jl +++ b/test/symbol_cache_test.jl @@ -1,3 +1,6 @@ +using SymbolicIndexingInterface +using Test + sc = SymbolCache([:x, :y, :z], [:a, :b], [:t]) @test all(is_variable.((sc,), [:x, :y, :z]))