Skip to content

Commit

Permalink
Merge pull request #65 from SciML/as/batched-getu
Browse files Browse the repository at this point in the history
feat: add `BatchedInterface`
  • Loading branch information
ChrisRackauckas authored Apr 11, 2024
2 parents f43b850 + 6dba9ff commit 063783b
Show file tree
Hide file tree
Showing 9 changed files with 433 additions and 27 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ jobs:
strategy:
matrix:
group:
- All
- Core
- Downstream
version:
- '1'
steps:
Expand All @@ -32,3 +33,5 @@ jobs:
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
GROUP: ${{ matrix.group }}
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ Accessors = "0.1.36"
Aqua = "0.8"
ArrayInterface = "7.9"
MacroTools = "0.5.13"
RuntimeGeneratedFunctions = "0.5"
Pkg = "1"
RuntimeGeneratedFunctions = "0.5.12"
SafeTestsets = "0.0.1"
StaticArrays = "1.9"
StaticArraysCore = "1.4"
Expand All @@ -24,9 +25,10 @@ julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "SafeTestsets", "StaticArrays"]
test = ["Aqua", "Pkg", "Test", "SafeTestsets", "StaticArrays"]
7 changes: 7 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,10 @@ symbolic_evaluate
SymbolCache
ProblemState
```

### Batched Queries and Updates

```@docs
BatchedInterface
associated_systems
```
3 changes: 3 additions & 0 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ include("parameter_indexing.jl")
export state_values, set_state!, current_time, getu, setu
include("state_indexing.jl")

export BatchedInterface, associated_systems
include("batched_interface.jl")

export ProblemState
include("problem_state.jl")

Expand Down
269 changes: 269 additions & 0 deletions src/batched_interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
"""
struct BatchedInterface{S <: AbstractVector, I}
function BatchedInterface(syssyms::Tuple...)
A struct which stores information for batched calls to [`getu`](@ref) or [`setu`](@ref).
Given `Tuple`s, where the first element of each tuple is a system and the second an
array of symbols (either variables or parameters) in the system, `BatchedInterface` will
compute the union of all symbols and associate each symbol with the first system with
which it occurs.
For example, given two systems `s1 = SymbolCache([:x, :y, :z])` and
`s2 = SymbolCache([:y, :z, :w])`, `BatchedInterface((s1, [:x, :y]), (s2, [:y, :z]))` will
associate `:x` and `:y` with `s1` and `:z` with `s2`. The information that `s1` had
associated symbols `:x` and `:y` and `s2` had associated symbols `:y` and `:z` will also
be retained internally.
`BatchedInterface` implements [`variable_symbols`](@ref), [`is_variable`](@ref),
[`variable_index`](@ref) to query the order of symbols in the union.
See [`getu`](@ref) and [`setu`](@ref) for further details.
See also: [`associated_systems`](@ref).
"""
struct BatchedInterface{S <: AbstractVector, I, T}
"Order of symbols in the union."
symbol_order::S
"Index of the system each symbol in the union is associated with."
associated_systems::Vector{Int}
"Index of symbol in the system it is associated with."
associated_indexes::I
"Whether the symbol is a state in the system it is associated with."
isstate::BitVector
"Map from system to indexes of its symbols in the union."
system_to_symbol_subset::Vector{Vector{Int}}
"Map from system to indexes of its symbols in the system."
system_to_symbol_indexes::Vector{Vector{T}}
"Map from system to whether each of its symbols is a state in the system."
system_to_isstate::Vector{BitVector}
end

function BatchedInterface(syssyms::Tuple...)
symbol_order = []
associated_systems = Int[]
associated_indexes = []
isstate = BitVector()
system_to_symbol_subset = Vector{Int}[]
system_to_symbol_indexes = []
system_to_isstate = BitVector[]
for (i, (sys, syms)) in enumerate(syssyms)
symbol_subset = Int[]
symbol_indexes = []
system_isstate = BitVector()
allsyms = []
for sym in syms
if symbolic_type(sym) === NotSymbolic()
error("Only symbolic variables allowed in BatchedInterface.")
end
if symbolic_type(sym) === ArraySymbolic()
append!(allsyms, collect(sym))
else
push!(allsyms, sym)
end
end
for sym in allsyms
if !is_variable(sys, sym) && !is_parameter(sys, sym)
error("Only variables and parameters allowed in BatchedInterface.")
end
if !any(isequal(sym), symbol_order)
push!(symbol_order, sym)
push!(associated_systems, i)
push!(isstate, is_variable(sys, sym))
if isstate[end]
push!(associated_indexes, variable_index(sys, sym))
else
push!(associated_indexes, parameter_index(sys, sym))
end
end
push!(symbol_subset, findfirst(isequal(sym), symbol_order))
push!(system_isstate, is_variable(sys, sym))
push!(symbol_indexes,
system_isstate[end] ? variable_index(sys, sym) : parameter_index(sys, sym))
end
push!(system_to_symbol_subset, symbol_subset)
push!(system_to_symbol_indexes, identity.(symbol_indexes))
push!(system_to_isstate, system_isstate)
end
symbol_order = identity.(symbol_order)
associated_indexes = identity.(associated_indexes)
system_to_symbol_indexes = identity.(system_to_symbol_indexes)

return BatchedInterface{typeof(symbol_order), typeof(associated_indexes),
eltype(eltype(system_to_symbol_indexes))}(
symbol_order, associated_systems, associated_indexes, isstate,
system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate)
end

variable_symbols(bi::BatchedInterface) = bi.symbol_order
variable_index(bi::BatchedInterface, sym) = findfirst(isequal(sym), bi.symbol_order)
is_variable(bi::BatchedInterface, sym) = variable_index(bi, sym) !== nothing

"""
associated_systems(bi::BatchedInterface)
Return an array of integers of the same length as `variable_symbols(bi)` where each value
is the index of the system associated with the corresponding symbol in
`variable_symbols(bi)`.
"""
associated_systems(bi::BatchedInterface) = bi.associated_systems

"""
getu(bi::BatchedInterface)
Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols),
return a function which takes `n` corresponding problems and returns an array of the values
of the symbols in the union. The returned function can also be passed an `AbstractArray` of
the appropriate `eltype` and size as its first argument, in which case the operation will
populate the array in-place with the values of the symbols in the union.
Note that all of the problems passed to the function returned by `getu` must satisfy
`is_timeseries(prob) === NotTimeseries()`.
The value of the `i`th symbol in the union (obtained through `variable_symbols(bi)[i]`) is
obtained from the problem corresponding to the associated system (i.e. the problem at
index `associated_systems(bi)[i]`).
See also: [`variable_symbols`](@ref), [`associated_systems`](@ref), [`is_timeseries`](@ref),
[`NotTimeseries`](@ref).
"""
function getu(bi::BatchedInterface)
numprobs = length(bi.system_to_symbol_subset)
probnames = [Symbol(:prob, i) for i in 1:numprobs]

fnbody = quote end
for (i, (prob, idx, isstate)) in enumerate(zip(
bi.associated_systems, bi.associated_indexes, bi.isstate))
symname = Symbol(:sym, i)
getter = isstate ? state_values : parameter_values
probname = probnames[prob]
push!(fnbody.args, :($symname = $getter($probname, $idx)))
end

oop_expr = Expr(:vect)
for i in 1:length(bi.symbol_order)
push!(oop_expr.args, Symbol(:sym, i))
end

iip_expr = quote end
for i in 1:length(bi.symbol_order)
symname = Symbol(:sym, i)
push!(iip_expr.args, :(out[$i] = $symname))
end

oopfn = Expr(
:function,
Expr(:tuple, probnames...),
quote
$fnbody
$oop_expr
end
)
iipfn = Expr(
:function,
Expr(:tuple, :out, probnames...),
quote
$fnbody
$iip_expr
out
end
)

return let oop = @RuntimeGeneratedFunction(oopfn),
iip = @RuntimeGeneratedFunction(iipfn)

_getter(probs...) = oop(probs...)
_getter(out::AbstractArray, probs...) = iip(out, probs...)
_getter
end
end

"""
setu(bi::BatchedInterface)
Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols),
return a function which takes `n` corresponding problems and an array of the values, and
updates each of the problems with the values of the corresponding symbols.
Note that all of the problems passed to the function returned by `setu` must satisfy
`is_timeseries(prob) === NotTimeseries()`.
Note that if any subset of the `n` systems share common symbols (among those passed to
`BatchedInterface`) then all of the corresponding problems in the subset will be updated
with the values of the common symbols.
See also: [`is_timeseries`](@ref), [`NotTimeseries`](@ref).
"""
function setu(bi::BatchedInterface)
numprobs = length(bi.system_to_symbol_subset)
probnames = [Symbol(:prob, i) for i in 1:numprobs]

full_update_fnexpr = let fnbody = quote end
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset)
probname = probnames[sys_idx]
for (idx_in_subset, idx_in_union) in enumerate(subset)
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
isstate = bi.system_to_isstate[sys_idx][idx_in_subset]
setter = isstate ? set_state! : set_parameter!
push!(fnbody.args, :($setter($probname, vals[$idx_in_union], $idx)))
end
# also run hook
if !all(bi.system_to_isstate[sys_idx])
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
for idx_in_subset in 1:length(subset)
if !bi.system_to_isstate[sys_idx][idx_in_subset]]
push!(fnbody.args, :($finalize_parameters_hook!($probname, $paramidxs)))
end
end
push!(fnbody.args, :(return vals))
Expr(
:function,
Expr(:tuple, probnames..., :vals),
fnbody
)
end

partial_update_fnexpr = let fnbody = quote end
curfnbody = fnbody
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset)
newcurfnbody = if sys_idx == 1
Expr(:if, :(idx == $sys_idx))
else
Expr(:elseif, :(idx == $sys_idx))
end
push!(curfnbody.args, newcurfnbody)
curfnbody = newcurfnbody

ifbody = quote end
push!(curfnbody.args, ifbody)

probname = :prob
for (idx_in_subset, idx_in_union) in enumerate(subset)
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
isstate = bi.system_to_isstate[sys_idx][idx_in_subset]
setter = isstate ? set_state! : set_parameter!
push!(ifbody.args, :($setter($probname, vals[$idx_in_union], $idx)))
end
# also run hook
if !all(bi.system_to_isstate[sys_idx])
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
for idx_in_subset in 1:length(subset)
if !bi.system_to_isstate[sys_idx][idx_in_subset]]
push!(ifbody.args, :($finalize_parameters_hook!($probname, $paramidxs)))
end
end
push!(curfnbody.args, :(error("Invalid problem index $idx")))
push!(fnbody.args, :(return nothing))
Expr(
:function,
Expr(:tuple, :prob, :idx, :vals),
fnbody
)
end
return let full_update = @RuntimeGeneratedFunction(full_update_fnexpr),
partial_update = @RuntimeGeneratedFunction(partial_update_fnexpr)

setter!(args...) = full_update(args...)
setter!(prob, idx::Int, vals::AbstractVector) = partial_update(prob, idx, vals)
setter!
end
end
56 changes: 56 additions & 0 deletions test/batched_interface_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using SymbolicIndexingInterface

syss = [
SymbolCache([:x, :y, :z], [:a, :b, :c], :t),
SymbolCache([:z, :w, :v], [:c, :e, :f]),
SymbolCache([:w, :x, :u], [:e, :a, :f])
]
syms = [
[:x, :z, :b, :c],
[:z, :w, :c, :f],
[:w, :x, :e, :a]
]
probs = [
ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3]),
ProblemState(; u = [4.0, 5.0, 6.0], p = [0.4, 0.5, 0.6]),
ProblemState(; u = [7.0, 8.0, 9.0], p = [0.7, 0.8, 0.9])
]

@test_throws ErrorException BatchedInterface((syss[1], [:x, 3]))
@test_throws ErrorException BatchedInterface((syss[1], [:(x + y)]))
@test_throws ErrorException BatchedInterface((syss[1], [:t]))

bi = BatchedInterface(zip(syss, syms)...)
@test variable_symbols(bi) == [:x, :z, :b, :c, :w, :f, :e, :a]
@test variable_index.((bi,), [:a, :b, :c, :e, :f, :x, :y, :z, :w, :v, :u]) ==
[8, 3, 4, 7, 6, 1, nothing, 2, 5, nothing, nothing]
@test is_variable.((bi,), [:a, :b, :c, :e, :f, :x, :y, :z, :w, :v, :u]) ==
Bool[1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0]
@test associated_systems(bi) == [1, 1, 1, 1, 2, 2, 3, 3]

getter = getu(bi)
@test (@inferred getter(probs...)) == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8]
buf = zeros(8)
@inferred getter(buf, probs...)
@test buf == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8]

setter! = setu(bi)
buf .*= 100
setter!(probs..., buf)

@test state_values(probs[1]) == [100.0, 2.0, 300.0]
# :a isn't updated here because it wasn't part of the symbols associated with syss[1] (syms[1])
@test parameter_values(probs[1]) == [0.1, 20.0, 30.0]
@test state_values(probs[2]) == [300.0, 500.0, 6.0]
# Similarly for :e
@test parameter_values(probs[2]) == [30.0, 0.5, 60.0]
@test state_values(probs[3]) == [500.0, 100.0, 9.0]
# Similarly for :f
@test parameter_values(probs[3]) == [70.0, 80.0, 0.9]

buf ./= 100
setter!(probs[1], 1, buf)
@test state_values(probs[1]) == [1.0, 2.0, 3.0]
@test parameter_values(probs[1]) == [0.1, 0.2, 0.3]

@test_throws ErrorException setter!(probs[1], 4, buf)
3 changes: 3 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[deps]
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Loading

0 comments on commit 063783b

Please sign in to comment.