Skip to content

Commit

Permalink
fix: VectorOfArray mapreduce, sum/prod performance, minor bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 5, 2024
1 parent 873a072 commit 1ecc966
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 9 deletions.
44 changes: 35 additions & 9 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,16 @@ function VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N}
VectorOfArray{eltype(T), N, typeof(vec)}(vec)
end
# Assume that the first element is representative of all other elements
VectorOfArray(vec::AbstractVector) = VectorOfArray(vec, (size(vec[1])..., length(vec)))
function VectorOfArray(vec::AbstractVector)
T = eltype(vec[1])
N = ndims(vec[1])
if all(x isa Union{<:AbstractArray, <:AbstractVectorOfArray} for x in vec)
A = Vector{Union{typeof.(vec)...}}
else
A = typeof(vec)
end
VectorOfArray{T, N + 1, A}(vec)
end
function VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT <: AbstractArray{T, N}}
VectorOfArray{T, N + 1, typeof(vec)}(vec)
end
Expand Down Expand Up @@ -482,21 +491,30 @@ function Base.append!(VA::AbstractVectorOfArray{T, N},
return VA
end

function Base.stack(VA::AbstractVectorOfArray; dims = :)
stack(VA.u; dims)
end

# AbstractArray methods
function Base.view(A::AbstractVectorOfArray, I::Vararg{Any,M}) where {M}
@inline
J = map(i->Base.unalias(A,i), to_indices(A, I))
@boundscheck checkbounds(A, J...)
SubArray(IndexStyle(A), A, J, Base.index_dimsum(J...))
end
function Base.SubArray(parent::AbstractVectorOfArray, indices::Tuple)
@inline
SubArray(IndexStyle(Base.viewindexing(indices), IndexStyle(parent)), parent, Base.ensure_indexable(indices), Base.index_dimsum(indices...))
end
Base.isassigned(VA::AbstractVectorOfArray, idxs...) = checkbounds(Bool, VA, idxs...)
Base.check_parent_index_match(::RecursiveArrayTools.AbstractVectorOfArray{T,N}, ::NTuple{N,Bool}) where {T,N} = nothing
Base.ndims(::AbstractVectorOfArray{T, N}) where {T, N} = N
function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...)
if checkbounds(Bool, VA.u, last(idx))
if last(idx) isa Integer
return all(checkbounds.(Bool, (VA.u[last(idx)],), Base.front(idx)))
return all(checkbounds.(Bool, (VA.u[last(idx)],), Base.front(idx)...))
else
return all(checkbounds.(Bool, VA.u[last(idx)], Base.front(idx)))
return all(checkbounds.(Bool, VA.u[last(idx)], tuple.(Base.front(idx))...))
end
end
return false
Expand Down Expand Up @@ -595,10 +613,14 @@ function Base.convert(::Type{Array}, VA::AbstractVectorOfArray)
end

# statistics
@inline Base.sum(f, VA::AbstractVectorOfArray) = sum(f, Array(VA))
@inline Base.sum(VA::AbstractVectorOfArray; kwargs...) = sum(Array(VA); kwargs...)
@inline Base.prod(f, VA::AbstractVectorOfArray) = prod(f, Array(VA))
@inline Base.prod(VA::AbstractVectorOfArray; kwargs...) = prod(Array(VA); kwargs...)
@inline Base.sum(VA::AbstractVectorOfArray; kwargs...) = sum(identity, VA; kwargs...)
@inline function Base.sum(f, VA::AbstractVectorOfArray; kwargs...)
mapreduce(f, Base.add_sum, VA; kwargs...)
end
@inline Base.prod(VA::AbstractVectorOfArray; kwargs...) = prod(identity, VA; kwargs...)
@inline function Base.prod(f, VA::AbstractVectorOfArray; kwargs...)
mapreduce(f, Base.mul_prod, VA; kwargs...)
end

@inline Statistics.mean(VA::AbstractVectorOfArray; kwargs...) = mean(Array(VA); kwargs...)
@inline function Statistics.median(VA::AbstractVectorOfArray; kwargs...)
Expand Down Expand Up @@ -638,8 +660,12 @@ end
end

Base.map(f, A::RecursiveArrayTools.AbstractVectorOfArray) = map(f, A.u)
function Base.mapreduce(f, op, A::AbstractVectorOfArray)
mapreduce(f, op, (mapreduce(f, op, x) for x in A.u))

function Base.mapreduce(f, op, A::AbstractVectorOfArray; kwargs...)
mapreduce(f, op, view(A, ntuple(_ -> :, ndims(A))...); kwargs...)
end
function Base.mapreduce(f, op, A::AbstractVectorOfArray{T,1,<:AbstractVector{T}}; kwargs...) where {T}
mapreduce(f, op, A.u; kwargs...)
end

## broadcasting
Expand Down
38 changes: 38 additions & 0 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,44 @@ push!(testda, [-1, -2, -3, -4])
@test_throws MethodError push!(testda, [-1 -2 -3 -4])
@test_throws MethodError push!(testda, [-1 -2; -3 -4])

# Type inference
@inferred sum(testva)
@inferred sum(VectorOfArray([VectorOfArray([zeros(4,4)])]))
@inferred mapreduce(string, *, testva)

# mapreduce
testva = VectorOfArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
@test mapreduce(x -> string(x) * "q", *, testva) == "1q2q3q4q5q6q7q8q9q"

testvb = VectorOfArray([rand(1:10, 3, 3, 3) for _ in 1:4])
arrvb = Array(testvb)
for i in 1:ndims(arrvb)
@test sum(arrvb; dims=i) == sum(testvb; dims=i)
@test prod(arrvb; dims=i) == prod(testvb; dims=i)
@test mapreduce(string, *, arrvb; dims=i) == mapreduce(string, *, testvb; dims=i)
end

# Test when ndims == 1
testvb = VectorOfArray(collect(1.0:0.1:2.0))
arrvb = Array(testvb)
@test sum(arrvb) == sum(testvb)
@test prod(arrvb) == prod(testvb)
@test mapreduce(string, *, arrvb) == mapreduce(string, *, testvb)

# view
testvc = VectorOfArray([rand(1:10, 3, 3) for _ in 1:3])
arrvc = Array(testvc)
for idxs in [(2, 2, :), (2, :, 2), (:, 2, 2), (:, :, 2), (:, 2, :), (2, : ,:), (:, :, :)]
arr_view = view(arrvc, idxs...)
voa_view = view(testvc, idxs...)
@test size(arr_view) == size(voa_view)
@test all(arr_view .== voa_view)
end

# test stack
@test stack(testva) == [1 4 7; 2 5 8; 3 6 9]
@test stack(testva; dims = 1) == [1 2 3; 4 5 6; 7 8 9]

# convert array from VectorOfArray/DiffEqArray
t = 1:8
recs = [rand(10, 7) for i in 1:8]
Expand Down

0 comments on commit 1ecc966

Please sign in to comment.