Skip to content

Commit

Permalink
fix: non-contiguous indexing is now supported
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 12, 2024
1 parent 099ae3b commit 004cce1
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 5 deletions.
43 changes: 43 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1278,4 +1278,47 @@ function scatter_setindex(
)
end

"""
gather_getindex(src, gather_indices)
Uses [`MLIR.Dialects.stablehlo.gather`](@ref) to get the values of `src` at the indices
specified by `gather_indices`. If the indices are contiguous it is recommended to directly
use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
"""
function gather_getindex(
src::TracedRArray{T,N}, gather_indices::TracedRArray{Int64,2}
) where {T,N}
@assert size(gather_indices, 2) == N

#! format: off
dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet(
MLIR.IR.context(),
Int64(1), Int64[1],
Int64(N - 1), collect(Int64, 0:(N - 2)),
Int64(0), Int64[],
Int64(0), Int64[],
Int64(N), collect(Int64, 0:(N - 1)),
1
)
#! format: on

return reshape(
TracedRArray{T,2}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.gather(
src.mlir_data,
gather_indices.mlir_data;
dimension_numbers,
slice_sizes=fill(Int64(1), N),
indices_are_sorted=false,
),
1,
),
(size(gather_indices, 1), 1),
),
size(gather_indices, 1),
)
end

end # module Ops
23 changes: 18 additions & 5 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
) where {T,N}
shape = Tuple(shape)
if !isnothing(mlir_data)
@assert size(MLIR.IR.type(mlir_data)) == shape
@assert size(MLIR.IR.type(mlir_data)) == shape "Expected: $(shape), got: $(size(MLIR.IR.type(mlir_data)))"
end
return new{T,N}(paths, mlir_data, shape)
end
Expand Down Expand Up @@ -114,21 +114,34 @@ function Base.getindex(a::TracedRArray{T,0}) where {T}
return TracedRNumber{T}((), a.mlir_data)
end

# XXX: We want to support https://github.com/EnzymeAD/Reactant.jl/issues/242 eventually
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
indices = map(enumerate(indices)) do (idx, i)
i isa Colon && return 1:size(a, idx)
i isa CartesianIndex && return Tuple(i)
return i
end

non_contiguous_getindex = false
for idxs in indices
idxs isa Number && continue
contiguous = all(isone, diff(idxs))
# XXX: We want to throw error even for dynamic indexing
if typeof(contiguous) <: Bool
contiguous || error("non-contiguous indexing is not supported")
if typeof(contiguous) <: Bool && !contiguous
non_contiguous_getindex = true
break
end
end

if non_contiguous_getindex
indices_tuples = collect(Iterators.product(indices...))
indices = Matrix{Int}(undef, (length(indices_tuples), 2))
for (i, idx) in enumerate(indices_tuples)
indices[i, 1] = idx[1] - 1
indices[i, 2] = idx[2] - 1
end
indices = promote_to(TracedRArray{Int,2}, indices)
res = Ops.gather_getindex(a, indices)
return Ops.reshape(res, size(indices_tuples)...)
end

start_indices = map(indices) do i
Expand Down Expand Up @@ -179,7 +192,7 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
indices[i, 1] = idx[1] - 1
indices[i, 2] = idx[2] - 1
end
indices = promote_to(TracedRArray{Int, 2}, indices)
indices = promote_to(TracedRArray{Int,2}, indices)
res = Ops.scatter_setindex(a, indices, Ops.reshape(v, length(v)))
a.mlir_data = res.mlir_data
return v
Expand Down

0 comments on commit 004cce1

Please sign in to comment.