diff --git a/src/host/indexing.jl b/src/host/indexing.jl index d270ab40..43121aec 100644 --- a/src/host/indexing.jl +++ b/src/host/indexing.jl @@ -138,15 +138,15 @@ struct EachIndex{T,N,IS} <: AbstractArray{T,N} dims::NTuple{N,Int} indices::IS end -EachIndex(xs::AbstractArray) = - EachIndex{typeof(firstindex(xs)), ndims(xs), typeof(eachindex(xs))}( - size(xs), eachindex(xs)) +EachIndex(A::AbstractArray) = + EachIndex{typeof(firstindex(A)), ndims(A), typeof(eachindex(A))}( + size(A), eachindex(A)) Base.size(ei::EachIndex) = ei.dims Base.getindex(ei::EachIndex, i::Int) = ei.indices[i] Base.IndexStyle(::Type{<:EachIndex}) = Base.IndexLinear() -function Base.findfirst(f::Function, xs::AnyGPUArray) - indices = EachIndex(xs) +function Base.findfirst(f::Function, A::AnyGPUArray) + indices = EachIndex(A) dummy_index = first(indices) # given two pairs of (istrue, index), return the one with the smallest index @@ -161,23 +161,23 @@ function Base.findfirst(f::Function, xs::AnyGPUArray) return (false, dummy_index) end - res = mapreduce((x, y)->(f(x), y), reduction, xs, indices; + res = mapreduce((x, y)->(f(x), y), reduction, A, indices; init = (false, dummy_index)) if res[1] # out of consistency with Base.findarray, return a CartesianIndex # when the input is a multidimensional array - ndims(xs) == 1 && return res[2] - return CartesianIndices(xs)[res[2]] + ndims(A) == 1 && return res[2] + return CartesianIndices(A)[res[2]] else return nothing end end -Base.findfirst(xs::AnyGPUArray{Bool}) = findfirst(identity, xs) +Base.findfirst(A::AnyGPUArray{Bool}) = findfirst(identity, A) -function findminmax(binop, xs::AnyGPUArray; init, dims) - indices = EachIndex(xs) - dummy_index = firstindex(xs) +function findminmax(binop, A::AnyGPUArray; init, dims) + indices = EachIndex(A) + dummy_index = firstindex(A) function reduction(t1, t2) (x, i), (y, j) = t1, t2 @@ -188,16 +188,16 @@ function findminmax(binop, xs::AnyGPUArray; init, dims) end if dims == Colon() - res = mapreduce(tuple, reduction, xs, indices; init = (init, dummy_index)) + res = mapreduce(tuple, reduction, A, indices; init = (init, dummy_index)) # out of consistency with Base.findarray, return a CartesianIndex # when the input is a multidimensional array - return (res[1], ndims(xs) == 1 ? res[2] : CartesianIndices(xs)[res[2]]) + return (res[1], ndims(A) == 1 ? res[2] : CartesianIndices(A)[res[2]]) else - res = mapreduce(tuple, reduction, xs, indices; + res = mapreduce(tuple, reduction, A, indices; init = (init, dummy_index), dims=dims) vals = map(x->x[1], res) - inds = map(x->ndims(xs) == 1 ? x[2] : CartesianIndices(xs)[x[2]], res) + inds = map(x->ndims(A) == 1 ? x[2] : CartesianIndices(A)[x[2]], res) return (vals, inds) end end