Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework host indexing. #499

Merged
merged 6 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions src/host/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,21 +308,10 @@ function Adapt.adapt_storage(to::ToGPU, xs::Array)
arr
end

# we don't really want an array, so don't call `adapt(Array, ...)`,
# but just want GPUArray indices to get downloaded back to the CPU.
# this makes sure we preserve array-like containers, like Base.Slice.
struct BackToCPU end
Adapt.adapt_storage(::BackToCPU, xs::AbstractGPUArray) = convert(Array, xs)

@inline function Base.view(A::AbstractGPUArray, I::Vararg{Any,N}) where {N}
J = to_indices(A, I)
@boundscheck begin
# Base's boundscheck accesses the indices, so make sure they reside on the CPU.
# this is expensive, but it's a bounds check after all.
J_cpu = map(j->adapt(BackToCPU(), j), J)
checkbounds(A, J_cpu...)
end
J_gpu = map(j->adapt(ToGPU(A), j), J)
@boundscheck checkbounds(A, J...)
unsafe_view(A, J_gpu, GPUIndexStyle(I...))
end

Expand Down
147 changes: 98 additions & 49 deletions src/host/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,51 +1,77 @@
# host-level indexing


# basic indexing with integers
# indexing operators

Base.IndexStyle(::Type{<:AbstractGPUArray}) = Base.IndexLinear()

function Base.getindex(xs::AbstractGPUArray{T}, I::Integer...) where T
vectorized_indices(Is::Union{Integer,CartesianIndex}...) = Val{false}()
vectorized_indices(Is...) = Val{true}()

# TODO: re-use Base functionality for the conversion of indices to a linear index,
# by only implementing `getindex(A, ::Int)` etc. this is difficult due to
# ambiguities with the vectorized method that can take any index type.

Base.@propagate_inbounds Base.getindex(A::AbstractGPUArray, Is...) =
_getindex(vectorized_indices(Is...), A, to_indices(A, Is)...)
Base.@propagate_inbounds _getindex(::Val{false}, A::AbstractGPUArray, Is...) =
scalar_getindex(A, to_indices(A, Is)...)
Base.@propagate_inbounds _getindex(::Val{true}, A::AbstractGPUArray, Is...) =
vectorized_getindex(A, to_indices(A, Is)...)

Base.@propagate_inbounds Base.setindex!(A::AbstractGPUArray, v, Is...) =
_setindex!(vectorized_indices(Is...), A, v, to_indices(A, Is)...)
Base.@propagate_inbounds _setindex!(::Val{false}, A::AbstractGPUArray, v, Is...) =
scalar_setindex!(A, v, to_indices(A, Is)...)
Base.@propagate_inbounds _setindex!(::Val{true}, A::AbstractGPUArray, v, Is...) =
vectorized_setindex!(A, v, to_indices(A, Is)...)

## scalar indexing

function scalar_getindex(A::AbstractGPUArray{T}, Is...) where T
@boundscheck checkbounds(A, Is...)
I = Base._to_linear_index(A, Is...)
getindex(A, I)
end

function scalar_setindex!(A::AbstractGPUArray{T}, v, Is...) where T
@boundscheck checkbounds(A, Is...)
I = Base._to_linear_index(A, Is...)
setindex!(A, v, I)
end

# we still dispatch to `Base.getindex(a, ::Int)` etc so that there's a single method to
# override when a back-end (e.g. with unified memory) wants to allow scalar indexing.

function Base.getindex(A::AbstractGPUArray{T}, I::Int) where T
@boundscheck checkbounds(A, I)
assertscalar("getindex")
i = Base._to_linear_index(xs, I...)
x = Array{T}(undef, 1)
copyto!(x, 1, xs, i, 1)
copyto!(x, 1, A, I, 1)
return x[1]
end

function Base.setindex!(xs::AbstractGPUArray{T}, v::T, I::Integer...) where T
function Base.setindex!(A::AbstractGPUArray{T}, v, I::Int) where T
@boundscheck checkbounds(A, I)
assertscalar("setindex!")
i = Base._to_linear_index(xs, I...)
x = T[v]
copyto!(xs, i, x, 1, 1)
return xs
copyto!(A, I, x, 1, 1)
return A
end

Base.setindex!(xs::AbstractGPUArray, v, I::Integer...) =
setindex!(xs, convert(eltype(xs), v), I...)

## vectorized indexing

# basic indexing with cartesian indices

Base.@propagate_inbounds Base.getindex(A::AbstractGPUArray, I::Union{Integer, CartesianIndex}...) =
A[Base.to_indices(A, I)...]
Base.@propagate_inbounds Base.setindex!(A::AbstractGPUArray, v, I::Union{Integer, CartesianIndex}...) =
(A[Base.to_indices(A, I)...] = v; A)


# generalized multidimensional indexing

Base.getindex(A::AbstractGPUArray, I...) = _getindex(A, to_indices(A, I)...)

function _getindex(src::AbstractGPUArray, Is...)
function vectorized_getindex(src::AbstractGPUArray, Is...)
shape = Base.index_shape(Is...)
dest = similar(src, shape)
any(isempty, Is) && return dest # indexing with empty array
idims = map(length, Is)

AT = typeof(src).name.wrapper
# NOTE: we are pretty liberal here supporting non-GPU indices...
gpu_call(getindex_kernel, dest, src, idims, adapt(AT, Is)...)
Is = map(x->adapt(ToGPU(src), x), Is)
@boundscheck checkbounds(src, Is...)

gpu_call(getindex_kernel, dest, src, idims, Is...)
return dest
end

Expand All @@ -61,9 +87,7 @@ end
end
end

Base.setindex!(A::AbstractGPUArray, v, I...) = _setindex!(A, v, to_indices(A, I)...)

function _setindex!(dest::AbstractGPUArray, src, Is...)
function vectorized_setindex!(dest::AbstractGPUArray, src, Is...)
isempty(Is) && return dest
idims = length.(Is)
len = prod(idims)
Expand All @@ -76,9 +100,11 @@ function _setindex!(dest::AbstractGPUArray, src, Is...)
end
end

AT = typeof(dest).name.wrapper
# NOTE: we are pretty liberal here supporting non-GPU sources and indices...
gpu_call(setindex_kernel, dest, adapt(AT, src), idims, len, adapt(AT, Is)...;
# NOTE: we are pretty liberal here supporting non-GPU indices...
Is = map(x->adapt(ToGPU(dest), x), Is)
@boundscheck checkbounds(dest, Is...)

gpu_call(setindex_kernel, dest, adapt(ToGPU(dest), src), idims, len, Is...;
elements=len)
return dest
end
Expand All @@ -96,7 +122,30 @@ end
end


## find*
# bounds checking

# indices residing on the GPU should be bounds-checked on the GPU to avoid iteration.

# not all wrapped GPU arrays make sense as indices, so we use a subset of `AnyGPUArray`
const IndexGPUArray{T} = Union{AbstractGPUArray{T},
SubArray{T, <:Any, <:AbstractGPUArray},
LinearAlgebra.Adjoint{T}}

@inline function Base.checkindex(::Type{Bool}, inds::AbstractUnitRange, I::IndexGPUArray)
all(broadcast(I) do i
Base.checkindex(Bool, inds, i)
end)
end

@inline function Base.checkindex(::Type{Bool}, inds::Tuple,
I::IndexGPUArray{<:CartesianIndex})
all(broadcast(I) do i
Base.checkbounds_indices(Bool, inds, (i,))
end)
end


# find*

# simple array type that returns the index used to access an element, while
# retaining the dimensionality of the original array. this can be used to
Expand All @@ -107,15 +156,15 @@ struct EachIndex{T,N,IS} <: AbstractArray{T,N}
dims::NTuple{N,Int}
indices::IS
end
EachIndex(xs::AbstractArray) =
EachIndex{typeof(firstindex(xs)), ndims(xs), typeof(eachindex(xs))}(
size(xs), eachindex(xs))
EachIndex(A::AbstractArray) =
EachIndex{typeof(firstindex(A)), ndims(A), typeof(eachindex(A))}(
size(A), eachindex(A))
Base.size(ei::EachIndex) = ei.dims
Base.getindex(ei::EachIndex, i::Int) = ei.indices[i]
Base.IndexStyle(::Type{<:EachIndex}) = Base.IndexLinear()

function Base.findfirst(f::Function, xs::AnyGPUArray)
indices = EachIndex(xs)
function Base.findfirst(f::Function, A::AnyGPUArray)
indices = EachIndex(A)
dummy_index = first(indices)

# given two pairs of (istrue, index), return the one with the smallest index
Expand All @@ -130,23 +179,23 @@ function Base.findfirst(f::Function, xs::AnyGPUArray)
return (false, dummy_index)
end

res = mapreduce((x, y)->(f(x), y), reduction, xs, indices;
res = mapreduce((x, y)->(f(x), y), reduction, A, indices;
init = (false, dummy_index))
if res[1]
# out of consistency with Base.findarray, return a CartesianIndex
# when the input is a multidimensional array
ndims(xs) == 1 && return res[2]
return CartesianIndices(xs)[res[2]]
ndims(A) == 1 && return res[2]
return CartesianIndices(A)[res[2]]
else
return nothing
end
end

Base.findfirst(xs::AnyGPUArray{Bool}) = findfirst(identity, xs)
Base.findfirst(A::AnyGPUArray{Bool}) = findfirst(identity, A)

function findminmax(binop, xs::AnyGPUArray; init, dims)
indices = EachIndex(xs)
dummy_index = firstindex(xs)
function findminmax(binop, A::AnyGPUArray; init, dims)
indices = EachIndex(A)
dummy_index = firstindex(A)

function reduction(t1, t2)
(x, i), (y, j) = t1, t2
Expand All @@ -157,16 +206,16 @@ function findminmax(binop, xs::AnyGPUArray; init, dims)
end

if dims == Colon()
res = mapreduce(tuple, reduction, xs, indices; init = (init, dummy_index))
res = mapreduce(tuple, reduction, A, indices; init = (init, dummy_index))

# out of consistency with Base.findarray, return a CartesianIndex
# when the input is a multidimensional array
return (res[1], ndims(xs) == 1 ? res[2] : CartesianIndices(xs)[res[2]])
return (res[1], ndims(A) == 1 ? res[2] : CartesianIndices(A)[res[2]])
else
res = mapreduce(tuple, reduction, xs, indices;
res = mapreduce(tuple, reduction, A, indices;
init = (init, dummy_index), dims=dims)
vals = map(x->x[1], res)
inds = map(x->ndims(xs) == 1 ? x[2] : CartesianIndices(xs)[x[2]], res)
inds = map(x->ndims(A) == 1 ? x[2] : CartesianIndices(A)[x[2]], res)
return (vals, inds)
end
end
Expand Down
6 changes: 6 additions & 0 deletions test/testsuite/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ end
@test_throws DimensionMismatch x[1:9,1:9,:,:] = y
end

@testset "mismatching axes/indices" begin
a = rand(Float32, 1,1)
@test compare(a->a[1:1], AT, a)
@test compare(a->a[1:1,1:1], AT, a)
@test compare(a->a[1:1,1:1,1:1], AT, a)
end
end

@testsuite "indexing find" (AT, eltypes)->begin
Expand Down
Loading