diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b78ae7e..309ea18 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,8 @@ jobs: strategy: matrix: group: - - All + - Core + - Downstream version: - '1' steps: @@ -32,3 +33,5 @@ jobs: ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 + env: + GROUP: ${{ matrix.group }} diff --git a/Project.toml b/Project.toml index 0d387ab..9860345 100644 --- a/Project.toml +++ b/Project.toml @@ -15,7 +15,8 @@ Accessors = "0.1.36" Aqua = "0.8" ArrayInterface = "7.9" MacroTools = "0.5.13" -RuntimeGeneratedFunctions = "0.5" +Pkg = "1" +RuntimeGeneratedFunctions = "0.5.12" SafeTestsets = "0.0.1" StaticArrays = "1.9" StaticArraysCore = "1.4" @@ -24,9 +25,10 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Test", "SafeTestsets", "StaticArrays"] +test = ["Aqua", "Pkg", "Test", "SafeTestsets", "StaticArrays"] diff --git a/docs/src/api.md b/docs/src/api.md index 3157ba7..d6b3cab 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -90,3 +90,10 @@ symbolic_evaluate SymbolCache ProblemState ``` + +### Batched Queries and Updates + +```@docs +BatchedInterface +associated_systems +``` diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index e5f8e9a..492ae77 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -31,6 +31,9 @@ include("parameter_indexing.jl") export state_values, set_state!, current_time, getu, setu include("state_indexing.jl") +export BatchedInterface, associated_systems +include("batched_interface.jl") + export ProblemState include("problem_state.jl") diff --git a/src/batched_interface.jl b/src/batched_interface.jl new file mode 100644 index 0000000..5c3617b --- /dev/null +++ b/src/batched_interface.jl @@ -0,0 +1,269 @@ +""" + struct BatchedInterface{S <: AbstractVector, I} + function BatchedInterface(syssyms::Tuple...) + +A struct which stores information for batched calls to [`getu`](@ref) or [`setu`](@ref). +Given `Tuple`s, where the first element of each tuple is a system and the second an +array of symbols (either variables or parameters) in the system, `BatchedInterface` will +compute the union of all symbols and associate each symbol with the first system with +which it occurs. + +For example, given two systems `s1 = SymbolCache([:x, :y, :z])` and +`s2 = SymbolCache([:y, :z, :w])`, `BatchedInterface((s1, [:x, :y]), (s2, [:y, :z]))` will +associate `:x` and `:y` with `s1` and `:z` with `s2`. The information that `s1` had +associated symbols `:x` and `:y` and `s2` had associated symbols `:y` and `:z` will also +be retained internally. + +`BatchedInterface` implements [`variable_symbols`](@ref), [`is_variable`](@ref), +[`variable_index`](@ref) to query the order of symbols in the union. + +See [`getu`](@ref) and [`setu`](@ref) for further details. + +See also: [`associated_systems`](@ref). +""" +struct BatchedInterface{S <: AbstractVector, I, T} + "Order of symbols in the union." + symbol_order::S + "Index of the system each symbol in the union is associated with." + associated_systems::Vector{Int} + "Index of symbol in the system it is associated with." + associated_indexes::I + "Whether the symbol is a state in the system it is associated with." + isstate::BitVector + "Map from system to indexes of its symbols in the union." + system_to_symbol_subset::Vector{Vector{Int}} + "Map from system to indexes of its symbols in the system." + system_to_symbol_indexes::Vector{Vector{T}} + "Map from system to whether each of its symbols is a state in the system." + system_to_isstate::Vector{BitVector} +end + +function BatchedInterface(syssyms::Tuple...) + symbol_order = [] + associated_systems = Int[] + associated_indexes = [] + isstate = BitVector() + system_to_symbol_subset = Vector{Int}[] + system_to_symbol_indexes = [] + system_to_isstate = BitVector[] + for (i, (sys, syms)) in enumerate(syssyms) + symbol_subset = Int[] + symbol_indexes = [] + system_isstate = BitVector() + allsyms = [] + for sym in syms + if symbolic_type(sym) === NotSymbolic() + error("Only symbolic variables allowed in BatchedInterface.") + end + if symbolic_type(sym) === ArraySymbolic() + append!(allsyms, collect(sym)) + else + push!(allsyms, sym) + end + end + for sym in allsyms + if !is_variable(sys, sym) && !is_parameter(sys, sym) + error("Only variables and parameters allowed in BatchedInterface.") + end + if !any(isequal(sym), symbol_order) + push!(symbol_order, sym) + push!(associated_systems, i) + push!(isstate, is_variable(sys, sym)) + if isstate[end] + push!(associated_indexes, variable_index(sys, sym)) + else + push!(associated_indexes, parameter_index(sys, sym)) + end + end + push!(symbol_subset, findfirst(isequal(sym), symbol_order)) + push!(system_isstate, is_variable(sys, sym)) + push!(symbol_indexes, + system_isstate[end] ? variable_index(sys, sym) : parameter_index(sys, sym)) + end + push!(system_to_symbol_subset, symbol_subset) + push!(system_to_symbol_indexes, identity.(symbol_indexes)) + push!(system_to_isstate, system_isstate) + end + symbol_order = identity.(symbol_order) + associated_indexes = identity.(associated_indexes) + system_to_symbol_indexes = identity.(system_to_symbol_indexes) + + return BatchedInterface{typeof(symbol_order), typeof(associated_indexes), + eltype(eltype(system_to_symbol_indexes))}( + symbol_order, associated_systems, associated_indexes, isstate, + system_to_symbol_subset, system_to_symbol_indexes, system_to_isstate) +end + +variable_symbols(bi::BatchedInterface) = bi.symbol_order +variable_index(bi::BatchedInterface, sym) = findfirst(isequal(sym), bi.symbol_order) +is_variable(bi::BatchedInterface, sym) = variable_index(bi, sym) !== nothing + +""" + associated_systems(bi::BatchedInterface) + +Return an array of integers of the same length as `variable_symbols(bi)` where each value +is the index of the system associated with the corresponding symbol in +`variable_symbols(bi)`. +""" +associated_systems(bi::BatchedInterface) = bi.associated_systems + +""" + getu(bi::BatchedInterface) + +Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols), +return a function which takes `n` corresponding problems and returns an array of the values +of the symbols in the union. The returned function can also be passed an `AbstractArray` of +the appropriate `eltype` and size as its first argument, in which case the operation will +populate the array in-place with the values of the symbols in the union. + +Note that all of the problems passed to the function returned by `getu` must satisfy +`is_timeseries(prob) === NotTimeseries()`. + +The value of the `i`th symbol in the union (obtained through `variable_symbols(bi)[i]`) is +obtained from the problem corresponding to the associated system (i.e. the problem at +index `associated_systems(bi)[i]`). + +See also: [`variable_symbols`](@ref), [`associated_systems`](@ref), [`is_timeseries`](@ref), +[`NotTimeseries`](@ref). +""" +function getu(bi::BatchedInterface) + numprobs = length(bi.system_to_symbol_subset) + probnames = [Symbol(:prob, i) for i in 1:numprobs] + + fnbody = quote end + for (i, (prob, idx, isstate)) in enumerate(zip( + bi.associated_systems, bi.associated_indexes, bi.isstate)) + symname = Symbol(:sym, i) + getter = isstate ? state_values : parameter_values + probname = probnames[prob] + push!(fnbody.args, :($symname = $getter($probname, $idx))) + end + + oop_expr = Expr(:vect) + for i in 1:length(bi.symbol_order) + push!(oop_expr.args, Symbol(:sym, i)) + end + + iip_expr = quote end + for i in 1:length(bi.symbol_order) + symname = Symbol(:sym, i) + push!(iip_expr.args, :(out[$i] = $symname)) + end + + oopfn = Expr( + :function, + Expr(:tuple, probnames...), + quote + $fnbody + $oop_expr + end + ) + iipfn = Expr( + :function, + Expr(:tuple, :out, probnames...), + quote + $fnbody + $iip_expr + out + end + ) + + return let oop = @RuntimeGeneratedFunction(oopfn), + iip = @RuntimeGeneratedFunction(iipfn) + + _getter(probs...) = oop(probs...) + _getter(out::AbstractArray, probs...) = iip(out, probs...) + _getter + end +end + +""" + setu(bi::BatchedInterface) + +Given a [`BatchedInterface`](@ref) composed from `n` systems (and corresponding symbols), +return a function which takes `n` corresponding problems and an array of the values, and +updates each of the problems with the values of the corresponding symbols. + +Note that all of the problems passed to the function returned by `setu` must satisfy +`is_timeseries(prob) === NotTimeseries()`. + +Note that if any subset of the `n` systems share common symbols (among those passed to +`BatchedInterface`) then all of the corresponding problems in the subset will be updated +with the values of the common symbols. + +See also: [`is_timeseries`](@ref), [`NotTimeseries`](@ref). +""" +function setu(bi::BatchedInterface) + numprobs = length(bi.system_to_symbol_subset) + probnames = [Symbol(:prob, i) for i in 1:numprobs] + + full_update_fnexpr = let fnbody = quote end + for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset) + probname = probnames[sys_idx] + for (idx_in_subset, idx_in_union) in enumerate(subset) + idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset] + isstate = bi.system_to_isstate[sys_idx][idx_in_subset] + setter = isstate ? set_state! : set_parameter! + push!(fnbody.args, :($setter($probname, vals[$idx_in_union], $idx))) + end + # also run hook + if !all(bi.system_to_isstate[sys_idx]) + paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset] + for idx_in_subset in 1:length(subset) + if !bi.system_to_isstate[sys_idx][idx_in_subset]] + push!(fnbody.args, :($finalize_parameters_hook!($probname, $paramidxs))) + end + end + push!(fnbody.args, :(return vals)) + Expr( + :function, + Expr(:tuple, probnames..., :vals), + fnbody + ) + end + + partial_update_fnexpr = let fnbody = quote end + curfnbody = fnbody + for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset) + newcurfnbody = if sys_idx == 1 + Expr(:if, :(idx == $sys_idx)) + else + Expr(:elseif, :(idx == $sys_idx)) + end + push!(curfnbody.args, newcurfnbody) + curfnbody = newcurfnbody + + ifbody = quote end + push!(curfnbody.args, ifbody) + + probname = :prob + for (idx_in_subset, idx_in_union) in enumerate(subset) + idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset] + isstate = bi.system_to_isstate[sys_idx][idx_in_subset] + setter = isstate ? set_state! : set_parameter! + push!(ifbody.args, :($setter($probname, vals[$idx_in_union], $idx))) + end + # also run hook + if !all(bi.system_to_isstate[sys_idx]) + paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset] + for idx_in_subset in 1:length(subset) + if !bi.system_to_isstate[sys_idx][idx_in_subset]] + push!(ifbody.args, :($finalize_parameters_hook!($probname, $paramidxs))) + end + end + push!(curfnbody.args, :(error("Invalid problem index $idx"))) + push!(fnbody.args, :(return nothing)) + Expr( + :function, + Expr(:tuple, :prob, :idx, :vals), + fnbody + ) + end + return let full_update = @RuntimeGeneratedFunction(full_update_fnexpr), + partial_update = @RuntimeGeneratedFunction(partial_update_fnexpr) + + setter!(args...) = full_update(args...) + setter!(prob, idx::Int, vals::AbstractVector) = partial_update(prob, idx, vals) + setter! + end +end diff --git a/test/batched_interface_test.jl b/test/batched_interface_test.jl new file mode 100644 index 0000000..3e622cd --- /dev/null +++ b/test/batched_interface_test.jl @@ -0,0 +1,56 @@ +using SymbolicIndexingInterface + +syss = [ + SymbolCache([:x, :y, :z], [:a, :b, :c], :t), + SymbolCache([:z, :w, :v], [:c, :e, :f]), + SymbolCache([:w, :x, :u], [:e, :a, :f]) +] +syms = [ + [:x, :z, :b, :c], + [:z, :w, :c, :f], + [:w, :x, :e, :a] +] +probs = [ + ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3]), + ProblemState(; u = [4.0, 5.0, 6.0], p = [0.4, 0.5, 0.6]), + ProblemState(; u = [7.0, 8.0, 9.0], p = [0.7, 0.8, 0.9]) +] + +@test_throws ErrorException BatchedInterface((syss[1], [:x, 3])) +@test_throws ErrorException BatchedInterface((syss[1], [:(x + y)])) +@test_throws ErrorException BatchedInterface((syss[1], [:t])) + +bi = BatchedInterface(zip(syss, syms)...) +@test variable_symbols(bi) == [:x, :z, :b, :c, :w, :f, :e, :a] +@test variable_index.((bi,), [:a, :b, :c, :e, :f, :x, :y, :z, :w, :v, :u]) == + [8, 3, 4, 7, 6, 1, nothing, 2, 5, nothing, nothing] +@test is_variable.((bi,), [:a, :b, :c, :e, :f, :x, :y, :z, :w, :v, :u]) == + Bool[1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0] +@test associated_systems(bi) == [1, 1, 1, 1, 2, 2, 3, 3] + +getter = getu(bi) +@test (@inferred getter(probs...)) == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8] +buf = zeros(8) +@inferred getter(buf, probs...) +@test buf == [1.0, 3.0, 0.2, 0.3, 5.0, 0.6, 0.7, 0.8] + +setter! = setu(bi) +buf .*= 100 +setter!(probs..., buf) + +@test state_values(probs[1]) == [100.0, 2.0, 300.0] +# :a isn't updated here because it wasn't part of the symbols associated with syss[1] (syms[1]) +@test parameter_values(probs[1]) == [0.1, 20.0, 30.0] +@test state_values(probs[2]) == [300.0, 500.0, 6.0] +# Similarly for :e +@test parameter_values(probs[2]) == [30.0, 0.5, 60.0] +@test state_values(probs[3]) == [500.0, 100.0, 9.0] +# Similarly for :f +@test parameter_values(probs[3]) == [70.0, 80.0, 0.9] + +buf ./= 100 +setter!(probs[1], 1, buf) +@test state_values(probs[1]) == [1.0, 2.0, 3.0] +@test parameter_values(probs[1]) == [0.1, 0.2, 0.3] + +@test_throws ErrorException setter!(probs[1], 4, buf) diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml new file mode 100644 index 0000000..7450e66 --- /dev/null +++ b/test/downstream/Project.toml @@ -0,0 +1,3 @@ +[deps] +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" diff --git a/test/downstream/batchedinterface_arrayvars.jl b/test/downstream/batchedinterface_arrayvars.jl new file mode 100644 index 0000000..83c003d --- /dev/null +++ b/test/downstream/batchedinterface_arrayvars.jl @@ -0,0 +1,42 @@ +using Symbolics +using SymbolicIndexingInterface + +@variables x[1:2] y z + +syss = [ + SymbolCache([x..., y]), + SymbolCache([x[1], y, z]) +] +syms = [ + [x, y], + [x[1], y] +] +probs = [ + ProblemState(; u = [1.0, 2.0, 3.0]), + ProblemState(; u = [4.0, 5.0, 6.0]) +] + +bi = BatchedInterface(zip(syss, syms)...) + +@test all(isequal.(variable_symbols(bi), [x..., y])) +@test variable_index.((bi,), [x..., y, z]) == [1, 2, 3, nothing] +@test is_variable.((bi,), [x..., y, z]) == Bool[1, 1, 1, 0] +@test associated_systems(bi) == [1, 1, 1] + +getter = getu(bi) +@test (@inferred getter(probs...)) == [1.0, 2.0, 3.0] +buf = zeros(3) +@inferred getter(buf, probs...) +@test buf == [1.0, 2.0, 3.0] + +setter! = setu(bi) +buf .*= 10 +setter!(probs..., buf) + +@test state_values(probs[1]) == [10.0, 20.0, 30.0] +@test state_values(probs[2]) == [10.0, 30.0, 6.0] + +buf ./= 10 + +setter!(probs[1], 1, buf) +@test state_values(probs[1]) == [1.0, 2.0, 3.0] diff --git a/test/runtests.jl b/test/runtests.jl index d334c17..a295448 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,31 +1,52 @@ using SymbolicIndexingInterface using SafeTestsets using Test +using Pkg -@safetestset "Quality Assurance" begin - @time include("qa.jl") -end -@safetestset "Interface test" begin - @time include("example_test.jl") -end -@safetestset "Trait test" begin - @time include("trait_test.jl") -end -@safetestset "SymbolCache test" begin - @time include("symbol_cache_test.jl") -end -@safetestset "Fallback test" begin - @time include("fallback_test.jl") -end -@safetestset "Parameter indexing test" begin - @time include("parameter_indexing_test.jl") -end -@safetestset "State indexing test" begin - @time include("state_indexing_test.jl") +const GROUP = get(ENV, "GROUP", "All") + +function activate_downstream_env() + Pkg.activate("downstream") + Pkg.develop(PackageSpec(path = dirname(@__DIR__))) + Pkg.instantiate() end -@safetestset "Remake test" begin - @time include("remake_test.jl") + +if GROUP == "All" || GROUP == "Core" + @safetestset "Quality Assurance" begin + @time include("qa.jl") + end + @safetestset "Interface test" begin + @time include("example_test.jl") + end + @safetestset "Trait test" begin + @time include("trait_test.jl") + end + @safetestset "SymbolCache test" begin + @time include("symbol_cache_test.jl") + end + @safetestset "Fallback test" begin + @time include("fallback_test.jl") + end + @safetestset "Parameter indexing test" begin + @time include("parameter_indexing_test.jl") + end + @safetestset "State indexing test" begin + @time include("state_indexing_test.jl") + end + @safetestset "Remake test" begin + @time include("remake_test.jl") + end + @safetestset "ProblemState test" begin + @time include("problem_state_test.jl") + end + @safetestset "BatchedInterface test" begin + @time include("batched_interface_test.jl") + end end -@safetestset "ProblemState test" begin - @time include("problem_state_test.jl") + +if GROUP == "All" || GROUP == "Downstream" + activate_downstream_env() + @safetestset "BatchedInterface with array symbolics test" begin + @time include("downstream/batchedinterface_arrayvars.jl") + end end