Skip to content

Commit

Permalink
Merge pull request #328 from yuehhua/cartesianidx
Browse files Browse the repository at this point in the history
Fix the output type of reverse_indices
  • Loading branch information
CarloLucibello authored Jun 30, 2021
2 parents a62a73f + 87f4e68 commit a40daf4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,21 @@ function reverse_indices!(rev::AbstractArray, idx::AbstractArray)
rev
end

function reverse_indices(idx::AbstractArray)
rev = Array{Vector{CartesianIndex}}(undef, maximum_dims(idx)...)
"""
reverse_indices(idx)
Return the reverse indices of `idx`. The indices of `idx` will be values, and values of `idx` will be index.
# Arguments
- `idx`: The indices to be reversed. Accepts array or cuarray of integer, tuple or `CartesianIndex`.
"""
function reverse_indices(idx::AbstractArray{<:Any,N}) where N
max_dims = maximum_dims(idx)
T = CartesianIndex{N}
rev = Array{Vector{T}}(undef, max_dims...)
for i in eachindex(rev)
rev[i] = CartesianIndex[]
rev[i] = T[]
end
return reverse_indices!(rev, idx)
end
Expand Down
3 changes: 3 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,16 @@ end
4 2 1 3;
3 5 5 3]
@test NNlib.reverse_indices(idx) == res
@test NNlib.reverse_indices(idx) isa typeof(res)
idx = [(1,) (2,) (3,) (4,);
(4,) (2,) (1,) (3,);
(3,) (5,) (5,) (3,)]
@test NNlib.reverse_indices(idx) == res
@test NNlib.reverse_indices(idx) isa typeof(res)
idx = CartesianIndex.(
[(1,) (2,) (3,) (4,);
(4,) (2,) (1,) (3,);
(3,) (5,) (5,) (3,)])
@test NNlib.reverse_indices(idx) == res
@test NNlib.reverse_indices(idx) isa typeof(res)
end

0 comments on commit a40daf4

Please sign in to comment.