Skip to content

Commit

Permalink
feat: add getu/setu, update docs, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 1, 2024
1 parent 06f2cb2 commit 8fcceb2
Show file tree
Hide file tree
Showing 12 changed files with 256 additions and 20 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ version = "0.3.1"

[compat]
Aqua = "0.8"
SafeTestsets = "0.0.1"
Test = "1"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test"]
test = ["Aqua", "Test", "SafeTestsets"]
4 changes: 4 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ all_variable_symbols
all_symbols
solvedvariables
allvariables
state_values
parameter_values
current_time
getp
setp
getu
setu
```

# Traits
Expand Down
83 changes: 73 additions & 10 deletions docs/src/complete_sii.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,41 @@ These are for handling symbolic expressions and generating equations which are n
in the solution vector.

```julia
using RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(@__MODULE__)

# this type accepts `Expr` for observed expressions involving state/parameter/observed
# variables
SymbolicIndexingInterface.is_observed(sys::ExampleSystem, sym) = sym isa Expr || sym isa Symbol && haskey(sys.observed, sym)

function SymbolicIndexingInterface.observed(sys::ExampleSystem, sym::Expr)
# generate a function with the appropriate signature
if is_time_dependent(sys)
return function (u, p, t)
# compute value from `sym`, leveraging `variable_index` and
# `parameter_index` to turn symbols into indices
end
fn_expr = :(
function gen(u, p, t)
# assign a variable for each state symbol it's value in u
$([:($var = u[$idx]) for (var, idx) in pairs(sys.state_index)]...)
# assign a variable for each parameter symbol it's value in p
$([:($var = p[$idx]) for (var, idx) in pairs(sys.parameter_index)]...)
# assign a variable for the independent variable
$(sys.independent_variable) = t
# return the value of the expression
return $sym
end
)
else
return function (u, p)
# compute value from `sym`, leveraging `variable_index` and
# `parameter_index` to turn symbols into indices
end
fn_expr = :(
function gen(u, p)
# assign a variable for each state symbol it's value in u
$([:($var = u[$idx]) for (var, idx) in pairs(sys.state_index)]...)
# assign a variable for each parameter symbol it's value in p
$([:($var = p[$idx]) for (var, idx) in pairs(sys.parameter_index)]...)
# return the value of the expression
return $sym
end
)
end
return @RuntimeGeneratedFunction(fn_expr)
end
```

Expand All @@ -127,7 +146,7 @@ only responsible for identifying observed values and `observed` will always be c
on a type that wraps this type. An example is `ModelingToolkit.AbstractSystem`, which
can identify whether a value is observed, but cannot implement `observed` itself.

Other optional methods relate to parameter indexing. If a type contains the values of
Other optional methods relate to indexing functions. If a type contains the values of
parameter variables, it must implement [`parameter_values`](@ref). This allows the
default definitions of [`getp`](@ref) and [`setp`](@ref) to work. While `setp` is
not typically useful for solution objects, it may be useful for integrators. Typically,
Expand All @@ -140,7 +159,51 @@ function SymbolicIndexingInterface.parameter_values(sys::ExampleSystem)
end
```

## Implementing the `SymbolicTypeTrait` for a type
If a type contains the value of state variables, it can define [`state_values`](@ref) to
enable the usage of [`getu`](@ref) and [`setu`](@ref). These methods retturn getter/
setter functions to access or update the value of a state variable (or a collection of
them). If the type also supports generating [`observed`](@ref) functions, `getu` also
enables returning functions to access the value of arbitrary expressions involving
the system's symbols. This also requires that the type implement
[`parameter_values`](@ref) and [`current_time`](@ref) (if the system is time-dependent).

Consider the following `ExampleIntegrator`

```julia
mutable struct ExampleIntegrator
u::Vector{Float64}
p::Vector{Float64}
t::Float64
state_index::Dict{Symbol,Int}
parameter_index::Dict{Symbol,Int}
independent_variable::Symbol
end
```

Assume that it implements the mandatory part of the interface as described above, and
the following methods below:

```julia
SymbolicIndexingInterface.state_values(sys::ExampleIntegrator) = sys.u
SymbolicIndexingInterface.parameter_values(sys::ExampleIntegrator) = sys.p
SymbolicIndexingInterface.current_time(sys::ExampleIntegrator) = sys.t
```

Then the following example would work:
```julia
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t)
getx = getu(integrator, :x)
getx(integrator) # 1.0

get_expr = getu(integrator, :(x + y + t))
get_expr(integrator) # 13.0

setx! = setu(integrator, :y)
setx!(integrator, 0.0)
getx(integrator) # 0.0
```

# Implementing the `SymbolicTypeTrait` for a type

