Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VariableAggregator test pack CellSpace variables #94

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/utils/IteratorUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,9 @@ end

"""
reduce_longtuple(f, rinit, t1::Tuple, t2, ... tm; errmsg="iterables lengths differ") -> r
reduce_longtuple_p(f, rinit, t1::Tuple, t2, ... tm, p; errmsg="iterables lengths differ") -> r

Call `r += f(r, t1[n], t2[n], ... tm[n])` for each element
Call `r += f(r, t1[n], t2[n], ... tm[n])` or `r += f(r, t1[n], t2[n], ... tm[n], p)` for each element
`n` of `t1::Tuple`, `t2`, ... `tm`. Initial value of `r = rinit`

# Implementation
Expand Down Expand Up @@ -286,5 +287,29 @@ end
return ex
end

@inline reduce_longtuple_p(f::F, rinit, t1, p; errmsg="iterables lengths differ") where{F} =
reduce_longtuple_unchecked_p(f, rinit, t1, p)
@generated function reduce_longtuple_unchecked_p(f, rinit, t1::Tuple, p)
ex = quote ; end # empty expression
for j=1:fieldcount(t1)
push!(ex.args, quote; rinit = f(rinit, t1[$j], p); end)
end
push!(ex.args, quote; return rinit; end)

return ex
end

@inline reduce_longtuple_p(f::F, rinit, t1, t2, p; errmsg="iterables lengths differ") where{F} =
(check_lengths_equal(t1, t2; errmsg=errmsg); reduce_longtuple_unchecked_p(f, rinit, t1, t2, p))
@generated function reduce_longtuple_unchecked_p(f, rinit, t1::Tuple, t2, p)
ex = quote ; end # empty expression
for j=1:fieldcount(t1)
push!(ex.args, quote; rinit = f(rinit, t1[$j], t2[$j], p); end)
end
push!(ex.args, quote; return rinit; end)

return ex
end


end # module
75 changes: 58 additions & 17 deletions src/variableaggregators/VariableAggregator.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@

# import Infiltrator
import Infiltrator

#################################################################
# VariableAggregator
#################################################################

struct VariableAggregator{T, F <: Tuple, C <: Tuple}
struct VariableAggregator{T, F <: Tuple, C <: Tuple, PF <: Tuple, PC <: AbstractVector}
# modeldata used for data arrays (for diagnostic output only)
modeldata::AbstractModelData # not typed
arrays_idx::Int
Expand All @@ -19,10 +19,16 @@ struct VariableAggregator{T, F <: Tuple, C <: Tuple}
# field and corresponding CellRange for each Variable
fields::F
cellranges::C

# optional packed fields and corresponding cellranges
packed_vars::Vector{VariableDomain}
packed_fields::PF
packed_cellranges::PC
packed_indices::UnitRange{Int64}
end

"""
VariableAggregator(vars, cellranges, modeldata, arrays_idx) -> VariableAggregator
VariableAggregator(vars, cellranges, modeldata, arrays_idx [, packed_vars] [, packed_cellranges]) -> VariableAggregator

Aggregate multiple VariableDomains into a flattened list (a contiguous Vector).

Expand All @@ -31,12 +37,14 @@ for data arrays in `modeldata` with `arrays_idx`.

`cellranges` may contain `nothing` entries to indicate whole Domain.

Optionally add `packed_vars` which should be CellSpace variables from a single Domain with cells defined by `packed_cellranges`

This is mostly useful for aggregating state Variables, derivatives, etc to implement an interface to a generic ODE/DAE etc solver.

Values may be copied to and from a Vector using [`copyto!`](@ref)

"""
function VariableAggregator(vars, cellranges, modeldata::AbstractModelData, arrays_idx::Int)
function VariableAggregator(vars, cellranges, modeldata::AbstractModelData, arrays_idx::Int, packed_vars=[], packed_cellranges=[])

IteratorUtils.check_lengths_equal(vars, cellranges; errmsg="'vars' and 'cellranges' must be of same length")

Expand All @@ -56,24 +64,36 @@ function VariableAggregator(vars, cellranges, modeldata::AbstractModelData, arra
nextidx += dof
end

fields=Tuple(fields)
cellranges=Tuple(cellranges)
packed_fields = []
packed_dof = 0
for v in packed_vars
f = get_field(v, modeldata, arrays_idx)
push!(packed_fields, f)
for cr in packed_cellranges
packed_dof += dof_field(f, cr)
end
end
packed_indices = nextidx:(nextidx+packed_dof-1)

fields = Tuple(fields)
cellranges = Tuple(cellranges)
packed_fields = Tuple(packed_fields)

return VariableAggregator{eltype(modeldata, arrays_idx), typeof(fields), typeof(cellranges)}(
modeldata, arrays_idx, copy(vars), indices, fields, cellranges,
return VariableAggregator{eltype(modeldata, arrays_idx), typeof(fields), typeof(cellranges), typeof(packed_fields), typeof(packed_cellranges)}(
modeldata, arrays_idx, copy(vars), indices, fields, cellranges, copy(packed_vars), packed_fields, packed_cellranges, packed_indices
)
end


# compact form
function Base.show(io::IO, va::VariableAggregator)
print(io, "VariableAggregator(modeldata=$(va.modeldata), arrays_idx=$(va.arrays_idx), length=$(length(va)), number of variables=$(length(va.vars)))")
print(io, "VariableAggregator(modeldata=$(va.modeldata), arrays_idx=$(va.arrays_idx), length=$(length(va)), number of variables=$(length(va.vars)), number of packed variables=$(length(va.packed_vars)))")
return nothing
end

# multiline form
function Base.show(io::IO, ::MIME"text/plain", va::VariableAggregator)
println(io, "VariableAggregator(modeldata=$(va.modeldata), arrays_idx=$(va.arrays_idx), length=$(length(va)), number of variables=$(length(va.vars))):")
println(io, "VariableAggregator(modeldata=$(va.modeldata), arrays_idx=$(va.arrays_idx), length=$(length(va)), number of variables=$(length(va.vars)), number of packed variables=$(length(va.packed_vars))):")
if length(va) > 0
println(
io,
Expand All @@ -93,19 +113,32 @@ function Base.show(io::IO, ::MIME"text/plain", va::VariableAggregator)
rpad(isnothing(va.cellranges[i]) ? "-" : va.cellranges[i].indices, 14),
)
end

if !isempty(va.packed_vars)
println(
io,
" ",
rpad("var", 6),
rpad("indices", 14),
rpad("name", 40),
)
for i in eachindex(va.packed_vars)
println(
io,
" ",
rpad(i,6),
rpad(va.packed_indices, 14),
rpad("$(fullname(va.packed_vars[i]))", 40),
)
end
end
end
return nothing
end

Base.eltype(va::VariableAggregator{T, F, C}) where{T, F, C} = T

function Base.length(va::VariableAggregator)
if isempty(va.indices)
return 0
else
return last(last(va.indices))
end
end
Base.length(va::VariableAggregator) = last(last(va.packed_indices))

"""
get_indices(va::VariableAggregator, varnamefull::AbstractString; allow_not_found=false) -> indices::UnitRange{Int64}
Expand Down Expand Up @@ -141,6 +174,10 @@ function Base.copyto!(dest::VariableAggregator, src::AbstractVector; sof::Int=1)

fof = IteratorUtils.reduce_longtuple(copy, sof, dest.fields, dest.cellranges)

for cr in dest.packed_cellranges
fof = IteratorUtils.reduce_longtuple_p(copy, fof, dest.packed_fields, cr)
end

return fof - sof
end

Expand All @@ -161,6 +198,10 @@ function Base.copyto!(dest::AbstractVector, src::VariableAggregator; dof::Int=1)

fof = IteratorUtils.reduce_longtuple(copy, dof, src.fields, src.cellranges)

for cr in src.packed_cellranges
fof = IteratorUtils.reduce_longtuple_p(copy, fof, src.packed_fields, cr)
end

return fof - dof
end

Expand Down