diff --git a/src/host/indexing.jl b/src/host/indexing.jl index 2b8e17da..25fa0f7f 100644 --- a/src/host/indexing.jl +++ b/src/host/indexing.jl @@ -1,43 +1,62 @@ # host-level indexing -# basic indexing with integers +# indexing operators Base.IndexStyle(::Type{<:AbstractGPUArray}) = Base.IndexLinear() -function Base.getindex(xs::AbstractGPUArray{T}, I::Integer...) where T +vectorized_indices(Is::Union{Integer,CartesianIndex}...) = Val{false}() +vectorized_indices(Is...) = Val{true}() + +# TODO: re-use Base functionality for the conversion of indices to a linear index, +# by only implementing `getindex(A, ::Int)` etc. this is difficult if we want +# to also want to match the case where we take any vectorized index... + +Base.@propagate_inbounds Base.getindex(A::AbstractGPUArray, Is...) = + _getindex(vectorized_indices(Is...), A, to_indices(A, Is)...) +Base.@propagate_inbounds _getindex(::Val{false}, A::AbstractGPUArray, Is...) = + scalar_getindex(A, to_indices(A, Is)...) +Base.@propagate_inbounds _getindex(::Val{true}, A::AbstractGPUArray, Is...) = + vectorized_getindex(A, to_indices(A, Is)...) + +Base.@propagate_inbounds Base.setindex!(A::AbstractGPUArray, v, Is...) = + _setindex!(vectorized_indices(Is...), A, v, to_indices(A, Is)...) +Base.@propagate_inbounds _setindex!(::Val{false}, A::AbstractGPUArray, v, Is...) = + scalar_setindex!(A, v, to_indices(A, Is)...) +Base.@propagate_inbounds _setindex!(::Val{true}, A::AbstractGPUArray, v, Is...) = + vectorized_setindex!(A, v, to_indices(A, Is)...) + +## scalar indexing + +function scalar_getindex(A::AbstractGPUArray{T}, Is...) where T assertscalar("getindex") - i = Base._to_linear_index(xs, I...) + @boundscheck checkbounds(A, Is...) + i = Base._to_linear_index(A, Is...) x = Array{T}(undef, 1) - copyto!(x, 1, xs, i, 1) + copyto!(x, 1, A, i, 1) return x[1] end -function Base.setindex!(xs::AbstractGPUArray{T}, v::T, I::Integer...) where T +function scalar_setindex!(A::AbstractGPUArray{T}, v, Is...) where T assertscalar("setindex!") - i = Base._to_linear_index(xs, I...) + @boundscheck checkbounds(A, Is...) + i = Base._to_linear_index(A, Is...) x = T[v] - copyto!(xs, i, x, 1, 1) - return xs + copyto!(A, i, x, 1, 1) + return A end -Base.setindex!(xs::AbstractGPUArray, v, I::Integer...) = - setindex!(xs, convert(eltype(xs), v), I...) - +## vectorized indexing -# basic indexing with cartesian indices - -Base.@propagate_inbounds Base.getindex(A::AbstractGPUArray, I::Union{Integer, CartesianIndex}...) = - A[Base.to_indices(A, I)...] -Base.@propagate_inbounds Base.setindex!(A::AbstractGPUArray, v, I::Union{Integer, CartesianIndex}...) = - (A[Base.to_indices(A, I)...] = v; A) - - -# generalized multidimensional indexing - -Base.getindex(A::AbstractGPUArray, I...) = _getindex(A, to_indices(A, I)...) +function vectorized_checkbounds(src, Is) + # Base's boundscheck accesses the indices, so make sure they reside on the CPU. + # this is expensive, but it's a bounds check after all. + Is_cpu = map(I->adapt(BackToCPU(), I), Is) + checkbounds(src, Is_cpu...) +end -function _getindex(src::AbstractGPUArray, Is...) +function vectorized_getindex(src::AbstractGPUArray, Is...) + @boundscheck vectorized_checkbounds(src, Is) shape = Base.index_shape(Is...) dest = similar(src, shape) any(isempty, Is) && return dest # indexing with empty array @@ -61,9 +80,8 @@ end end end -Base.setindex!(A::AbstractGPUArray, v, I...) = _setindex!(A, v, to_indices(A, I)...) - -function _setindex!(dest::AbstractGPUArray, src, Is...) +function vectorized_setindex!(dest::AbstractGPUArray, src, Is...) + @boundscheck vectorized_checkbounds(dest, Is) isempty(Is) && return dest idims = length.(Is) len = prod(idims) @@ -96,7 +114,7 @@ end end -## find* +# find* # simple array type that returns the index used to access an element, while # retaining the dimensionality of the original array. this can be used to