Skip to content

Commit

Permalink
NFC rename.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Oct 31, 2023
1 parent 7c4e74f commit 12cc5f7
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions src/host/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,15 @@ struct EachIndex{T,N,IS} <: AbstractArray{T,N}
dims::NTuple{N,Int}
indices::IS
end
EachIndex(xs::AbstractArray) =
EachIndex{typeof(firstindex(xs)), ndims(xs), typeof(eachindex(xs))}(
size(xs), eachindex(xs))
EachIndex(A::AbstractArray) =
EachIndex{typeof(firstindex(A)), ndims(A), typeof(eachindex(A))}(
size(A), eachindex(A))
Base.size(ei::EachIndex) = ei.dims
Base.getindex(ei::EachIndex, i::Int) = ei.indices[i]
Base.IndexStyle(::Type{<:EachIndex}) = Base.IndexLinear()

function Base.findfirst(f::Function, xs::AnyGPUArray)
indices = EachIndex(xs)
function Base.findfirst(f::Function, A::AnyGPUArray)
indices = EachIndex(A)
dummy_index = first(indices)

# given two pairs of (istrue, index), return the one with the smallest index
Expand All @@ -161,23 +161,23 @@ function Base.findfirst(f::Function, xs::AnyGPUArray)
return (false, dummy_index)
end

res = mapreduce((x, y)->(f(x), y), reduction, xs, indices;
res = mapreduce((x, y)->(f(x), y), reduction, A, indices;
init = (false, dummy_index))
if res[1]
# out of consistency with Base.findarray, return a CartesianIndex
# when the input is a multidimensional array
ndims(xs) == 1 && return res[2]
return CartesianIndices(xs)[res[2]]
ndims(A) == 1 && return res[2]
return CartesianIndices(A)[res[2]]
else
return nothing
end
end

Base.findfirst(xs::AnyGPUArray{Bool}) = findfirst(identity, xs)
Base.findfirst(A::AnyGPUArray{Bool}) = findfirst(identity, A)

function findminmax(binop, xs::AnyGPUArray; init, dims)
indices = EachIndex(xs)
dummy_index = firstindex(xs)
function findminmax(binop, A::AnyGPUArray; init, dims)
indices = EachIndex(A)
dummy_index = firstindex(A)

function reduction(t1, t2)
(x, i), (y, j) = t1, t2
Expand All @@ -188,16 +188,16 @@ function findminmax(binop, xs::AnyGPUArray; init, dims)
end

if dims == Colon()
res = mapreduce(tuple, reduction, xs, indices; init = (init, dummy_index))
res = mapreduce(tuple, reduction, A, indices; init = (init, dummy_index))

# out of consistency with Base.findarray, return a CartesianIndex
# when the input is a multidimensional array
return (res[1], ndims(xs) == 1 ? res[2] : CartesianIndices(xs)[res[2]])
return (res[1], ndims(A) == 1 ? res[2] : CartesianIndices(A)[res[2]])
else
res = mapreduce(tuple, reduction, xs, indices;
res = mapreduce(tuple, reduction, A, indices;
init = (init, dummy_index), dims=dims)
vals = map(x->x[1], res)
inds = map(x->ndims(xs) == 1 ? x[2] : CartesianIndices(xs)[x[2]], res)
inds = map(x->ndims(A) == 1 ? x[2] : CartesianIndices(A)[x[2]], res)
return (vals, inds)
end
end
Expand Down

0 comments on commit 12cc5f7

Please sign in to comment.