Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add BatchedInterface #65

Merged
merged 5 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ jobs:
strategy:
matrix:
group:
- All
- Core
- Downstream
version:
- '1'
steps:
Expand All @@ -32,3 +33,5 @@ jobs:
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
GROUP: ${{ matrix.group }}
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ Accessors = "0.1.36"
Aqua = "0.8"
ArrayInterface = "7.9"
MacroTools = "0.5.13"
RuntimeGeneratedFunctions = "0.5"
Pkg = "1"
RuntimeGeneratedFunctions = "0.5.12"
SafeTestsets = "0.0.1"
StaticArrays = "1.9"
StaticArraysCore = "1.4"
Expand All @@ -24,9 +25,10 @@ julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "SafeTestsets", "StaticArrays"]
test = ["Aqua", "Pkg", "Test", "SafeTestsets", "StaticArrays"]
7 changes: 7 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,10 @@ symbolic_evaluate
SymbolCache
ProblemState
```

### Batched Queries and Updates

```@docs
BatchedInterface
associated_systems
```
3 changes: 3 additions & 0 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ include("parameter_indexing.jl")
export state_values, set_state!, current_time, getu, setu
include("state_indexing.jl")

export BatchedInterface, associated_systems
include("batched_interface.jl")

export ProblemState
include("problem_state.jl")

Expand Down
269 changes: 269 additions & 0 deletions src/batched_interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
"""
struct BatchedInterface{S <: AbstractVector, I}
function BatchedInterface(syssyms::Tuple...)

A struct which stores information for batched calls to [`getu`](@ref) or [`setu`](@ref).
Given `Tuple`s, where the first element of each tuple is a system and the second an
array of symbols (either variables or parameters) in the system, `BatchedInterface` will
compute the union of all symbols and associate each symbol with the first system with
which it occurs.

For example, given two systems `s1 = SymbolCache([:x, :y, :z])` and
`s2 = SymbolCache([:y, :z, :w])`, `BatchedInterface((s1, [:x, :y]), (s2, [:y, :z]))` will
associate `:x` and `:y` with `s1` and `:z` with `s2`. The information that `s1` had
associated symbols `:x` and `:y` and `s2` had associated symbols `:y` and `:z` will also
be retained internally.

`BatchedInterface` implements [`variable_symbols`](@ref), [`is_variable`](@ref),
[`variable_index`](@ref) to query the order of symbols in the union.

See [`getu`](@ref) and [`setu`](@ref) for further details.

See also: [`associated_systems`](@ref).
"""
struct BatchedInterface{S <: AbstractVector, I, T}
"Order of symbols in the union."
symbol_order::S
"Index of the system each symbol in the union is associated with."
associated_systems::Vector{Int}
"Index of symbol in the system it is associated with."
associated_indexes::I
"Whether the symbol is a state in the system it is associated with."
isstate::BitVector
"Map from system to indexes of its symbols in the union."
system_to_symbol_subset::Vector{Vector{Int}}
"Map from system to indexes of its symbols in the system."
system_to_symbol_indexes::Vector{Vector{T}}
"Map from system to whether each of its symbols is a state in the system."
system_to_isstate::Vector{BitVector}
end

function BatchedInterface(syssyms::Tuple...)
symbol_order = []
associated_systems = Int[]
associated_indexes = []
isstate = BitVector()
system_to_symbol_subset = Vector{Int}[]
system_to_symbol_indexes = []
system_to_isstate = BitVector[]
for (i, (sys, syms)) in enumerate(syssyms)
symbol_subset = Int[]
symbol_indexes = []
system_isstate = BitVector()
allsyms = []
for sym in syms
if symbolic_type(sym) === NotSymbolic()
error("Only symbolic variables allowed in BatchedInterface.")

Check warning on line 56 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L41-L56

Added lines #L41 - L56 were not covered by tests
end
if symbolic_type(sym) === ArraySymbolic()
append!(allsyms, collect(sym))

Check warning on line 59 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L58-L59

Added lines #L58 - L59 were not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This removes the need for ArrayPartitions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

else
push!(allsyms, sym)

Check warning on line 61 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L61

Added line #L61 was not covered by tests
end
end
for sym in allsyms
if !is_variable(sys, sym) && !is_parameter(sys, sym)
error("Only variables and parameters allowed in BatchedInterface.")

Check warning on line 66 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L63-L66

Added lines #L63 - L66 were not covered by tests
end
if !any(isequal(sym), symbol_order)
push!(symbol_order, sym)
push!(associated_systems, i)
push!(isstate, is_variable(sys, sym))
if isstate[end]
push!(associated_indexes, variable_index(sys, sym))

Check warning on line 73 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L68-L73

Added lines #L68 - L73 were not covered by tests
else
push!(associated_indexes, parameter_index(sys, sym))

Check warning on line 75 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L75

Added line #L75 was not covered by tests
end
end
push!(symbol_subset, findfirst(isequal(sym), symbol_order))
push!(system_isstate, is_variable(sys, sym))
push!(symbol_indexes,

Check warning on line 80 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L78-L80

Added lines #L78 - L80 were not covered by tests
system_isstate[end] ? variable_index(sys, sym) : parameter_index(sys, sym))
end
push!(system_to_symbol_subset, symbol_subset)
push!(system_to_symbol_indexes, identity.(symbol_indexes))
push!(system_to_isstate, system_isstate)
end
symbol_order = identity.(symbol_order)
associated_indexes = identity.(associated_indexes)
system_to_symbol_indexes = identity.(system_to_symbol_indexes)

Check warning on line 89 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L82-L89

Added lines #L82 - L89 were not covered by tests

return BatchedInterface{typeof(symbol_order), typeof(associated_indexes),

Check warning on line 91 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L91

Added line #L91 was not covered by tests
eltype(eltype(system_to_symbol_indexes))}(
symbol_order, associated_systems, associated_indexes, isstate,
system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate)
end

variable_symbols(bi::BatchedInterface) = bi.symbol_order
variable_index(bi::BatchedInterface, sym) = findfirst(isequal(sym), bi.symbol_order)
is_variable(bi::BatchedInterface, sym) = variable_index(bi, sym) !== nothing

Check warning on line 99 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L97-L99

Added lines #L97 - L99 were not covered by tests

"""
associated_systems(bi::BatchedInterface)

Return an array of integers of the same length as `variable_symbols(bi)` where each value
is the index of the system associated with the corresponding symbol in
`variable_symbols(bi)`.
"""
associated_systems(bi::BatchedInterface) = bi.associated_systems

Check warning on line 108 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L108

Added line #L108 was not covered by tests

"""
getu(bi::BatchedInterface)

Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols),
return a function which takes `n` corresponding problems and returns an array of the values
of the symbols in the union. The returned function can also be passed an `AbstractArray` of
the appropriate `eltype` and size as its first argument, in which case the operation will
populate the array in-place with the values of the symbols in the union.

Note that all of the problems passed to the function returned by `getu` must satisfy
`is_timeseries(prob) === NotTimeseries()`.

The value of the `i`th symbol in the union (obtained through `variable_symbols(bi)[i]`) is
obtained from the problem corresponding to the associated system (i.e. the problem at
index `associated_systems(bi)[i]`).

See also: [`variable_symbols`](@ref), [`associated_systems`](@ref), [`is_timeseries`](@ref),
[`NotTimeseries`](@ref).
"""
function getu(bi::BatchedInterface)
numprobs = length(bi.system_to_symbol_subset)
probnames = [Symbol(:prob, i) for i in 1:numprobs]

Check warning on line 131 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L129-L131

Added lines #L129 - L131 were not covered by tests

fnbody = quote end
for (i, (prob, idx, isstate)) in enumerate(zip(

Check warning on line 134 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L133-L134

Added lines #L133 - L134 were not covered by tests
bi.associated_systems, bi.associated_indexes, bi.isstate))
symname = Symbol(:sym, i)
getter = isstate ? state_values : parameter_values
probname = probnames[prob]
push!(fnbody.args, :($symname = $getter($probname, $idx)))
end

Check warning on line 140 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L136-L140

Added lines #L136 - L140 were not covered by tests

oop_expr = Expr(:vect)
for i in 1:length(bi.symbol_order)
push!(oop_expr.args, Symbol(:sym, i))
end

Check warning on line 145 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L142-L145

Added lines #L142 - L145 were not covered by tests

iip_expr = quote end
for i in 1:length(bi.symbol_order)
symname = Symbol(:sym, i)
push!(iip_expr.args, :(out[$i] = $symname))
end

Check warning on line 151 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L147-L151

Added lines #L147 - L151 were not covered by tests

oopfn = Expr(

Check warning on line 153 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L153

Added line #L153 was not covered by tests
:function,
Expr(:tuple, probnames...),
quote
$fnbody
$oop_expr

Check warning on line 158 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L157-L158

Added lines #L157 - L158 were not covered by tests
end
)
iipfn = Expr(

Check warning on line 161 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L161

Added line #L161 was not covered by tests
:function,
Expr(:tuple, :out, probnames...),
quote
$fnbody
$iip_expr
out

Check warning on line 167 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L165-L167

Added lines #L165 - L167 were not covered by tests
end
)

return let oop = @RuntimeGeneratedFunction(oopfn),
iip = @RuntimeGeneratedFunction(iipfn)

Check warning on line 172 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L171-L172

Added lines #L171 - L172 were not covered by tests

_getter(probs...) = oop(probs...)
_getter(out::AbstractArray, probs...) = iip(out, probs...)
_getter

Check warning on line 176 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L174-L176

Added lines #L174 - L176 were not covered by tests
end
end

"""
setu(bi::BatchedInterface)

Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols),
return a function which takes `n` corresponding problems and an array of the values, and
updates each of the problems with the values of the corresponding symbols.

Note that all of the problems passed to the function returned by `setu` must satisfy
`is_timeseries(prob) === NotTimeseries()`.

Note that if any subset of the `n` systems share common symbols (among those passed to
`BatchedInterface`) then all of the corresponding problems in the subset will be updated
with the values of the common symbols.

See also: [`is_timeseries`](@ref), [`NotTimeseries`](@ref).
"""
function setu(bi::BatchedInterface)
numprobs = length(bi.system_to_symbol_subset)
probnames = [Symbol(:prob, i) for i in 1:numprobs]

Check warning on line 198 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L196-L198

Added lines #L196 - L198 were not covered by tests

full_update_fnexpr = let fnbody = quote end
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset)
probname = probnames[sys_idx]
for (idx_in_subset, idx_in_union) in enumerate(subset)
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
isstate = bi.system_to_isstate[sys_idx][idx_in_subset]
setter = isstate ? set_state! : set_parameter!
push!(fnbody.args, :($setter($probname, vals[$idx_in_union], $idx)))
end

Check warning on line 208 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L200-L208

Added lines #L200 - L208 were not covered by tests
# also run hook
if !all(bi.system_to_isstate[sys_idx])
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset]

Check warning on line 211 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L210-L211

Added lines #L210 - L211 were not covered by tests
for idx_in_subset in 1:length(subset)
if !bi.system_to_isstate[sys_idx][idx_in_subset]]
push!(fnbody.args, :($finalize_parameters_hook!($probname, $paramidxs)))

Check warning on line 214 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L214

Added line #L214 was not covered by tests
end
end
push!(fnbody.args, :(return vals))
Expr(

Check warning on line 218 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L216-L218

Added lines #L216 - L218 were not covered by tests
:function,
Expr(:tuple, probnames..., :vals),
fnbody
)
end

partial_update_fnexpr = let fnbody = quote end
curfnbody = fnbody
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset)
newcurfnbody = if sys_idx == 1
Expr(:if, :(idx == $sys_idx))

Check warning on line 229 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L225-L229

Added lines #L225 - L229 were not covered by tests
else
Expr(:elseif, :(idx == $sys_idx))

Check warning on line 231 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L231

Added line #L231 was not covered by tests
end
push!(curfnbody.args, newcurfnbody)
curfnbody = newcurfnbody

Check warning on line 234 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L233-L234

Added lines #L233 - L234 were not covered by tests

ifbody = quote end
push!(curfnbody.args, ifbody)

Check warning on line 237 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L236-L237

Added lines #L236 - L237 were not covered by tests

probname = :prob
for (idx_in_subset, idx_in_union) in enumerate(subset)
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
isstate = bi.system_to_isstate[sys_idx][idx_in_subset]
setter = isstate ? set_state! : set_parameter!
push!(ifbody.args, :($setter($probname, vals[$idx_in_union], $idx)))
end

Check warning on line 245 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L239-L245

Added lines #L239 - L245 were not covered by tests
# also run hook
if !all(bi.system_to_isstate[sys_idx])
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset]

Check warning on line 248 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L247-L248

Added lines #L247 - L248 were not covered by tests
for idx_in_subset in 1:length(subset)
if !bi.system_to_isstate[sys_idx][idx_in_subset]]
push!(ifbody.args, :($finalize_parameters_hook!($probname, $paramidxs)))

Check warning on line 251 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L251

Added line #L251 was not covered by tests
end
end
push!(curfnbody.args, :(error("Invalid problem index $idx")))
push!(fnbody.args, :(return nothing))
Expr(

Check warning on line 256 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L253-L256

Added lines #L253 - L256 were not covered by tests
:function,
Expr(:tuple, :prob, :idx, :vals),
fnbody
)
end
return let full_update = @RuntimeGeneratedFunction(full_update_fnexpr),
partial_update = @RuntimeGeneratedFunction(partial_update_fnexpr)

Check warning on line 263 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L262-L263

Added lines #L262 - L263 were not covered by tests

setter!(args...) = full_update(args...)
setter!(prob, idx::Int, vals::AbstractVector) = partial_update(prob, idx, vals)
setter!

Check warning on line 267 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L265-L267

Added lines #L265 - L267 were not covered by tests
end
end
56 changes: 56 additions & 0 deletions test/batched_interface_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using SymbolicIndexingInterface

syss = [
SymbolCache([:x, :y, :z], [:a, :b, :c], :t),
SymbolCache([:z, :w, :v], [:c, :e, :f]),
SymbolCache([:w, :x, :u], [:e, :a, :f])
]
syms = [
[:x, :z, :b, :c],
[:z, :w, :c, :f],
[:w, :x, :e, :a]
]
probs = [
ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3]),
ProblemState(; u = [4.0, 5.0, 6.0], p = [0.4, 0.5, 0.6]),
ProblemState(; u = [7.0, 8.0, 9.0], p = [0.7, 0.8, 0.9])
]

@test_throws ErrorException BatchedInterface((syss[1], [:x, 3]))
@test_throws ErrorException BatchedInterface((syss[1], [:(x + y)]))
@test_throws ErrorException BatchedInterface((syss[1], [:t]))

bi = BatchedInterface(zip(syss, syms)...)
@test variable_symbols(bi) == [:x, :z, :b, :c, :w, :f, :e, :a]
@test variable_index.((bi,), [:a, :b, :c, :e, :f, :x, :y, :z, :w, :v, :u]) ==
[8, 3, 4, 7, 6, 1, nothing, 2, 5, nothing, nothing]
@test is_variable.((bi,), [:a, :b, :c, :e, :f, :x, :y, :z, :w, :v, :u]) ==
Bool[1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0]
@test associated_systems(bi) == [1, 1, 1, 1, 2, 2, 3, 3]

getter = getu(bi)
@test (@inferred getter(probs...)) == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem quite right, you'd want an array for each to make remaking easy?

buf = zeros(8)
@inferred getter(buf, probs...)
@test buf == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8]

setter! = setu(bi)
buf .*= 100
setter!(probs..., buf)

@test state_values(probs[1]) == [100.0, 2.0, 300.0]
# :a isn't updated here because it wasn't part of the symbols associated with syss[1] (syms[1])
@test parameter_values(probs[1]) == [0.1, 20.0, 30.0]
@test state_values(probs[2]) == [300.0, 500.0, 6.0]
# Similarly for :e
@test parameter_values(probs[2]) == [30.0, 0.5, 60.0]
@test state_values(probs[3]) == [500.0, 100.0, 9.0]
# Similarly for :f
@test parameter_values(probs[3]) == [70.0, 80.0, 0.9]

buf ./= 100
setter!(probs[1], 1, buf)
@test state_values(probs[1]) == [1.0, 2.0, 3.0]
@test parameter_values(probs[1]) == [0.1, 0.2, 0.3]

@test_throws ErrorException setter!(probs[1], 4, buf)
3 changes: 3 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[deps]
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Loading
Loading