The `SymbolicTypeTrait` is used to identify values that can act as symbolic variables. It
has three variants:
Expand Down
3 changes: 3 additions & 0 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,7 @@ include("symbol_cache.jl")
export parameter_values, getp, setp
include("parameter_indexing.jl")

export state_values, current_time, getu, setu
include("state_indexing.jl")

end
127 changes: 127 additions & 0 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""
state_values(p)
Return an indexable collection containing the values of all states in the integrator or
problem `p`.
"""
function state_values end

"""
current_time(p)
Return the current time in the integrator or problem `p`.
"""
function current_time end

"""
getu(sys, sym)
Return a function that takes an integrator or problem of `sys`, and returns the value of
the symbolic `sym`. `sym` can be a direct index into the state vector, a symbolic state,
a symbolic expression involving symbolic quantities in the system `sys`, or an
array/tuple of the aforementioned.
At minimum, this requires that the integrator or problem implement [`state_values`](@ref).
To support symbolic expressions, the integrator or problem must implement
[`observed`](@ref), [`parameter_values`](@ref) and [`current_time`](@ref).
This function typically does not need to be implemented, and has a default implementation
relying on the above functions.
"""
function getu(sys, sym)
symtype = symbolic_type(sym)
elsymtype = symbolic_type(eltype(sym))

Check warning on line 33 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L31-L33

Added lines #L31 - L33 were not covered by tests

if symtype != NotSymbolic()
_getu(sys, symtype, sym)

Check warning on line 36 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L35-L36

Added lines #L35 - L36 were not covered by tests
else
_getu(sys, elsymtype, sym)

Check warning on line 38 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L38

Added line #L38 was not covered by tests
end
end

function _getu(sys, ::NotSymbolic, sym)
return function getter(prob)
return state_values(prob)[sym]

Check warning on line 44 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L42-L44

Added lines #L42 - L44 were not covered by tests
end
end

function _getu(sys, ::ScalarSymbolic, sym)
if is_variable(sys, sym)
idx = variable_index(sys, sym)
return function getter1(prob)
return state_values(prob)[idx]

Check warning on line 52 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L48-L52

Added lines #L48 - L52 were not covered by tests
end
elseif is_observed(sys, sym)
fn = observed(sys, sym)
if is_time_dependent(sys)
function getter2(prob)
return fn(state_values(prob), parameter_values(prob), current_time(prob))

Check warning on line 58 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L54-L58

Added lines #L54 - L58 were not covered by tests
end
else
function getter3(prob)
return fn(state_values(prob), parameter_values(prob))

Check warning on line 62 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L61-L62

Added lines #L61 - L62 were not covered by tests
end
end
end
error("Invalid symbol $sym for `getu`")

Check warning on line 66 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L66

Added line #L66 was not covered by tests
end

function _getu(sys, ::ScalarSymbolic, sym::Union{<:Tuple,<:AbstractArray})
getters = getu.((sys,), sym)
_call(getter, prob) = getter(prob)
return function getter(prob)
return _call.(getters, (prob,))

Check warning on line 73 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L69-L73

Added lines #L69 - L73 were not covered by tests
end
end

function _getu(sys, ::ArraySymbolic, sym)
return getu(sys, collect(sym))

Check warning on line 78 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L77-L78

Added lines #L77 - L78 were not covered by tests
end

"""
setu(sys, sym)
Return a function that takes an integrator or problem of `sys` and a value, and sets the
the state `sym` to that value. Note that `sym` can be a direct numerical index, a symbolic state, or an array/tuple of the aforementioned.
Requires that the integrator implement [`state_values`](@ref) and the
returned collection be a mutable reference to the state vector in the integrator/problem.
In case `state_values` cannot return such a mutable reference, `setu` needs to be
implemented manually.
"""
function setu(sys, sym)
symtype = symbolic_type(sym)
elsymtype = symbolic_type(eltype(sym))

Check warning on line 94 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L92-L94

Added lines #L92 - L94 were not covered by tests

if symtype != NotSymbolic()
_setu(sys, symtype, sym)

Check warning on line 97 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L96-L97

Added lines #L96 - L97 were not covered by tests
else
_setu(sys, elsymtype, sym)

Check warning on line 99 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L99

Added line #L99 was not covered by tests
end
end

function _setu(sys, ::NotSymbolic, sym)
return function setter!(prob, val)
state_values(prob)[sym] = val

Check warning on line 105 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L103-L105

Added lines #L103 - L105 were not covered by tests
end
end

function _setu(sys, ::ScalarSymbolic, sym)
is_variable(sys, sym) || error("Invalid symbol $sym for `setu`")
idx = variable_index(sys, sym)
return function setter!(prob, val)
state_values(prob)[idx] = val

Check warning on line 113 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L109-L113

Added lines #L109 - L113 were not covered by tests
end
end

function _setu(sys, ::ScalarSymbolic, sym::Union{<:Tuple,<:AbstractArray})
setters = setu.((sys,), sym)
_call!(setter!, prob, val) = setter!(prob, val)
return function setter!(prob, val)
_call!.(setters, (prob,), val)

Check warning on line 121 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L117-L121

Added lines #L117 - L121 were not covered by tests
end
end

function _setu(sys, ::ArraySymbolic, sym)
return setu(sys, collect(sym))

Check warning on line 126 in src/state_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/state_indexing.jl#L125-L126

Added lines #L125 - L126 were not covered by tests
end
3 changes: 2 additions & 1 deletion src/trait.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ struct NotSymbolic <: SymbolicTypeTrait end
symbolic_type(::Type)
Get the symbolic type trait of a type. Default to [`NotSymbolic`](@ref) for all types
except `Symbol`.
except `Symbol` and `Expr`, both of which are [`ScalarSymbolic`](@ref).
See also: [`ScalarSymbolic`](@ref), [`ArraySymbolic`](@ref), [`NotSymbolic`](@ref)
"""
symbolic_type(x) = symbolic_type(typeof(x))
symbolic_type(::Type) = NotSymbolic()
symbolic_type(::Type{Symbol}) = ScalarSymbolic()
symbolic_type(::Type{Expr}) = ScalarSymbolic()

