-
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #23 from SciML/as/getu
feat: add `getu`/`setu`, update docs, add tests
- Loading branch information
Showing
12 changed files
with
256 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
""" | ||
state_values(p) | ||
Return an indexable collection containing the values of all states in the integrator or | ||
problem `p`. | ||
""" | ||
function state_values end | ||
|
||
""" | ||
current_time(p) | ||
Return the current time in the integrator or problem `p`. | ||
""" | ||
function current_time end | ||
|
||
""" | ||
getu(sys, sym) | ||
Return a function that takes an integrator or problem of `sys`, and returns the value of | ||
the symbolic `sym`. `sym` can be a direct index into the state vector, a symbolic state, | ||
a symbolic expression involving symbolic quantities in the system `sys`, or an | ||
array/tuple of the aforementioned. | ||
At minimum, this requires that the integrator or problem implement [`state_values`](@ref). | ||
To support symbolic expressions, the integrator or problem must implement | ||
[`observed`](@ref), [`parameter_values`](@ref) and [`current_time`](@ref). | ||
This function typically does not need to be implemented, and has a default implementation | ||
relying on the above functions. | ||
""" | ||
function getu(sys, sym) | ||
symtype = symbolic_type(sym) | ||
elsymtype = symbolic_type(eltype(sym)) | ||
|
||
if symtype != NotSymbolic() | ||
_getu(sys, symtype, sym) | ||
else | ||
_getu(sys, elsymtype, sym) | ||
end | ||
end | ||
|
||
function _getu(sys, ::NotSymbolic, sym) | ||
return function getter(prob) | ||
return state_values(prob)[sym] | ||
end | ||
end | ||
|
||
function _getu(sys, ::ScalarSymbolic, sym) | ||
if is_variable(sys, sym) | ||
idx = variable_index(sys, sym) | ||
return function getter1(prob) | ||
return state_values(prob)[idx] | ||
end | ||
elseif is_observed(sys, sym) | ||
fn = observed(sys, sym) | ||
if is_time_dependent(sys) | ||
function getter2(prob) | ||
return fn(state_values(prob), parameter_values(prob), current_time(prob)) | ||
end | ||
else | ||
function getter3(prob) | ||
return fn(state_values(prob), parameter_values(prob)) | ||
end | ||
end | ||
end | ||
error("Invalid symbol $sym for `getu`") | ||
end | ||
|
||
function _getu(sys, ::ScalarSymbolic, sym::Union{<:Tuple,<:AbstractArray}) | ||
getters = getu.((sys,), sym) | ||
_call(getter, prob) = getter(prob) | ||
return function getter(prob) | ||
return _call.(getters, (prob,)) | ||
end | ||
end | ||
|
||
function _getu(sys, ::ArraySymbolic, sym) | ||
return getu(sys, collect(sym)) | ||
end | ||
|
||
""" | ||
setu(sys, sym) | ||
Return a function that takes an integrator or problem of `sys` and a value, and sets the | ||
the state `sym` to that value. Note that `sym` can be a direct numerical index, a symbolic state, or an array/tuple of the aforementioned. | ||
Requires that the integrator implement [`state_values`](@ref) and the | ||
returned collection be a mutable reference to the state vector in the integrator/problem. | ||
In case `state_values` cannot return such a mutable reference, `setu` needs to be | ||
implemented manually. | ||
""" | ||
function setu(sys, sym) | ||
symtype = symbolic_type(sym) | ||
elsymtype = symbolic_type(eltype(sym)) | ||
|
||
if symtype != NotSymbolic() | ||
_setu(sys, symtype, sym) | ||
else | ||
_setu(sys, elsymtype, sym) | ||
end | ||
end | ||
|
||
function _setu(sys, ::NotSymbolic, sym) | ||
return function setter!(prob, val) | ||
state_values(prob)[sym] = val | ||
end | ||
end | ||
|
||
function _setu(sys, ::ScalarSymbolic, sym) | ||
is_variable(sys, sym) || error("Invalid symbol $sym for `setu`") | ||
idx = variable_index(sys, sym) | ||
return function setter!(prob, val) | ||
state_values(prob)[idx] = val | ||
end | ||
end | ||
|
||
function _setu(sys, ::ScalarSymbolic, sym::Union{<:Tuple,<:AbstractArray}) | ||
setters = setu.((sys,), sym) | ||
_call!(setter!, prob, val) = setter!(prob, val) | ||
return function setter!(prob, val) | ||
_call!.(setters, (prob,), val) | ||
end | ||
end | ||
|
||
function _setu(sys, ::ArraySymbolic, sym) | ||
return setu(sys, collect(sym)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
using SymbolicIndexingInterface | ||
using Test | ||
|
||
struct SystemMockup | ||
static::Bool | ||
vars::Vector{Symbol} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
using SymbolicIndexingInterface | ||
using Test | ||
|
||
struct Wrapper{W} | ||
wrapped::W | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,25 @@ | ||
using SymbolicIndexingInterface | ||
using SafeTestsets | ||
using Test | ||
@testset "Quality Assurance" begin | ||
|
||
@safetestset "Quality Assurance" begin | ||
@time include("qa.jl") | ||
end | ||
@testset "Interface test" begin | ||
@safetestset "Interface test" begin | ||
@time include("example_test.jl") | ||
end | ||
@testset "Trait test" begin | ||
@safetestset "Trait test" begin | ||
@time include("trait_test.jl") | ||
end | ||
@testset "SymbolCache test" begin | ||
@safetestset "SymbolCache test" begin | ||
@time include("symbol_cache_test.jl") | ||
end | ||
@testset "Fallback test" begin | ||
@safetestset "Fallback test" begin | ||
@time include("fallback_test.jl") | ||
end | ||
@testset "Parameter indexing test" begin | ||
@safetestset "Parameter indexing test" begin | ||
@time include("parameter_indexing_test.jl") | ||
end | ||
@safetestset "State indexing test" begin | ||
@time include("state_indexing_test.jl") | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
using SymbolicIndexingInterface | ||
|
||
struct FakeIntegrator{S,U} | ||
sys::S | ||
u::U | ||
end | ||
|
||
SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys | ||
SymbolicIndexingInterface.state_values(fp::FakeIntegrator) = fp.u | ||
|
||
sys = SymbolCache([:x, :y, :z], [:a, :b], [:t]) | ||
u = [1.0, 2.0, 3.0] | ||
fi = FakeIntegrator(sys, copy(u)) | ||
for (i, sym) in [(1, :x), (2, :y), (3, :z), ([1, 2], [:x, :y]), ((3, 2), (:z, :y))] | ||
get = getu(sys, sym) | ||
set! = setu(sys, sym) | ||
true_value = i isa Tuple ? getindex.((u,), i) : u[i] | ||
@test get(fi) == true_value | ||
set!(fi, 0.5 .* i) | ||
@test get(fi) == 0.5 .* i | ||
set!(fi, true_value) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters