diff --git a/src/utils/IteratorUtils.jl b/src/utils/IteratorUtils.jl index 3586923..1903b53 100644 --- a/src/utils/IteratorUtils.jl +++ b/src/utils/IteratorUtils.jl @@ -254,8 +254,9 @@ end """ reduce_longtuple(f, rinit, t1::Tuple, t2, ... tm; errmsg="iterables lengths differ") -> r + reduce_longtuple_p(f, rinit, t1::Tuple, t2, ... tm, p; errmsg="iterables lengths differ") -> r -Call `r += f(r, t1[n], t2[n], ... tm[n])` for each element +Call `r += f(r, t1[n], t2[n], ... tm[n])` or `r += f(r, t1[n], t2[n], ... tm[n], p)` for each element `n` of `t1::Tuple`, `t2`, ... `tm`. Initial value of `r = rinit` # Implementation @@ -286,5 +287,29 @@ end return ex end +@inline reduce_longtuple_p(f::F, rinit, t1, p; errmsg="iterables lengths differ") where{F} = + reduce_longtuple_unchecked_p(f, rinit, t1, p) +@generated function reduce_longtuple_unchecked_p(f, rinit, t1::Tuple, p) + ex = quote ; end # empty expression + for j=1:fieldcount(t1) + push!(ex.args, quote; rinit = f(rinit, t1[$j], p); end) + end + push!(ex.args, quote; return rinit; end) + + return ex +end + +@inline reduce_longtuple_p(f::F, rinit, t1, t2, p; errmsg="iterables lengths differ") where{F} = + (check_lengths_equal(t1, t2; errmsg=errmsg); reduce_longtuple_unchecked_p(f, rinit, t1, t2, p)) +@generated function reduce_longtuple_unchecked_p(f, rinit, t1::Tuple, t2, p) + ex = quote ; end # empty expression + for j=1:fieldcount(t1) + push!(ex.args, quote; rinit = f(rinit, t1[$j], t2[$j], p); end) + end + push!(ex.args, quote; return rinit; end) + + return ex +end + end # module diff --git a/src/variableaggregators/VariableAggregator.jl b/src/variableaggregators/VariableAggregator.jl index aa6c968..f7ccf46 100644 --- a/src/variableaggregators/VariableAggregator.jl +++ b/src/variableaggregators/VariableAggregator.jl @@ -1,11 +1,11 @@ -# import Infiltrator +import Infiltrator ################################################################# # VariableAggregator ################################################################# -struct VariableAggregator{T, F <: Tuple, C <: Tuple} +struct VariableAggregator{T, F <: Tuple, C <: Tuple, PF <: Tuple, PC <: AbstractVector} # modeldata used for data arrays (for diagnostic output only) modeldata::AbstractModelData # not typed arrays_idx::Int @@ -19,10 +19,16 @@ struct VariableAggregator{T, F <: Tuple, C <: Tuple} # field and corresponding CellRange for each Variable fields::F cellranges::C + + # optional packed fields and corresponding cellranges + packed_vars::Vector{VariableDomain} + packed_fields::PF + packed_cellranges::PC + packed_indices::UnitRange{Int64} end """ - VariableAggregator(vars, cellranges, modeldata, arrays_idx) -> VariableAggregator + VariableAggregator(vars, cellranges, modeldata, arrays_idx [, packed_vars] [, packed_cellranges]) -> VariableAggregator Aggregate multiple VariableDomains into a flattened list (a contiguous Vector). @@ -31,12 +37,14 @@ for data arrays in `modeldata` with `arrays_idx`. `cellranges` may contain `nothing` entries to indicate whole Domain. +Optionally add `packed_vars` which should be CellSpace variables from a single Domain with cells defined by `packed_cellranges` + This is mostly useful for aggregating state Variables, derivatives, etc to implement an interface to a generic ODE/DAE etc solver. Values may be copied to and from a Vector using [`copyto!`](@ref) """ -function VariableAggregator(vars, cellranges, modeldata::AbstractModelData, arrays_idx::Int) +function VariableAggregator(vars, cellranges, modeldata::AbstractModelData, arrays_idx::Int, packed_vars=[], packed_cellranges=[]) IteratorUtils.check_lengths_equal(vars, cellranges; errmsg="'vars' and 'cellranges' must be of same length") @@ -56,24 +64,36 @@ function VariableAggregator(vars, cellranges, modeldata::AbstractModelData, arra nextidx += dof end - fields=Tuple(fields) - cellranges=Tuple(cellranges) + packed_fields = [] + packed_dof = 0 + for v in packed_vars + f = get_field(v, modeldata, arrays_idx) + push!(packed_fields, f) + for cr in packed_cellranges + packed_dof += dof_field(f, cr) + end + end + packed_indices = nextidx:(nextidx+packed_dof-1) + + fields = Tuple(fields) + cellranges = Tuple(cellranges) + packed_fields = Tuple(packed_fields) - return VariableAggregator{eltype(modeldata, arrays_idx), typeof(fields), typeof(cellranges)}( - modeldata, arrays_idx, copy(vars), indices, fields, cellranges, + return VariableAggregator{eltype(modeldata, arrays_idx), typeof(fields), typeof(cellranges), typeof(packed_fields), typeof(packed_cellranges)}( + modeldata, arrays_idx, copy(vars), indices, fields, cellranges, copy(packed_vars), packed_fields, packed_cellranges, packed_indices ) end # compact form function Base.show(io::IO, va::VariableAggregator) - print(io, "VariableAggregator(modeldata=$(va.modeldata), arrays_idx=$(va.arrays_idx), length=$(length(va)), number of variables=$(length(va.vars)))") + print(io, "VariableAggregator(modeldata=$(va.modeldata), arrays_idx=$(va.arrays_idx), length=$(length(va)), number of variables=$(length(va.vars)), number of packed variables=$(length(va.packed_vars)))") return nothing end # multiline form function Base.show(io::IO, ::MIME"text/plain", va::VariableAggregator) - println(io, "VariableAggregator(modeldata=$(va.modeldata), arrays_idx=$(va.arrays_idx), length=$(length(va)), number of variables=$(length(va.vars))):") + println(io, "VariableAggregator(modeldata=$(va.modeldata), arrays_idx=$(va.arrays_idx), length=$(length(va)), number of variables=$(length(va.vars)), number of packed variables=$(length(va.packed_vars))):") if length(va) > 0 println( io, @@ -93,19 +113,32 @@ function Base.show(io::IO, ::MIME"text/plain", va::VariableAggregator) rpad(isnothing(va.cellranges[i]) ? "-" : va.cellranges[i].indices, 14), ) end + + if !isempty(va.packed_vars) + println( + io, + " ", + rpad("var", 6), + rpad("indices", 14), + rpad("name", 40), + ) + for i in eachindex(va.packed_vars) + println( + io, + " ", + rpad(i,6), + rpad(va.packed_indices, 14), + rpad("$(fullname(va.packed_vars[i]))", 40), + ) + end + end end return nothing end Base.eltype(va::VariableAggregator{T, F, C}) where{T, F, C} = T -function Base.length(va::VariableAggregator) - if isempty(va.indices) - return 0 - else - return last(last(va.indices)) - end -end +Base.length(va::VariableAggregator) = last(last(va.packed_indices)) """ get_indices(va::VariableAggregator, varnamefull::AbstractString; allow_not_found=false) -> indices::UnitRange{Int64} @@ -141,6 +174,10 @@ function Base.copyto!(dest::VariableAggregator, src::AbstractVector; sof::Int=1) fof = IteratorUtils.reduce_longtuple(copy, sof, dest.fields, dest.cellranges) + for cr in dest.packed_cellranges + fof = IteratorUtils.reduce_longtuple_p(copy, fof, dest.packed_fields, cr) + end + return fof - sof end @@ -161,6 +198,10 @@ function Base.copyto!(dest::AbstractVector, src::VariableAggregator; dof::Int=1) fof = IteratorUtils.reduce_longtuple(copy, dof, src.fields, src.cellranges) + for cr in src.packed_cellranges + fof = IteratorUtils.reduce_longtuple_p(copy, fof, src.packed_fields, cr) + end + return fof - dof end