Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add BatchedInterface #65

Merged
merged 5 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ jobs:
strategy:
matrix:
group:
- All
- Core
- Downstream
version:
- '1'
steps:
Expand All @@ -32,3 +33,5 @@ jobs:
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
GROUP: ${{ matrix.group }}
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Accessors = "0.1.36"
Aqua = "0.8"
ArrayInterface = "7.9"
MacroTools = "0.5.13"
Pkg = "1"
RuntimeGeneratedFunctions = "0.5"
SafeTestsets = "0.0.1"
StaticArrays = "1.9"
Expand All @@ -24,9 +25,10 @@ julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "SafeTestsets", "StaticArrays"]
test = ["Aqua", "Pkg", "Test", "SafeTestsets", "StaticArrays"]
7 changes: 7 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,10 @@ symbolic_evaluate
SymbolCache
ProblemState
```

### Batched Queries and Updates

```@docs
BatchedInterface
associated_systems
```
3 changes: 3 additions & 0 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ include("parameter_indexing.jl")
export state_values, set_state!, current_time, getu, setu
include("state_indexing.jl")

export BatchedInterface, associated_systems
include("batched_interface.jl")

export ProblemState
include("problem_state.jl")

Expand Down
269 changes: 269 additions & 0 deletions src/batched_interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
"""
struct BatchedInterface{S <: AbstractVector, I}
function BatchedInterface(syssyms::Tuple...)

A struct which stores information for batched calls to [`getu`](@ref) or [`setu`](@ref).
Given `Tuple`s, where the first element of each tuple is a system and the second an
array of symbols (either variables or parameters) in the system, `BatchedInterface` will
compute the union of all symbols and associate each symbol with the first system with
which it occurs.

For example, given two systems `s1 = SymbolCache([:x, :y, :z])` and
`s2 = SymbolCache([:y, :z, :w])`, `BatchedInterface((s1, [:x, :y]), (s2, [:y, :z]))` will
associate `:x` and `:y` with `s1` and `:z` with `s2`. The information that `s1` had
associated symbols `:x` and `:y` and `s2` had associated symbols `:y` and `:z` will also
be retained internally.

`BatchedInterface` implements [`variable_symbols`](@ref), [`is_variable`](@ref),
[`variable_index`](@ref) to query the order of symbols in the union.

See [`getu`](@ref) and [`setu`](@ref) for further details.

See also: [`associated_systems`](@ref).
"""
struct BatchedInterface{S <: AbstractVector, I, T}
"Order of symbols in the union."
symbol_order::S
"Index of the system each symbol in the union is associated with."
associated_systems::Vector{Int}
"Index of symbol in the system it is associated with."
associated_indexes::I
"Whether the symbol is a state in the system it is associated with."
isstate::BitVector
"Map from system to indexes of its symbols in the union."
system_to_symbol_subset::Vector{Vector{Int}}
"Map from system to indexes of its symbols in the system."
system_to_symbol_indexes::Vector{Vector{T}}
"Map from system to whether each of its symbols is a state in the system."
system_to_isstate::Vector{BitVector}
end

function BatchedInterface(syssyms::Tuple...)
symbol_order = []
associated_systems = Int[]
associated_indexes = []
isstate = BitVector()
system_to_symbol_subset = Vector{Int}[]
system_to_symbol_indexes = []
system_to_isstate = BitVector[]
for (i, (sys, syms)) in enumerate(syssyms)
symbol_subset = Int[]
symbol_indexes = []
system_isstate = BitVector()
allsyms = []
for sym in syms
if symbolic_type(sym) === NotSymbolic()
error("Only symbolic variables allowed in BatchedInterface.")
end
if symbolic_type(sym) === ArraySymbolic()
append!(allsyms, collect(sym))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This removes the need for ArrayPartitions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

else
push!(allsyms, sym)
end
end
for sym in allsyms
if !is_variable(sys, sym) && !is_parameter(sys, sym)
error("Only variables and parameters allowed in BatchedInterface.")
end
if !any(isequal(sym), symbol_order)
push!(symbol_order, sym)
push!(associated_systems, i)
push!(isstate, is_variable(sys, sym))
if isstate[end]
push!(associated_indexes, variable_index(sys, sym))
else
push!(associated_indexes, parameter_index(sys, sym))
end
end
push!(symbol_subset, findfirst(isequal(sym), symbol_order))
push!(system_isstate, is_variable(sys, sym))
push!(symbol_indexes,
system_isstate[end] ? variable_index(sys, sym) : parameter_index(sys, sym))
end
push!(system_to_symbol_subset, symbol_subset)
push!(system_to_symbol_indexes, identity.(symbol_indexes))
push!(system_to_isstate, system_isstate)
end
symbol_order = identity.(symbol_order)
associated_indexes = identity.(associated_indexes)
system_to_symbol_indexes = identity.(system_to_symbol_indexes)

return BatchedInterface{typeof(symbol_order), typeof(associated_indexes),
eltype(eltype(system_to_symbol_indexes))}(
symbol_order, associated_systems, associated_indexes, isstate,
system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate)
end

variable_symbols(bi::BatchedInterface) = bi.symbol_order
variable_index(bi::BatchedInterface, sym) = findfirst(isequal(sym), bi.symbol_order)
is_variable(bi::BatchedInterface, sym) = variable_index(bi, sym) !== nothing

"""
associated_systems(bi::BatchedInterface)

Return an array of integers of the same length as `variable_symbols(bi)` where each value
is the index of the system associated with the corresponding symbol in
`variable_symbols(bi)`.
"""
associated_systems(bi::BatchedInterface) = bi.associated_systems

"""
getu(bi::BatchedInterface)

Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols),
return a function which takes `n` corresponding problems and returns an array of the values
of the symbols in the union. The returned function can also be passed an `AbstractArray` of
the appropriate `eltype` and size as its first argument, in which case the operation will
populate the array in-place with the values of the symbols in the union.

Note that all of the problems passed to the function returned by `getu` must satisfy
`is_timeseries(prob) === NotTimeseries()`.

The value of the `i`th symbol in the union (obtained through `variable_symbols(bi)[i]`) is
obtained from the problem corresponding to the associated system (i.e. the problem at
index `associated_systems(bi)[i]`).

See also: [`variable_symbols`](@ref), [`associated_systems`](@ref), [`is_timeseries`](@ref),
[`NotTimeseries`](@ref).
"""
function getu(bi::BatchedInterface)
numprobs = length(bi.system_to_symbol_subset)
probnames = [Symbol(:prob, i) for i in 1:numprobs]

fnbody = quote end
for (i, (prob, idx, isstate)) in enumerate(zip(
bi.associated_systems, bi.associated_indexes, bi.isstate))
symname = Symbol(:sym, i)
getter = isstate ? state_values : parameter_values
probname = probnames[prob]
push!(fnbody.args, :($symname = $getter($probname, $idx)))
end

oop_expr = Expr(:vect)
for i in 1:length(bi.symbol_order)
push!(oop_expr.args, Symbol(:sym, i))
end

iip_expr = quote end
for i in 1:length(bi.symbol_order)
symname = Symbol(:sym, i)
push!(iip_expr.args, :(out[$i] = $symname))
end

oopfn = Expr(
:function,
Expr(:tuple, probnames...),
quote
$fnbody
$oop_expr
end
)
iipfn = Expr(
:function,
Expr(:tuple, :out, probnames...),
quote
$fnbody
$iip_expr
out
end
)

return let oop = @RuntimeGeneratedFunction(oopfn),
iip = @RuntimeGeneratedFunction(iipfn)

_getter(probs...) = oop(probs...)
_getter(out::AbstractArray, probs...) = iip(out, probs...)
_getter
end
end

"""
setu(bi::BatchedInterface)

Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols),
return a function which takes `n` corresponding problems and an array of the values, and
updates each of the problems with the values of the corresponding symbols.

Note that all of the problems passed to the function returned by `setu` must satisfy
`is_timeseries(prob) === NotTimeseries()`.

Note that if any subset of the `n` systems share common symbols (among those passed to
`BatchedInterface`) then all of the corresponding problems in the subset will be updated
with the values of the common symbols.

See also: [`is_timeseries`](@ref), [`NotTimeseries`](@ref).
"""
function setu(bi::BatchedInterface)
numprobs = length(bi.system_to_symbol_subset)
probnames = [Symbol(:prob, i) for i in 1:numprobs]

full_update_fnexpr = let fnbody = quote end
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset)
probname = probnames[sys_idx]
for (idx_in_subset, idx_in_union) in enumerate(subset)
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
isstate = bi.system_to_isstate[sys_idx][idx_in_subset]
setter = isstate ? set_state! : set_parameter!
push!(fnbody.args, :($setter($probname, vals[$idx_in_union], $idx)))
end
# also run hook
if !all(bi.system_to_isstate[sys_idx])
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
for idx_in_subset in 1:length(subset)
if !bi.system_to_isstate[sys_idx][idx_in_subset]]
push!(fnbody.args, :($finalize_parameters_hook!($probname, $paramidxs)))
end
end
push!(fnbody.args, :(return vals))
Expr(
:function,
Expr(:tuple, probnames..., :vals),
fnbody
)
end

partial_update_fnexpr = let fnbody = quote end
curfnbody = fnbody
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset)
newcurfnbody = if sys_idx == 1
Expr(:if, :(idx == $sys_idx))
else
Expr(:elseif, :(idx == $sys_idx))
end
push!(curfnbody.args, newcurfnbody)
curfnbody = newcurfnbody

ifbody = quote end
push!(curfnbody.args, ifbody)

probname = :prob
for (idx_in_subset, idx_in_union) in enumerate(subset)
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
isstate = bi.system_to_isstate[sys_idx][idx_in_subset]
setter = isstate ? set_state! : set_parameter!
push!(ifbody.args, :($setter($probname, vals[$idx_in_union], $idx)))
end
# also run hook
if !all(bi.system_to_isstate[sys_idx])
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
for idx_in_subset in 1:length(subset)
if !bi.system_to_isstate[sys_idx][idx_in_subset]]
push!(ifbody.args, :($finalize_parameters_hook!($probname, $paramidxs)))
end
end
push!(curfnbody.args, :(error("Invalid problem index $idx")))
push!(fnbody.args, :(return nothing))
Expr(
:function,
Expr(:tuple, :prob, :idx, :vals),
fnbody
)
end
return let full_update = @RuntimeGeneratedFunction(full_update_fnexpr),
partial_update = @RuntimeGeneratedFunction(partial_update_fnexpr)

setter!(args...) = full_update(args...)
setter!(prob, idx::Int, vals::AbstractVector) = partial_update(prob, idx, vals)
setter!
end
end
56 changes: 56 additions & 0 deletions test/batched_interface_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using SymbolicIndexingInterface

syss = [
SymbolCache([:x, :y, :z], [:a, :b, :c], :t),
SymbolCache([:z, :w, :v], [:c, :e, :f]),
SymbolCache([:w, :x, :u], [:e, :a, :f])
]
syms = [
[:x, :z, :b, :c],
[:z, :w, :c, :f],
[:w, :x, :e, :a]
]
probs = [
ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3]),
ProblemState(; u = [4.0, 5.0, 6.0], p = [0.4, 0.5, 0.6]),
ProblemState(; u = [7.0, 8.0, 9.0], p = [0.7, 0.8, 0.9])
]

@test_throws ErrorException BatchedInterface((syss[1], [:x, 3]))
@test_throws ErrorException BatchedInterface((syss[1], [:(x + y)]))
@test_throws ErrorException BatchedInterface((syss[1], [:t]))

bi = BatchedInterface(zip(syss, syms)...)
@test variable_symbols(bi) == [:x, :z, :b, :c, :w, :f, :e, :a]
@test variable_index.((bi,), [:a, :b, :c, :e, :f, :x, :y, :z, :w, :v, :u]) ==
[8, 3, 4, 7, 6, 1, nothing, 2, 5, nothing, nothing]
@test is_variable.((bi,), [:a, :b, :c, :e, :f, :x, :y, :z, :w, :v, :u]) ==
Bool[1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0]
@test associated_systems(bi) == [1, 1, 1, 1, 2, 2, 3, 3]

getter = getu(bi)
@test (@inferred getter(probs...)) == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem quite right, you'd want an array for each to make remaking easy?

buf = zeros(8)
@inferred getter(buf, probs...)
@test buf == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8]

setter! = setu(bi)
buf .*= 100
setter!(probs..., buf)

@test state_values(probs[1]) == [100.0, 2.0, 300.0]
# :a isn't updated here because it wasn't part of the symbols associated with syss[1] (syms[1])
@test parameter_values(probs[1]) == [0.1, 20.0, 30.0]
@test state_values(probs[2]) == [300.0, 500.0, 6.0]
# Similarly for :e
@test parameter_values(probs[2]) == [30.0, 0.5, 60.0]
@test state_values(probs[3]) == [500.0, 100.0, 9.0]
# Similarly for :f
@test parameter_values(probs[3]) == [70.0, 80.0, 0.9]

buf ./= 100
setter!(probs[1], 1, buf)
@test state_values(probs[1]) == [1.0, 2.0, 3.0]
@test parameter_values(probs[1]) == [0.1, 0.2, 0.3]

@test_throws ErrorException setter!(probs[1], 4, buf)
3 changes: 3 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[deps]
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Loading
Loading