From 1ecc9661e3e34c7c7f864d967c9db876207047d3 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 5 Jan 2024 17:42:44 +0530 Subject: [PATCH] fix: VectorOfArray mapreduce, sum/prod performance, minor bugs --- src/vector_of_array.jl | 44 ++++++++++++++++++++++++++++++++--------- test/interface_tests.jl | 38 +++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 9 deletions(-) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 9dcfc409..14476500 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -133,7 +133,16 @@ function VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N} VectorOfArray{eltype(T), N, typeof(vec)}(vec) end # Assume that the first element is representative of all other elements -VectorOfArray(vec::AbstractVector) = VectorOfArray(vec, (size(vec[1])..., length(vec))) +function VectorOfArray(vec::AbstractVector) + T = eltype(vec[1]) + N = ndims(vec[1]) + if all(x isa Union{<:AbstractArray, <:AbstractVectorOfArray} for x in vec) + A = Vector{Union{typeof.(vec)...}} + else + A = typeof(vec) + end + VectorOfArray{T, N + 1, A}(vec) +end function VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT <: AbstractArray{T, N}} VectorOfArray{T, N + 1, typeof(vec)}(vec) end @@ -482,6 +491,10 @@ function Base.append!(VA::AbstractVectorOfArray{T, N}, return VA end +function Base.stack(VA::AbstractVectorOfArray; dims = :) + stack(VA.u; dims) +end + # AbstractArray methods function Base.view(A::AbstractVectorOfArray, I::Vararg{Any,M}) where {M} @inline @@ -489,14 +502,19 @@ function Base.view(A::AbstractVectorOfArray, I::Vararg{Any,M}) where {M} @boundscheck checkbounds(A, J...) SubArray(IndexStyle(A), A, J, Base.index_dimsum(J...)) end +function Base.SubArray(parent::AbstractVectorOfArray, indices::Tuple) + @inline + SubArray(IndexStyle(Base.viewindexing(indices), IndexStyle(parent)), parent, Base.ensure_indexable(indices), Base.index_dimsum(indices...)) +end +Base.isassigned(VA::AbstractVectorOfArray, idxs...) = checkbounds(Bool, VA, idxs...) Base.check_parent_index_match(::RecursiveArrayTools.AbstractVectorOfArray{T,N}, ::NTuple{N,Bool}) where {T,N} = nothing Base.ndims(::AbstractVectorOfArray{T, N}) where {T, N} = N function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...) if checkbounds(Bool, VA.u, last(idx)) if last(idx) isa Integer - return all(checkbounds.(Bool, (VA.u[last(idx)],), Base.front(idx))) + return all(checkbounds.(Bool, (VA.u[last(idx)],), Base.front(idx)...)) else - return all(checkbounds.(Bool, VA.u[last(idx)], Base.front(idx))) + return all(checkbounds.(Bool, VA.u[last(idx)], tuple.(Base.front(idx))...)) end end return false @@ -595,10 +613,14 @@ function Base.convert(::Type{Array}, VA::AbstractVectorOfArray) end # statistics -@inline Base.sum(f, VA::AbstractVectorOfArray) = sum(f, Array(VA)) -@inline Base.sum(VA::AbstractVectorOfArray; kwargs...) = sum(Array(VA); kwargs...) -@inline Base.prod(f, VA::AbstractVectorOfArray) = prod(f, Array(VA)) -@inline Base.prod(VA::AbstractVectorOfArray; kwargs...) = prod(Array(VA); kwargs...) +@inline Base.sum(VA::AbstractVectorOfArray; kwargs...) = sum(identity, VA; kwargs...) +@inline function Base.sum(f, VA::AbstractVectorOfArray; kwargs...) + mapreduce(f, Base.add_sum, VA; kwargs...) +end +@inline Base.prod(VA::AbstractVectorOfArray; kwargs...) = prod(identity, VA; kwargs...) +@inline function Base.prod(f, VA::AbstractVectorOfArray; kwargs...) + mapreduce(f, Base.mul_prod, VA; kwargs...) +end @inline Statistics.mean(VA::AbstractVectorOfArray; kwargs...) = mean(Array(VA); kwargs...) @inline function Statistics.median(VA::AbstractVectorOfArray; kwargs...) @@ -638,8 +660,12 @@ end end Base.map(f, A::RecursiveArrayTools.AbstractVectorOfArray) = map(f, A.u) -function Base.mapreduce(f, op, A::AbstractVectorOfArray) - mapreduce(f, op, (mapreduce(f, op, x) for x in A.u)) + +function Base.mapreduce(f, op, A::AbstractVectorOfArray; kwargs...) + mapreduce(f, op, view(A, ntuple(_ -> :, ndims(A))...); kwargs...) +end +function Base.mapreduce(f, op, A::AbstractVectorOfArray{T,1,<:AbstractVector{T}}; kwargs...) where {T} + mapreduce(f, op, A.u; kwargs...) end ## broadcasting diff --git a/test/interface_tests.jl b/test/interface_tests.jl index d2f4030b..3aba249a 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -57,6 +57,44 @@ push!(testda, [-1, -2, -3, -4]) @test_throws MethodError push!(testda, [-1 -2 -3 -4]) @test_throws MethodError push!(testda, [-1 -2; -3 -4]) +# Type inference +@inferred sum(testva) +@inferred sum(VectorOfArray([VectorOfArray([zeros(4,4)])])) +@inferred mapreduce(string, *, testva) + +# mapreduce +testva = VectorOfArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) +@test mapreduce(x -> string(x) * "q", *, testva) == "1q2q3q4q5q6q7q8q9q" + +testvb = VectorOfArray([rand(1:10, 3, 3, 3) for _ in 1:4]) +arrvb = Array(testvb) +for i in 1:ndims(arrvb) + @test sum(arrvb; dims=i) == sum(testvb; dims=i) + @test prod(arrvb; dims=i) == prod(testvb; dims=i) + @test mapreduce(string, *, arrvb; dims=i) == mapreduce(string, *, testvb; dims=i) +end + +# Test when ndims == 1 +testvb = VectorOfArray(collect(1.0:0.1:2.0)) +arrvb = Array(testvb) +@test sum(arrvb) == sum(testvb) +@test prod(arrvb) == prod(testvb) +@test mapreduce(string, *, arrvb) == mapreduce(string, *, testvb) + +# view +testvc = VectorOfArray([rand(1:10, 3, 3) for _ in 1:3]) +arrvc = Array(testvc) +for idxs in [(2, 2, :), (2, :, 2), (:, 2, 2), (:, :, 2), (:, 2, :), (2, : ,:), (:, :, :)] + arr_view = view(arrvc, idxs...) + voa_view = view(testvc, idxs...) + @test size(arr_view) == size(voa_view) + @test all(arr_view .== voa_view) +end + +# test stack +@test stack(testva) == [1 4 7; 2 5 8; 3 6 9] +@test stack(testva; dims = 1) == [1 2 3; 4 5 6; 7 8 9] + # convert array from VectorOfArray/DiffEqArray t = 1:8 recs = [rand(10, 7) for i in 1:8]