-
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #65 from SciML/as/batched-getu
feat: add `BatchedInterface`
- Loading branch information
Showing
9 changed files
with
433 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,269 @@ | ||
""" | ||
struct BatchedInterface{S <: AbstractVector, I} | ||
function BatchedInterface(syssyms::Tuple...) | ||
A struct which stores information for batched calls to [`getu`](@ref) or [`setu`](@ref). | ||
Given `Tuple`s, where the first element of each tuple is a system and the second an | ||
array of symbols (either variables or parameters) in the system, `BatchedInterface` will | ||
compute the union of all symbols and associate each symbol with the first system with | ||
which it occurs. | ||
For example, given two systems `s1 = SymbolCache([:x, :y, :z])` and | ||
`s2 = SymbolCache([:y, :z, :w])`, `BatchedInterface((s1, [:x, :y]), (s2, [:y, :z]))` will | ||
associate `:x` and `:y` with `s1` and `:z` with `s2`. The information that `s1` had | ||
associated symbols `:x` and `:y` and `s2` had associated symbols `:y` and `:z` will also | ||
be retained internally. | ||
`BatchedInterface` implements [`variable_symbols`](@ref), [`is_variable`](@ref), | ||
[`variable_index`](@ref) to query the order of symbols in the union. | ||
See [`getu`](@ref) and [`setu`](@ref) for further details. | ||
See also: [`associated_systems`](@ref). | ||
""" | ||
struct BatchedInterface{S <: AbstractVector, I, T} | ||
"Order of symbols in the union." | ||
symbol_order::S | ||
"Index of the system each symbol in the union is associated with." | ||
associated_systems::Vector{Int} | ||
"Index of symbol in the system it is associated with." | ||
associated_indexes::I | ||
"Whether the symbol is a state in the system it is associated with." | ||
isstate::BitVector | ||
"Map from system to indexes of its symbols in the union." | ||
system_to_symbol_subset::Vector{Vector{Int}} | ||
"Map from system to indexes of its symbols in the system." | ||
system_to_symbol_indexes::Vector{Vector{T}} | ||
"Map from system to whether each of its symbols is a state in the system." | ||
system_to_isstate::Vector{BitVector} | ||
end | ||
|
||
function BatchedInterface(syssyms::Tuple...) | ||
symbol_order = [] | ||
associated_systems = Int[] | ||
associated_indexes = [] | ||
isstate = BitVector() | ||
system_to_symbol_subset = Vector{Int}[] | ||
system_to_symbol_indexes = [] | ||
system_to_isstate = BitVector[] | ||
for (i, (sys, syms)) in enumerate(syssyms) | ||
symbol_subset = Int[] | ||
symbol_indexes = [] | ||
system_isstate = BitVector() | ||
allsyms = [] | ||
for sym in syms | ||
if symbolic_type(sym) === NotSymbolic() | ||
error("Only symbolic variables allowed in BatchedInterface.") | ||
end | ||
if symbolic_type(sym) === ArraySymbolic() | ||
append!(allsyms, collect(sym)) | ||
else | ||
push!(allsyms, sym) | ||
end | ||
end | ||
for sym in allsyms | ||
if !is_variable(sys, sym) && !is_parameter(sys, sym) | ||
error("Only variables and parameters allowed in BatchedInterface.") | ||
end | ||
if !any(isequal(sym), symbol_order) | ||
push!(symbol_order, sym) | ||
push!(associated_systems, i) | ||
push!(isstate, is_variable(sys, sym)) | ||
if isstate[end] | ||
push!(associated_indexes, variable_index(sys, sym)) | ||
else | ||
push!(associated_indexes, parameter_index(sys, sym)) | ||
end | ||
end | ||
push!(symbol_subset, findfirst(isequal(sym), symbol_order)) | ||
push!(system_isstate, is_variable(sys, sym)) | ||
push!(symbol_indexes, | ||
system_isstate[end] ? variable_index(sys, sym) : parameter_index(sys, sym)) | ||
end | ||
push!(system_to_symbol_subset, symbol_subset) | ||
push!(system_to_symbol_indexes, identity.(symbol_indexes)) | ||
push!(system_to_isstate, system_isstate) | ||
end | ||
symbol_order = identity.(symbol_order) | ||
associated_indexes = identity.(associated_indexes) | ||
system_to_symbol_indexes = identity.(system_to_symbol_indexes) | ||
|
||
return BatchedInterface{typeof(symbol_order), typeof(associated_indexes), | ||
eltype(eltype(system_to_symbol_indexes))}( | ||
symbol_order, associated_systems, associated_indexes, isstate, | ||
system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate) | ||
end | ||
|
||
variable_symbols(bi::BatchedInterface) = bi.symbol_order | ||
variable_index(bi::BatchedInterface, sym) = findfirst(isequal(sym), bi.symbol_order) | ||
is_variable(bi::BatchedInterface, sym) = variable_index(bi, sym) !== nothing | ||
|
||
""" | ||
associated_systems(bi::BatchedInterface) | ||
Return an array of integers of the same length as `variable_symbols(bi)` where each value | ||
is the index of the system associated with the corresponding symbol in | ||
`variable_symbols(bi)`. | ||
""" | ||
associated_systems(bi::BatchedInterface) = bi.associated_systems | ||
|
||
""" | ||
getu(bi::BatchedInterface) | ||
Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols), | ||
return a function which takes `n` corresponding problems and returns an array of the values | ||
of the symbols in the union. The returned function can also be passed an `AbstractArray` of | ||
the appropriate `eltype` and size as its first argument, in which case the operation will | ||
populate the array in-place with the values of the symbols in the union. | ||
Note that all of the problems passed to the function returned by `getu` must satisfy | ||
`is_timeseries(prob) === NotTimeseries()`. | ||
The value of the `i`th symbol in the union (obtained through `variable_symbols(bi)[i]`) is | ||
obtained from the problem corresponding to the associated system (i.e. the problem at | ||
index `associated_systems(bi)[i]`). | ||
See also: [`variable_symbols`](@ref), [`associated_systems`](@ref), [`is_timeseries`](@ref), | ||
[`NotTimeseries`](@ref). | ||
""" | ||
function getu(bi::BatchedInterface) | ||
numprobs = length(bi.system_to_symbol_subset) | ||
probnames = [Symbol(:prob, i) for i in 1:numprobs] | ||
|
||
fnbody = quote end | ||
for (i, (prob, idx, isstate)) in enumerate(zip( | ||
bi.associated_systems, bi.associated_indexes, bi.isstate)) | ||
symname = Symbol(:sym, i) | ||
getter = isstate ? state_values : parameter_values | ||
probname = probnames[prob] | ||
push!(fnbody.args, :($symname = $getter($probname, $idx))) | ||
end | ||
|
||
oop_expr = Expr(:vect) | ||
for i in 1:length(bi.symbol_order) | ||
push!(oop_expr.args, Symbol(:sym, i)) | ||
end | ||
|
||
iip_expr = quote end | ||
for i in 1:length(bi.symbol_order) | ||
symname = Symbol(:sym, i) | ||
push!(iip_expr.args, :(out[$i] = $symname)) | ||
end | ||
|
||
oopfn = Expr( | ||
:function, | ||
Expr(:tuple, probnames...), | ||
quote | ||
$fnbody | ||
$oop_expr | ||
end | ||
) | ||
iipfn = Expr( | ||
:function, | ||
Expr(:tuple, :out, probnames...), | ||
quote | ||
$fnbody | ||
$iip_expr | ||
out | ||
end | ||
) | ||
|
||
return let oop = @RuntimeGeneratedFunction(oopfn), | ||
iip = @RuntimeGeneratedFunction(iipfn) | ||
|
||
_getter(probs...) = oop(probs...) | ||
_getter(out::AbstractArray, probs...) = iip(out, probs...) | ||
_getter | ||
end | ||
end | ||
|
||
""" | ||
setu(bi::BatchedInterface) | ||
Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols), | ||
return a function which takes `n` corresponding problems and an array of the values, and | ||
updates each of the problems with the values of the corresponding symbols. | ||
Note that all of the problems passed to the function returned by `setu` must satisfy | ||
`is_timeseries(prob) === NotTimeseries()`. | ||
Note that if any subset of the `n` systems share common symbols (among those passed to | ||
`BatchedInterface`) then all of the corresponding problems in the subset will be updated | ||
with the values of the common symbols. | ||
See also: [`is_timeseries`](@ref), [`NotTimeseries`](@ref). | ||
""" | ||
function setu(bi::BatchedInterface) | ||
numprobs = length(bi.system_to_symbol_subset) | ||
probnames = [Symbol(:prob, i) for i in 1:numprobs] | ||
|
||
full_update_fnexpr = let fnbody = quote end | ||
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset) | ||
probname = probnames[sys_idx] | ||
for (idx_in_subset, idx_in_union) in enumerate(subset) | ||
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset] | ||
isstate = bi.system_to_isstate[sys_idx][idx_in_subset] | ||
setter = isstate ? set_state! : set_parameter! | ||
push!(fnbody.args, :($setter($probname, vals[$idx_in_union], $idx))) | ||
end | ||
# also run hook | ||
if !all(bi.system_to_isstate[sys_idx]) | ||
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset] | ||
for idx_in_subset in 1:length(subset) | ||
if !bi.system_to_isstate[sys_idx][idx_in_subset]] | ||
push!(fnbody.args, :($finalize_parameters_hook!($probname, $paramidxs))) | ||
end | ||
end | ||
push!(fnbody.args, :(return vals)) | ||
Expr( | ||
:function, | ||
Expr(:tuple, probnames..., :vals), | ||
fnbody | ||
) | ||
end | ||
|
||
partial_update_fnexpr = let fnbody = quote end | ||
curfnbody = fnbody | ||
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset) | ||
newcurfnbody = if sys_idx == 1 | ||
Expr(:if, :(idx == $sys_idx)) | ||
else | ||
Expr(:elseif, :(idx == $sys_idx)) | ||
end | ||
push!(curfnbody.args, newcurfnbody) | ||
curfnbody = newcurfnbody | ||
|
||
ifbody = quote end | ||
push!(curfnbody.args, ifbody) | ||
|
||
probname = :prob | ||
for (idx_in_subset, idx_in_union) in enumerate(subset) | ||
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset] | ||
isstate = bi.system_to_isstate[sys_idx][idx_in_subset] | ||
setter = isstate ? set_state! : set_parameter! | ||
push!(ifbody.args, :($setter($probname, vals[$idx_in_union], $idx))) | ||
end | ||
# also run hook | ||
if !all(bi.system_to_isstate[sys_idx]) | ||
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset] | ||
for idx_in_subset in 1:length(subset) | ||
if !bi.system_to_isstate[sys_idx][idx_in_subset]] | ||
push!(ifbody.args, :($finalize_parameters_hook!($probname, $paramidxs))) | ||
end | ||
end | ||
push!(curfnbody.args, :(error("Invalid problem index $idx"))) | ||
push!(fnbody.args, :(return nothing)) | ||
Expr( | ||
:function, | ||
Expr(:tuple, :prob, :idx, :vals), | ||
fnbody | ||
) | ||
end | ||
return let full_update = @RuntimeGeneratedFunction(full_update_fnexpr), | ||
partial_update = @RuntimeGeneratedFunction(partial_update_fnexpr) | ||
|
||
setter!(args...) = full_update(args...) | ||
setter!(prob, idx::Int, vals::AbstractVector) = partial_update(prob, idx, vals) | ||
setter! | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
using SymbolicIndexingInterface | ||
|
||
syss = [ | ||
SymbolCache([:x, :y, :z], [:a, :b, :c], :t), | ||
SymbolCache([:z, :w, :v], [:c, :e, :f]), | ||
SymbolCache([:w, :x, :u], [:e, :a, :f]) | ||
] | ||
syms = [ | ||
[:x, :z, :b, :c], | ||
[:z, :w, :c, :f], | ||
[:w, :x, :e, :a] | ||
] | ||
probs = [ | ||
ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3]), | ||
ProblemState(; u = [4.0, 5.0, 6.0], p = [0.4, 0.5, 0.6]), | ||
ProblemState(; u = [7.0, 8.0, 9.0], p = [0.7, 0.8, 0.9]) | ||
] | ||
|
||
@test_throws ErrorException BatchedInterface((syss[1], [:x, 3])) | ||
@test_throws ErrorException BatchedInterface((syss[1], [:(x + y)])) | ||
@test_throws ErrorException BatchedInterface((syss[1], [:t])) | ||
|
||
bi = BatchedInterface(zip(syss, syms)...) | ||
@test variable_symbols(bi) == [:x, :z, :b, :c, :w, :f, :e, :a] | ||
@test variable_index.((bi,), [:a, :b, :c, :e, :f, :x, :y, :z, :w, :v, :u]) == | ||
[8, 3, 4, 7, 6, 1, nothing, 2, 5, nothing, nothing] | ||
@test is_variable.((bi,), [:a, :b, :c, :e, :f, :x, :y, :z, :w, :v, :u]) == | ||
Bool[1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0] | ||
@test associated_systems(bi) == [1, 1, 1, 1, 2, 2, 3, 3] | ||
|
||
getter = getu(bi) | ||
@test (@inferred getter(probs...)) == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8] | ||
buf = zeros(8) | ||
@inferred getter(buf, probs...) | ||
@test buf == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8] | ||
|
||
setter! = setu(bi) | ||
buf .*= 100 | ||
setter!(probs..., buf) | ||
|
||
@test state_values(probs[1]) == [100.0, 2.0, 300.0] | ||
# :a isn't updated here because it wasn't part of the symbols associated with syss[1] (syms[1]) | ||
@test parameter_values(probs[1]) == [0.1, 20.0, 30.0] | ||
@test state_values(probs[2]) == [300.0, 500.0, 6.0] | ||
# Similarly for :e | ||
@test parameter_values(probs[2]) == [30.0, 0.5, 60.0] | ||
@test state_values(probs[3]) == [500.0, 100.0, 9.0] | ||
# Similarly for :f | ||
@test parameter_values(probs[3]) == [70.0, 80.0, 0.9] | ||
|
||
buf ./= 100 | ||
setter!(probs[1], 1, buf) | ||
@test state_values(probs[1]) == [1.0, 2.0, 3.0] | ||
@test parameter_values(probs[1]) == [0.1, 0.2, 0.3] | ||
|
||
@test_throws ErrorException setter!(probs[1], 4, buf) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[deps] | ||
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" | ||
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" |
Oops, something went wrong.