Skip to content

Commit

Permalink
Merge pull request #327 from AayushSabharwal/as/checkbounds
Browse files Browse the repository at this point in the history
fix: fix checkbounds, view methods, indexing, and add tests
  • Loading branch information
ChrisRackauckas authored Jan 6, 2024
2 parents 33dd4ea + 9afae5a commit 8ddcb4a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
26 changes: 25 additions & 1 deletion src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ Base.IndexStyle(::Type{<:AbstractVectorOfArray}) = IndexCartesian()
@inline function Base.eachindex(VA::AbstractVectorOfArray)
return eachindex(VA.u)
end
@inline function Base.eachindex(::IndexLinear, VA::AbstractVectorOfArray{T,N,<:AbstractVector{T}}) where {T, N}
return eachindex(IndexLinear(), VA.u)
end
@inline Base.IteratorSize(::Type{<:AbstractVectorOfArray}) = Base.HasLength()
@inline Base.first(VA::AbstractVectorOfArray) = first(VA.u)
@inline Base.last(VA::AbstractVectorOfArray) = last(VA.u)
Expand All @@ -245,7 +248,11 @@ __parameterless_type(T) = Base.typename(T).wrapper
Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray{T, N},
::NotSymbolic, I::Colon...) where {T, N}
@assert length(I) == ndims(A.u[1]) + 1
vecs = vec.(A.u)
vecs = if N == 1
A.u
else
vec.(A.u)
end
return Adapt.adapt(__parameterless_type(T),
reshape(reduce(hcat, vecs), size(A.u[1])..., length(A.u)))
end
Expand Down Expand Up @@ -496,6 +503,16 @@ function Base.stack(VA::AbstractVectorOfArray; dims = :)
end

# AbstractArray methods
function Base.view(A::AbstractVectorOfArray{T,N,<:AbstractVector{T}}, I::Vararg{Any, M}) where {T,N,M}
@inline
if length(I) == 1
J = map(i->Base.unalias(A,i), to_indices(A, I))
elseif length(I) == 2 && (I[1] == Colon() || I[1] == 1)
J = map(i->Base.unalias(A,i), to_indices(A, Base.tail(I)))
end
@boundscheck checkbounds(A, J...)
SubArray(IndexStyle(A), A, J, Base.index_dimsum(J...))
end
function Base.view(A::AbstractVectorOfArray, I::Vararg{Any,M}) where {M}
@inline
J = map(i->Base.unalias(A,i), to_indices(A, I))
Expand All @@ -509,6 +526,13 @@ end
Base.isassigned(VA::AbstractVectorOfArray, idxs...) = checkbounds(Bool, VA, idxs...)
Base.check_parent_index_match(::RecursiveArrayTools.AbstractVectorOfArray{T,N}, ::NTuple{N,Bool}) where {T,N} = nothing
Base.ndims(::AbstractVectorOfArray{T, N}) where {T, N} = N

function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, idxs...) where {T, N}
if length(idxs) == 2 && (idxs[1] == Colon() || idxs[1] == 1)
return checkbounds(Bool, VA.u, idxs[2])
end
return checkbounds(Bool, VA.u, idxs...)
end
function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...)
if checkbounds(Bool, VA.u, last(idx))
if last(idx) isa Integer
Expand Down
14 changes: 14 additions & 0 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@ for idxs in [(2, 2, :), (2, :, 2), (:, 2, 2), (:, :, 2), (:, 2, :), (2, : ,:), (
@test all(arr_view .== voa_view)
end

testvc = VectorOfArray(collect(1:10))
arrvc = Array(testvc)
for (voaidx, arridx) in [
((:,), (:,)),
((3:5,), (3:5,)),
((:, 3:5), (3:5,)),
((1, 3:5), (3:5,)),
]
arr_view = view(arrvc, arridx...)
voa_view = view(testvc, voaidx...)
@test size(arr_view) == size(voa_view)
@test all(arr_view .== voa_view)
end

# test stack
@test stack(testva) == [1 4 7; 2 5 8; 3 6 9]
@test stack(testva; dims = 1) == [1 2 3; 4 5 6; 7 8 9]
Expand Down

0 comments on commit 8ddcb4a

Please sign in to comment.