Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multi-dimensional VectorOfArray broadcast #374

Merged
merged 12 commits into from
May 6, 2024
41 changes: 33 additions & 8 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,15 @@
returns a vector of the series for each component, that is, `A[i,:]` for each `i`.
A plot recipe is provided, which plots the `A[i,:]` series.

There is also support for `VectorOfArray` with constructed from multi-dimensional arrays

There is also support for `VectorOfArray` constructed from multi-dimensional arrays
```julia
VectorOfArray(u::AbstractArray{AT}) where {T, N, AT <: AbstractArray{T, N}}
```

where `IndexStyle(typeof(u)) isa IndexLinear`.
"""
mutable struct VectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N, A}
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
u::A # A <: AbstractArray{<: AbstractArray{T, N - 1}}
end
# VectorOfArray with an added series for time

Expand Down Expand Up @@ -719,7 +718,7 @@
# for VectorOfArray with multi-dimensional parent arrays of arrays where all elements are the same type
function Base.similar(vec::VectorOfArray{
T, N, AT}) where {T, N, AT <: AbstractArray{<:AbstractArray{T}}}
return VectorOfArray(similar(Base.parent(vec)))
return VectorOfArray(similar.(Base.parent(vec)))

Check warning on line 721 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L721

Added line #L721 was not covered by tests
Copy link
Contributor Author

@jlchan jlchan May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broadcasting similar here avoids some #undef issues.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note: this implementation could replace these 3 functions

@inline function Base.similar(VA::VectorOfArray, ::Type{T} = eltype(VA)) where {T}
VectorOfArray([similar(VA[:, i], T) for i in eachindex(VA.u)])
end
# for VectorOfArray with multi-dimensional parent arrays of arrays where all elements are the same type
function Base.similar(vec::VectorOfArray{
T, N, AT}) where {T, N, AT <: AbstractArray{<:AbstractArray{T}}}
return VectorOfArray(similar(Base.parent(vec)))
end
# special-case when the multi-dimensional parent array is just an AbstractVector (call the old method)
function Base.similar(vec::VectorOfArray{
T, N, AT}) where {T, N, AT <: AbstractVector{<:AbstractArray{T}}}
return Base.similar(vec, eltype(vec))
end

However, doing so breaks these tests because similar can't be called on <:Real values:

x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
testva = DiffEqArray(x, x)
testvb = DiffEqArray(x, x)
mulX = sqrt.(abs.(testva .* testvb))
ref = sqrt.(abs.(x .* x))
@test mulX == ref
fill!(mulX, 0)
mulX .= sqrt.(abs.(testva .* testvb))
@test mulX == ref

These tests don't seem correct to me; the parent array x doesn't have underlying data structure Vector{AbstractArray{T}}, so it shouldn't be a valid edge case. Maybe I'm missing something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numbers are also allowed. That's used a lot downstream.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can special-case that case where the elements are numbers, but we cannot throw it out.

end

# special-case when the multi-dimensional parent array is just an AbstractVector (call the old method)
Expand All @@ -728,6 +727,7 @@
return Base.similar(vec, eltype(vec))
end


# fill!
# For DiffEqArray it ignores ts and fills only u
function Base.fill!(VA::AbstractVectorOfArray, x)
Expand Down Expand Up @@ -840,12 +840,37 @@
# make vectorofarrays broadcastable so they aren't collected
Broadcast.broadcastable(x::AbstractVectorOfArray) = x

# recurse through broadcast arguments and return a parent array for
# the first VoA or DiffEqArray in the bc arguments
function find_VoA_parent(args)
arg = Base.first(args)
if arg isa AbstractDiffEqArray

Check warning on line 847 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L845-L847

Added lines #L845 - L847 were not covered by tests
# if first(args) is a DiffEqArray, use the underlying
# field `u` of DiffEqArray as a parent array.
return arg.u
elseif arg isa AbstractVectorOfArray
return parent(arg)

Check warning on line 852 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L850-L852

Added lines #L850 - L852 were not covered by tests
else
return find_VoA_parent(Base.tail(args))

Check warning on line 854 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L854

Added line #L854 was not covered by tests
end
end

@inline function Base.copy(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
bc = Broadcast.flatten(bc)
N = narrays(bc)
VectorOfArray(map(1:N) do i
copy(unpack_voa(bc, i))
end)

parent = find_VoA_parent(bc.args)

Check warning on line 861 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L861

Added line #L861 was not covered by tests

if parent isa AbstractVector

Check warning on line 863 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L863

Added line #L863 was not covered by tests
# this is the default behavior in v3.15.0
N = narrays(bc)
return VectorOfArray(map(1:N) do i
copy(unpack_voa(bc, i))

Check warning on line 867 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L865-L867

Added lines #L865 - L867 were not covered by tests
end)
else # if parent isa AbstractArray
return VectorOfArray(map(enumerate(Iterators.product(axes(parent)...))) do (i, _)
copy(unpack_voa(bc, i))

Check warning on line 871 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L870-L871

Added lines #L870 - L871 were not covered by tests
end)
end
Comment on lines +863 to +873
Copy link
Contributor Author

@jlchan jlchan May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note, this entire block can be replaced by the second branch

return VectorOfArray(map(enumerate(Iterators.product(axes(parent)...))) do (i, _)
    copy(unpack_voa(bc, i))
end)

since enumerate(Iterators.product(axes(parent)...))) do (i, _) basically recovers map(1:N) do i for an AbstractVector.

I left the old behavior in since I wasn't sure if it would break something upstream or change performance (and it's easier to read).

end

for (type, N_expr) in [
Expand Down
10 changes: 6 additions & 4 deletions test/basic_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ a[[1, 3, 8]]
# multidimensional array of arrays
####################################################################

u_matrix = VectorOfArray(fill([1, 2], 2, 3))
u_vector = VectorOfArray(vec(u_matrix.u))
u_matrix = VectorOfArray([[1, 2] for i in 1:2, j in 1:3])
u_vector = VectorOfArray([[1, 2] for i in 1:6])

# test broadcasting
function foo!(u)
Expand All @@ -248,11 +248,13 @@ function foo!(u)
end
foo!(u_matrix)
foo!(u_vector)
@test u_matrix ≈ u_vector
@test all(u_matrix .== [3, 10])
@test all(vec(u_matrix) .≈ vec(u_vector))

# test that, for VectorOfArray with multi-dimensional parent arrays,
# `similar` preserves the structure of the parent array
# broadcast and `similar` preserve the structure of the parent array
@test typeof(parent(similar(u_matrix))) == typeof(parent(u_matrix))
@test typeof(parent((x->x).(u_matrix))) == typeof(parent(u_matrix))

# test efficiency
num_allocs = @allocations foo!(u_matrix)
Expand Down
Loading