diff --git a/src/host/indexing.jl b/src/host/indexing.jl index 25fa0f7f..d270ab40 100644 --- a/src/host/indexing.jl +++ b/src/host/indexing.jl @@ -9,8 +9,8 @@ vectorized_indices(Is::Union{Integer,CartesianIndex}...) = Val{false}() vectorized_indices(Is...) = Val{true}() # TODO: re-use Base functionality for the conversion of indices to a linear index, -# by only implementing `getindex(A, ::Int)` etc. this is difficult if we want -# to also want to match the case where we take any vectorized index... +# by only implementing `getindex(A, ::Int)` etc. this is difficult due to +# ambiguities with the vectorized method that can take any index type. Base.@propagate_inbounds Base.getindex(A::AbstractGPUArray, Is...) = _getindex(vectorized_indices(Is...), A, to_indices(A, Is)...) @@ -29,20 +29,33 @@ Base.@propagate_inbounds _setindex!(::Val{true}, A::AbstractGPUArray, v, Is...) ## scalar indexing function scalar_getindex(A::AbstractGPUArray{T}, Is...) where T - assertscalar("getindex") @boundscheck checkbounds(A, Is...) - i = Base._to_linear_index(A, Is...) + I = Base._to_linear_index(A, Is...) + getindex(A, I) +end + +function scalar_setindex!(A::AbstractGPUArray{T}, v, Is...) where T + @boundscheck checkbounds(A, Is...) + I = Base._to_linear_index(A, Is...) + setindex!(A, v, I) +end + +# we still dispatch to `Base.getindex(a, ::Int)` etc so that there's a single method to +# override when a back-end (e.g. with unified memory) wants to allow scalar indexing. + +function Base.getindex(A::AbstractGPUArray{T}, I::Int) where T + @boundscheck checkbounds(A, I) + assertscalar("getindex") x = Array{T}(undef, 1) - copyto!(x, 1, A, i, 1) + copyto!(x, 1, A, I, 1) return x[1] end -function scalar_setindex!(A::AbstractGPUArray{T}, v, Is...) where T +function Base.setindex!(A::AbstractGPUArray{T}, v, I::Int) where T + @boundscheck checkbounds(A, I) assertscalar("setindex!") - @boundscheck checkbounds(A, Is...) - i = Base._to_linear_index(A, Is...) x = T[v] - copyto!(A, i, x, 1, 1) + copyto!(A, I, x, 1, 1) return A end