Skip to content

Commit

Permalink
Use Base method for scalar indexing.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Oct 31, 2023
1 parent 591c6a0 commit 7c4e74f
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions src/host/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ vectorized_indices(Is::Union{Integer,CartesianIndex}...) = Val{false}()
vectorized_indices(Is...) = Val{true}()

# TODO: re-use Base functionality for the conversion of indices to a linear index,
# by only implementing `getindex(A, ::Int)` etc. this is difficult if we want
# to also want to match the case where we take any vectorized index...
# by only implementing `getindex(A, ::Int)` etc. this is difficult due to
# ambiguities with the vectorized method that can take any index type.

Base.@propagate_inbounds Base.getindex(A::AbstractGPUArray, Is...) =
_getindex(vectorized_indices(Is...), A, to_indices(A, Is)...)
Expand All @@ -29,20 +29,33 @@ Base.@propagate_inbounds _setindex!(::Val{true}, A::AbstractGPUArray, v, Is...)
## scalar indexing

function scalar_getindex(A::AbstractGPUArray{T}, Is...) where T
assertscalar("getindex")
@boundscheck checkbounds(A, Is...)
i = Base._to_linear_index(A, Is...)
I = Base._to_linear_index(A, Is...)
getindex(A, I)
end

function scalar_setindex!(A::AbstractGPUArray{T}, v, Is...) where T
@boundscheck checkbounds(A, Is...)
I = Base._to_linear_index(A, Is...)
setindex!(A, v, I)
end

# we still dispatch to `Base.getindex(a, ::Int)` etc so that there's a single method to
# override when a back-end (e.g. with unified memory) wants to allow scalar indexing.

function Base.getindex(A::AbstractGPUArray{T}, I::Int) where T
@boundscheck checkbounds(A, I)
assertscalar("getindex")
x = Array{T}(undef, 1)
copyto!(x, 1, A, i, 1)
copyto!(x, 1, A, I, 1)
return x[1]
end

function scalar_setindex!(A::AbstractGPUArray{T}, v, Is...) where T
function Base.setindex!(A::AbstractGPUArray{T}, v, I::Int) where T
@boundscheck checkbounds(A, I)
assertscalar("setindex!")
@boundscheck checkbounds(A, Is...)
i = Base._to_linear_index(A, Is...)
x = T[v]
copyto!(A, i, x, 1, 1)
copyto!(A, I, x, 1, 1)
return A
end

Expand Down

0 comments on commit 7c4e74f

Please sign in to comment.