Skip to content

Commit

Permalink
Rework host indexing.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Oct 30, 2023
1 parent e21889b commit 591c6a0
Showing 1 changed file with 45 additions and 27 deletions.
72 changes: 45 additions & 27 deletions src/host/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,43 +1,62 @@
# host-level indexing


# basic indexing with integers
# indexing operators

Base.IndexStyle(::Type{<:AbstractGPUArray}) = Base.IndexLinear()

function Base.getindex(xs::AbstractGPUArray{T}, I::Integer...) where T
vectorized_indices(Is::Union{Integer,CartesianIndex}...) = Val{false}()
vectorized_indices(Is...) = Val{true}()

# TODO: re-use Base functionality for the conversion of indices to a linear index,
# by only implementing `getindex(A, ::Int)` etc. this is difficult if we want
# to also want to match the case where we take any vectorized index...

Base.@propagate_inbounds Base.getindex(A::AbstractGPUArray, Is...) =
_getindex(vectorized_indices(Is...), A, to_indices(A, Is)...)
Base.@propagate_inbounds _getindex(::Val{false}, A::AbstractGPUArray, Is...) =
scalar_getindex(A, to_indices(A, Is)...)
Base.@propagate_inbounds _getindex(::Val{true}, A::AbstractGPUArray, Is...) =
vectorized_getindex(A, to_indices(A, Is)...)

Base.@propagate_inbounds Base.setindex!(A::AbstractGPUArray, v, Is...) =
_setindex!(vectorized_indices(Is...), A, v, to_indices(A, Is)...)
Base.@propagate_inbounds _setindex!(::Val{false}, A::AbstractGPUArray, v, Is...) =
scalar_setindex!(A, v, to_indices(A, Is)...)
Base.@propagate_inbounds _setindex!(::Val{true}, A::AbstractGPUArray, v, Is...) =
vectorized_setindex!(A, v, to_indices(A, Is)...)

## scalar indexing

function scalar_getindex(A::AbstractGPUArray{T}, Is...) where T
assertscalar("getindex")
i = Base._to_linear_index(xs, I...)
@boundscheck checkbounds(A, Is...)
i = Base._to_linear_index(A, Is...)
x = Array{T}(undef, 1)
copyto!(x, 1, xs, i, 1)
copyto!(x, 1, A, i, 1)
return x[1]
end

function Base.setindex!(xs::AbstractGPUArray{T}, v::T, I::Integer...) where T
function scalar_setindex!(A::AbstractGPUArray{T}, v, Is...) where T
assertscalar("setindex!")
i = Base._to_linear_index(xs, I...)
@boundscheck checkbounds(A, Is...)
i = Base._to_linear_index(A, Is...)
x = T[v]
copyto!(xs, i, x, 1, 1)
return xs
copyto!(A, i, x, 1, 1)
return A
end

Base.setindex!(xs::AbstractGPUArray, v, I::Integer...) =
setindex!(xs, convert(eltype(xs), v), I...)

## vectorized indexing

# basic indexing with cartesian indices

Base.@propagate_inbounds Base.getindex(A::AbstractGPUArray, I::Union{Integer, CartesianIndex}...) =
A[Base.to_indices(A, I)...]
Base.@propagate_inbounds Base.setindex!(A::AbstractGPUArray, v, I::Union{Integer, CartesianIndex}...) =
(A[Base.to_indices(A, I)...] = v; A)


# generalized multidimensional indexing

Base.getindex(A::AbstractGPUArray, I...) = _getindex(A, to_indices(A, I)...)
function vectorized_checkbounds(src, Is)
# Base's boundscheck accesses the indices, so make sure they reside on the CPU.
# this is expensive, but it's a bounds check after all.
Is_cpu = map(I->adapt(BackToCPU(), I), Is)
checkbounds(src, Is_cpu...)
end

function _getindex(src::AbstractGPUArray, Is...)
function vectorized_getindex(src::AbstractGPUArray, Is...)
@boundscheck vectorized_checkbounds(src, Is)
shape = Base.index_shape(Is...)
dest = similar(src, shape)
any(isempty, Is) && return dest # indexing with empty array
Expand All @@ -61,9 +80,8 @@ end
end
end

Base.setindex!(A::AbstractGPUArray, v, I...) = _setindex!(A, v, to_indices(A, I)...)

function _setindex!(dest::AbstractGPUArray, src, Is...)
function vectorized_setindex!(dest::AbstractGPUArray, src, Is...)
@boundscheck vectorized_checkbounds(dest, Is)
isempty(Is) && return dest
idims = length.(Is)
len = prod(idims)
Expand Down Expand Up @@ -96,7 +114,7 @@ end
end


## find*
# find*

# simple array type that returns the index used to access an element, while
# retaining the dimensionality of the original array. this can be used to
Expand Down

0 comments on commit 591c6a0

Please sign in to comment.