-
-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add BatchedInterface
#65
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
ac86106
feat: add `BatchedInterface`
AayushSabharwal 0e8e037
feat: support updating individual problems with BatchedInterface
AayushSabharwal d329806
feat: support array symbolics in BatchedInterface
AayushSabharwal 5aad4fa
ci: add Downstream testset to CI
AayushSabharwal 6dba9ff
chore: bump RuntimeGeneratedFunctions compat for Downgrade CI
AayushSabharwal File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,269 @@ | ||
""" | ||
struct BatchedInterface{S <: AbstractVector, I} | ||
function BatchedInterface(syssyms::Tuple...) | ||
|
||
A struct which stores information for batched calls to [`getu`](@ref) or [`setu`](@ref). | ||
Given `Tuple`s, where the first element of each tuple is a system and the second an | ||
array of symbols (either variables or parameters) in the system, `BatchedInterface` will | ||
compute the union of all symbols and associate each symbol with the first system with | ||
which it occurs. | ||
|
||
For example, given two systems `s1 = SymbolCache([:x, :y, :z])` and | ||
`s2 = SymbolCache([:y, :z, :w])`, `BatchedInterface((s1, [:x, :y]), (s2, [:y, :z]))` will | ||
associate `:x` and `:y` with `s1` and `:z` with `s2`. The information that `s1` had | ||
associated symbols `:x` and `:y` and `s2` had associated symbols `:y` and `:z` will also | ||
be retained internally. | ||
|
||
`BatchedInterface` implements [`variable_symbols`](@ref), [`is_variable`](@ref), | ||
[`variable_index`](@ref) to query the order of symbols in the union. | ||
|
||
See [`getu`](@ref) and [`setu`](@ref) for further details. | ||
|
||
See also: [`associated_systems`](@ref). | ||
""" | ||
struct BatchedInterface{S <: AbstractVector, I, T} | ||
"Order of symbols in the union." | ||
symbol_order::S | ||
"Index of the system each symbol in the union is associated with." | ||
associated_systems::Vector{Int} | ||
"Index of symbol in the system it is associated with." | ||
associated_indexes::I | ||
"Whether the symbol is a state in the system it is associated with." | ||
isstate::BitVector | ||
"Map from system to indexes of its symbols in the union." | ||
system_to_symbol_subset::Vector{Vector{Int}} | ||
"Map from system to indexes of its symbols in the system." | ||
system_to_symbol_indexes::Vector{Vector{T}} | ||
"Map from system to whether each of its symbols is a state in the system." | ||
system_to_isstate::Vector{BitVector} | ||
end | ||
|
||
function BatchedInterface(syssyms::Tuple...) | ||
symbol_order = [] | ||
associated_systems = Int[] | ||
associated_indexes = [] | ||
isstate = BitVector() | ||
system_to_symbol_subset = Vector{Int}[] | ||
system_to_symbol_indexes = [] | ||
system_to_isstate = BitVector[] | ||
for (i, (sys, syms)) in enumerate(syssyms) | ||
symbol_subset = Int[] | ||
symbol_indexes = [] | ||
system_isstate = BitVector() | ||
allsyms = [] | ||
for sym in syms | ||
if symbolic_type(sym) === NotSymbolic() | ||
error("Only symbolic variables allowed in BatchedInterface.") | ||
end | ||
if symbolic_type(sym) === ArraySymbolic() | ||
append!(allsyms, collect(sym)) | ||
else | ||
push!(allsyms, sym) | ||
end | ||
end | ||
for sym in allsyms | ||
if !is_variable(sys, sym) && !is_parameter(sys, sym) | ||
error("Only variables and parameters allowed in BatchedInterface.") | ||
end | ||
if !any(isequal(sym), symbol_order) | ||
push!(symbol_order, sym) | ||
push!(associated_systems, i) | ||
push!(isstate, is_variable(sys, sym)) | ||
if isstate[end] | ||
push!(associated_indexes, variable_index(sys, sym)) | ||
else | ||
push!(associated_indexes, parameter_index(sys, sym)) | ||
end | ||
end | ||
push!(symbol_subset, findfirst(isequal(sym), symbol_order)) | ||
push!(system_isstate, is_variable(sys, sym)) | ||
push!(symbol_indexes, | ||
system_isstate[end] ? variable_index(sys, sym) : parameter_index(sys, sym)) | ||
end | ||
push!(system_to_symbol_subset, symbol_subset) | ||
push!(system_to_symbol_indexes, identity.(symbol_indexes)) | ||
push!(system_to_isstate, system_isstate) | ||
end | ||
symbol_order = identity.(symbol_order) | ||
associated_indexes = identity.(associated_indexes) | ||
system_to_symbol_indexes = identity.(system_to_symbol_indexes) | ||
|
||
return BatchedInterface{typeof(symbol_order), typeof(associated_indexes), | ||
eltype(eltype(system_to_symbol_indexes))}( | ||
symbol_order, associated_systems, associated_indexes, isstate, | ||
system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate) | ||
end | ||
|
||
variable_symbols(bi::BatchedInterface) = bi.symbol_order | ||
variable_index(bi::BatchedInterface, sym) = findfirst(isequal(sym), bi.symbol_order) | ||
is_variable(bi::BatchedInterface, sym) = variable_index(bi, sym) !== nothing | ||
|
||
""" | ||
associated_systems(bi::BatchedInterface) | ||
|
||
Return an array of integers of the same length as `variable_symbols(bi)` where each value | ||
is the index of the system associated with the corresponding symbol in | ||
`variable_symbols(bi)`. | ||
""" | ||
associated_systems(bi::BatchedInterface) = bi.associated_systems | ||
|
||
""" | ||
getu(bi::BatchedInterface) | ||
|
||
Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols), | ||
return a function which takes `n` corresponding problems and returns an array of the values | ||
of the symbols in the union. The returned function can also be passed an `AbstractArray` of | ||
the appropriate `eltype` and size as its first argument, in which case the operation will | ||
populate the array in-place with the values of the symbols in the union. | ||
|
||
Note that all of the problems passed to the function returned by `getu` must satisfy | ||
`is_timeseries(prob) === NotTimeseries()`. | ||
|
||
The value of the `i`th symbol in the union (obtained through `variable_symbols(bi)[i]`) is | ||
obtained from the problem corresponding to the associated system (i.e. the problem at | ||
index `associated_systems(bi)[i]`). | ||
|
||
See also: [`variable_symbols`](@ref), [`associated_systems`](@ref), [`is_timeseries`](@ref), | ||
[`NotTimeseries`](@ref). | ||
""" | ||
function getu(bi::BatchedInterface) | ||
numprobs = length(bi.system_to_symbol_subset) | ||
probnames = [Symbol(:prob, i) for i in 1:numprobs] | ||
|
||
fnbody = quote end | ||
for (i, (prob, idx, isstate)) in enumerate(zip( | ||
bi.associated_systems, bi.associated_indexes, bi.isstate)) | ||
symname = Symbol(:sym, i) | ||
getter = isstate ? state_values : parameter_values | ||
probname = probnames[prob] | ||
push!(fnbody.args, :($symname = $getter($probname, $idx))) | ||
end | ||
|
||
oop_expr = Expr(:vect) | ||
for i in 1:length(bi.symbol_order) | ||
push!(oop_expr.args, Symbol(:sym, i)) | ||
end | ||
|
||
iip_expr = quote end | ||
for i in 1:length(bi.symbol_order) | ||
symname = Symbol(:sym, i) | ||
push!(iip_expr.args, :(out[$i] = $symname)) | ||
end | ||
|
||
oopfn = Expr( | ||
:function, | ||
Expr(:tuple, probnames...), | ||
quote | ||
$fnbody | ||
$oop_expr | ||
end | ||
) | ||
iipfn = Expr( | ||
:function, | ||
Expr(:tuple, :out, probnames...), | ||
quote | ||
$fnbody | ||
$iip_expr | ||
out | ||
end | ||
) | ||
|
||
return let oop = @RuntimeGeneratedFunction(oopfn), | ||
iip = @RuntimeGeneratedFunction(iipfn) | ||
|
||
_getter(probs...) = oop(probs...) | ||
_getter(out::AbstractArray, probs...) = iip(out, probs...) | ||
_getter | ||
end | ||
end | ||
|
||
""" | ||
setu(bi::BatchedInterface) | ||
|
||
Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols), | ||
return a function which takes `n` corresponding problems and an array of the values, and | ||
updates each of the problems with the values of the corresponding symbols. | ||
|
||
Note that all of the problems passed to the function returned by `setu` must satisfy | ||
`is_timeseries(prob) === NotTimeseries()`. | ||
|
||
Note that if any subset of the `n` systems share common symbols (among those passed to | ||
`BatchedInterface`) then all of the corresponding problems in the subset will be updated | ||
with the values of the common symbols. | ||
|
||
See also: [`is_timeseries`](@ref), [`NotTimeseries`](@ref). | ||
""" | ||
function setu(bi::BatchedInterface) | ||
numprobs = length(bi.system_to_symbol_subset) | ||
probnames = [Symbol(:prob, i) for i in 1:numprobs] | ||
|
||
full_update_fnexpr = let fnbody = quote end | ||
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset) | ||
probname = probnames[sys_idx] | ||
for (idx_in_subset, idx_in_union) in enumerate(subset) | ||
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset] | ||
isstate = bi.system_to_isstate[sys_idx][idx_in_subset] | ||
setter = isstate ? set_state! : set_parameter! | ||
push!(fnbody.args, :($setter($probname, vals[$idx_in_union], $idx))) | ||
end | ||
# also run hook | ||
if !all(bi.system_to_isstate[sys_idx]) | ||
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset] | ||
for idx_in_subset in 1:length(subset) | ||
if !bi.system_to_isstate[sys_idx][idx_in_subset]] | ||
push!(fnbody.args, :($finalize_parameters_hook!($probname, $paramidxs))) | ||
end | ||
end | ||
push!(fnbody.args, :(return vals)) | ||
Expr( | ||
:function, | ||
Expr(:tuple, probnames..., :vals), | ||
fnbody | ||
) | ||
end | ||
|
||
partial_update_fnexpr = let fnbody = quote end | ||
curfnbody = fnbody | ||
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset) | ||
newcurfnbody = if sys_idx == 1 | ||
Expr(:if, :(idx == $sys_idx)) | ||
else | ||
Expr(:elseif, :(idx == $sys_idx)) | ||
end | ||
push!(curfnbody.args, newcurfnbody) | ||
curfnbody = newcurfnbody | ||
|
||
ifbody = quote end | ||
push!(curfnbody.args, ifbody) | ||
|
||
probname = :prob | ||
for (idx_in_subset, idx_in_union) in enumerate(subset) | ||
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset] | ||
isstate = bi.system_to_isstate[sys_idx][idx_in_subset] | ||
setter = isstate ? set_state! : set_parameter! | ||
push!(ifbody.args, :($setter($probname, vals[$idx_in_union], $idx))) | ||
end | ||
# also run hook | ||
if !all(bi.system_to_isstate[sys_idx]) | ||
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset] | ||
for idx_in_subset in 1:length(subset) | ||
if !bi.system_to_isstate[sys_idx][idx_in_subset]] | ||
push!(ifbody.args, :($finalize_parameters_hook!($probname, $paramidxs))) | ||
end | ||
end | ||
push!(curfnbody.args, :(error("Invalid problem index $idx"))) | ||
push!(fnbody.args, :(return nothing)) | ||
Expr( | ||
:function, | ||
Expr(:tuple, :prob, :idx, :vals), | ||
fnbody | ||
) | ||
end | ||
return let full_update = @RuntimeGeneratedFunction(full_update_fnexpr), | ||
partial_update = @RuntimeGeneratedFunction(partial_update_fnexpr) | ||
|
||
setter!(args...) = full_update(args...) | ||
setter!(prob, idx::Int, vals::AbstractVector) = partial_update(prob, idx, vals) | ||
setter! | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
using SymbolicIndexingInterface | ||
|
||
syss = [ | ||
SymbolCache([:x, :y, :z], [:a, :b, :c], :t), | ||
SymbolCache([:z, :w, :v], [:c, :e, :f]), | ||
SymbolCache([:w, :x, :u], [:e, :a, :f]) | ||
] | ||
syms = [ | ||
[:x, :z, :b, :c], | ||
[:z, :w, :c, :f], | ||
[:w, :x, :e, :a] | ||
] | ||
probs = [ | ||
ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3]), | ||
ProblemState(; u = [4.0, 5.0, 6.0], p = [0.4, 0.5, 0.6]), | ||
ProblemState(; u = [7.0, 8.0, 9.0], p = [0.7, 0.8, 0.9]) | ||
] | ||
|
||
@test_throws ErrorException BatchedInterface((syss[1], [:x, 3])) | ||
@test_throws ErrorException BatchedInterface((syss[1], [:(x + y)])) | ||
@test_throws ErrorException BatchedInterface((syss[1], [:t])) | ||
|
||
bi = BatchedInterface(zip(syss, syms)...) | ||
@test variable_symbols(bi) == [:x, :z, :b, :c, :w, :f, :e, :a] | ||
@test variable_index.((bi,), [:a, :b, :c, :e, :f, :x, :y, :z, :w, :v, :u]) == | ||
[8, 3, 4, 7, 6, 1, nothing, 2, 5, nothing, nothing] | ||
@test is_variable.((bi,), [:a, :b, :c, :e, :f, :x, :y, :z, :w, :v, :u]) == | ||
Bool[1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0] | ||
@test associated_systems(bi) == [1, 1, 1, 1, 2, 2, 3, 3] | ||
|
||
getter = getu(bi) | ||
@test (@inferred getter(probs...)) == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't seem quite right, you'd want an array for each to make remaking easy? |
||
buf = zeros(8) | ||
@inferred getter(buf, probs...) | ||
@test buf == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8] | ||
|
||
setter! = setu(bi) | ||
buf .*= 100 | ||
setter!(probs..., buf) | ||
|
||
@test state_values(probs[1]) == [100.0, 2.0, 300.0] | ||
# :a isn't updated here because it wasn't part of the symbols associated with syss[1] (syms[1]) | ||
@test parameter_values(probs[1]) == [0.1, 20.0, 30.0] | ||
@test state_values(probs[2]) == [300.0, 500.0, 6.0] | ||
# Similarly for :e | ||
@test parameter_values(probs[2]) == [30.0, 0.5, 60.0] | ||
@test state_values(probs[3]) == [500.0, 100.0, 9.0] | ||
# Similarly for :f | ||
@test parameter_values(probs[3]) == [70.0, 80.0, 0.9] | ||
|
||
buf ./= 100 | ||
setter!(probs[1], 1, buf) | ||
@test state_values(probs[1]) == [1.0, 2.0, 3.0] | ||
@test parameter_values(probs[1]) == [0.1, 0.2, 0.3] | ||
|
||
@test_throws ErrorException setter!(probs[1], 4, buf) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[deps] | ||
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" | ||
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This removes the need for
ArrayPartition
s?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes