Skip to content

Commit

Permalink
Merge pull request #374 from jlchan/jc/fix_bcast_multidim_VoA
Browse files Browse the repository at this point in the history
Fix multi-dimensional `VectorOfArray` broadcast
  • Loading branch information
ChrisRackauckas authored May 6, 2024
2 parents eacbe3f + b7f9c84 commit f7415e8
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
41 changes: 33 additions & 8 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,15 @@ the `VectorOfArray` into a matrix/tensor. Also, `vecarr_to_vectors(VA::AbstractV
returns a vector of the series for each component, that is, `A[i,:]` for each `i`.
A plot recipe is provided, which plots the `A[i,:]` series.
There is also support for `VectorOfArray` with constructed from multi-dimensional arrays
There is also support for `VectorOfArray` constructed from multi-dimensional arrays
```julia
VectorOfArray(u::AbstractArray{AT}) where {T, N, AT <: AbstractArray{T, N}}
```
where `IndexStyle(typeof(u)) isa IndexLinear`.
"""
mutable struct VectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N, A}
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
u::A # A <: AbstractArray{<: AbstractArray{T, N - 1}}
end
# VectorOfArray with an added series for time

Expand Down Expand Up @@ -719,7 +718,7 @@ end
# for VectorOfArray with multi-dimensional parent arrays of arrays where all elements are the same type
function Base.similar(vec::VectorOfArray{
T, N, AT}) where {T, N, AT <: AbstractArray{<:AbstractArray{T}}}
return VectorOfArray(similar(Base.parent(vec)))
return VectorOfArray(similar.(Base.parent(vec)))
end

# special-case when the multi-dimensional parent array is just an AbstractVector (call the old method)
Expand All @@ -728,6 +727,7 @@ function Base.similar(vec::VectorOfArray{
return Base.similar(vec, eltype(vec))
end


# fill!
# For DiffEqArray it ignores ts and fills only u
function Base.fill!(VA::AbstractVectorOfArray, x)
Expand Down Expand Up @@ -840,12 +840,37 @@ end
# make vectorofarrays broadcastable so they aren't collected
Broadcast.broadcastable(x::AbstractVectorOfArray) = x

# recurse through broadcast arguments and return a parent array for
# the first VoA or DiffEqArray in the bc arguments
function find_VoA_parent(args)
arg = Base.first(args)
if arg isa AbstractDiffEqArray
# if first(args) is a DiffEqArray, use the underlying
# field `u` of DiffEqArray as a parent array.
return arg.u
elseif arg isa AbstractVectorOfArray
return parent(arg)
else
return find_VoA_parent(Base.tail(args))
end
end

@inline function Base.copy(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
bc = Broadcast.flatten(bc)
N = narrays(bc)
VectorOfArray(map(1:N) do i
copy(unpack_voa(bc, i))
end)

parent = find_VoA_parent(bc.args)

if parent isa AbstractVector
# this is the default behavior in v3.15.0
N = narrays(bc)
return VectorOfArray(map(1:N) do i
copy(unpack_voa(bc, i))
end)
else # if parent isa AbstractArray
return VectorOfArray(map(enumerate(Iterators.product(axes(parent)...))) do (i, _)
copy(unpack_voa(bc, i))
end)
end
end

for (type, N_expr) in [
Expand Down
10 changes: 6 additions & 4 deletions test/basic_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ a[[1, 3, 8]]
# multidimensional array of arrays
####################################################################

u_matrix = VectorOfArray(fill([1, 2], 2, 3))
u_vector = VectorOfArray(vec(u_matrix.u))
u_matrix = VectorOfArray([[1, 2] for i in 1:2, j in 1:3])
u_vector = VectorOfArray([[1, 2] for i in 1:6])

# test broadcasting
function foo!(u)
Expand All @@ -248,11 +248,13 @@ function foo!(u)
end
foo!(u_matrix)
foo!(u_vector)
@test u_matrix u_vector
@test all(u_matrix .== [3, 10])
@test all(vec(u_matrix) .≈ vec(u_vector))

# test that, for VectorOfArray with multi-dimensional parent arrays,
# `similar` preserves the structure of the parent array
# broadcast and `similar` preserve the structure of the parent array
@test typeof(parent(similar(u_matrix))) == typeof(parent(u_matrix))
@test typeof(parent((x->x).(u_matrix))) == typeof(parent(u_matrix))

# test efficiency
num_allocs = @allocations foo!(u_matrix)
Expand Down

0 comments on commit f7415e8

Please sign in to comment.