Check warning on line 45 in src/trait.jl

View check run for this annotation

Codecov / codecov/patch

src/trait.jl#L45

Added line #L45 was not covered by tests

"""
hasname(x)
Expand Down
3 changes: 3 additions & 0 deletions test/example_test.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using SymbolicIndexingInterface
using Test

struct SystemMockup
static::Bool
vars::Vector{Symbol}
Expand Down
1 change: 1 addition & 0 deletions test/fallback_test.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using SymbolicIndexingInterface
using Test

struct Wrapper{W}
wrapped::W
Expand Down
6 changes: 4 additions & 2 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using SymbolicIndexingInterface
using Test

struct FakeIntegrator{P}
struct FakeIntegrator{S,P}
sys::S
p::P
end

Expand All @@ -9,7 +11,7 @@ SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p

sys = SymbolCache([:x, :y, :z], [:a, :b], [:t])
p = [1.0, 2.0]
fi = FakeIntegrator(copy(p))
fi = FakeIntegrator(sys, copy(p))
for (i, sym) in [(1, :a), (2, :b), ([1, 2], [:a, :b]), ((1, 2), (:a, :b))]
get = getp(sys, sym)
set! = setp(sys, sym)
Expand Down
17 changes: 11 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
using SymbolicIndexingInterface
using SafeTestsets
using Test
@testset "Quality Assurance" begin

@safetestset "Quality Assurance" begin
@time include("qa.jl")
end
@testset "Interface test" begin
@safetestset "Interface test" begin
@time include("example_test.jl")
end
@testset "Trait test" begin
@safetestset "Trait test" begin
@time include("trait_test.jl")
end
@testset "SymbolCache test" begin
@safetestset "SymbolCache test" begin
@time include("symbol_cache_test.jl")
end
@testset "Fallback test" begin
@safetestset "Fallback test" begin
@time include("fallback_test.jl")
end
@testset "Parameter indexing test" begin
@safetestset "Parameter indexing test" begin
@time include("parameter_indexing_test.jl")
end
@safetestset "State indexing test" begin
@time include("state_indexing_test.jl")
end
22 changes: 22 additions & 0 deletions test/state_indexing_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using SymbolicIndexingInterface

struct FakeIntegrator{S,U}
sys::S
u::U
end

SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys
SymbolicIndexingInterface.state_values(fp::FakeIntegrator) = fp.u

sys = SymbolCache([:x, :y, :z], [:a, :b], [:t])
u = [1.0, 2.0, 3.0]
fi = FakeIntegrator(sys, copy(u))
for (i, sym) in [(1, :x), (2, :y), (3, :z), ([1, 2], [:x, :y]), ((3, 2), (:z, :y))]
get = getu(sys, sym)
set! = setu(sys, sym)
true_value = i isa Tuple ? getindex.((u,), i) : u[i]
@test get(fi) == true_value
set!(fi, 0.5 .* i)
@test get(fi) == 0.5 .* i
set!(fi, true_value)
end
3 changes: 3 additions & 0 deletions test/symbol_cache_test.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using SymbolicIndexingInterface
using Test

sc = SymbolCache([:x, :y, :z], [:a, :b], [:t])

@test all(is_variable.((sc,), [:x, :y, :z]))
Expand Down

0 comments on commit 8fcceb2

Please sign in to comment.