Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for the remaining wrapper types #369

Merged
merged 15 commits into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ function create_result(tocopy::T, path, result_stores) where {T}
elems = Union{Symbol,Expr}[]

for i in 1:fieldcount(T)
# If the field is undefined we don't set it. A common example for this is `du2`
# for Tridiagonal
isdefined(tocopy, i) || continue
ev = create_result(getfield(tocopy, i), append_path(path, i), result_stores)
push!(elems, ev)
end
Expand Down Expand Up @@ -102,7 +105,7 @@ function create_result(tocopy::D, path, result_stores) where {K,V,D<:AbstractDic
end

function create_result(
tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol},
tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol,Char},
path,
result_stores,
)
Expand Down
110 changes: 110 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1418,4 +1418,114 @@ julia> Reactant.@jit(
end
end

"""
scatter_setindex(dest, scatter_indices, updates)

Uses [`MLIR.Dialects.stablehlo.scatter`](@ref) to set the values of `dest` at the indices
specified by `scatter_indices` to the values in `updates`. If the indices are contiguous it
is recommended to directly use [`MLIR.Dialects.stablehlo.dynamic_update_slice`](@ref)
instead.
"""
@noinline function scatter_setindex(
dest::TracedRArray{T,N},
scatter_indices::TracedRArray{Int64,2},
updates::TracedRArray{T,1},
) where {T,N}
@assert length(updates) == size(scatter_indices, 1)
@assert size(scatter_indices, 2) == N

update_computation = MLIR.IR.Region()
block = MLIR.IR.Block(
[mlir_type(TracedRNumber{T}), mlir_type(TracedRNumber{T})],
[MLIR.IR.Location(), MLIR.IR.Location()],
)
return_op = MLIR.Dialects.stablehlo.return_([MLIR.IR.argument(block, 2)])
MLIR.IR.rmfromparent!(return_op)
push!(block, return_op)
pushfirst!(update_computation, block)

#! format: off
update_window_dims = Int64[]
inserted_window_dims = collect(Int64, 0:(N - 1))
input_batching_dims = Int64[]
scatter_indices_batching_dims = Int64[]
scatter_dims_to_operand_dims = collect(Int64, 0:(N - 1))
index_vector_dim = Int64(1)

scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet(
MLIR.IR.context(),
length(update_window_dims), update_window_dims,
length(inserted_window_dims), inserted_window_dims,
length(input_batching_dims), input_batching_dims,
length(scatter_indices_batching_dims), scatter_indices_batching_dims,
length(scatter_dims_to_operand_dims), scatter_dims_to_operand_dims,
index_vector_dim,
)
#! format: on

return TracedRArray{T,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.scatter(
[dest.mlir_data],
scatter_indices.mlir_data,
[updates.mlir_data];
result_0=[mlir_type(TracedRArray{T,N}, size(dest))],
update_computation,
scatter_dimension_numbers,
),
1,
),
size(dest),
)
end

"""
gather_getindex(src, gather_indices)

Uses [`MLIR.Dialects.stablehlo.gather`](@ref) to get the values of `src` at the indices
specified by `gather_indices`. If the indices are contiguous it is recommended to directly
use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
"""
@noinline function gather_getindex(
src::TracedRArray{T,N}, gather_indices::TracedRArray{Int64,2}
) where {T,N}
@assert size(gather_indices, 2) == N

#! format: off
offset_dims = Int64[1]
collapsed_slice_dims = collect(Int64, 0:(N - 2))
operand_batching_dims = Int64[]
start_indices_batching_dims = Int64[]
start_index_map = collect(Int64, 0:(N - 1))
index_vector_dim = Int64(1)

dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet(
MLIR.IR.context(),
Int64(length(offset_dims)), offset_dims,
Int64(length(collapsed_slice_dims)), collapsed_slice_dims,
Int64(length(operand_batching_dims)), operand_batching_dims,
Int64(length(start_indices_batching_dims)), start_indices_batching_dims,
Int64(length(start_index_map)), start_index_map,
Int64(index_vector_dim),
)
#! format: on

return reshape(
TracedRArray{T}(
MLIR.IR.result(
MLIR.Dialects.stablehlo.gather(
src.mlir_data,
gather_indices.mlir_data;
dimension_numbers,
slice_sizes=fill(Int64(1), N),
indices_are_sorted=false,
),
1,
),
),
size(gather_indices, 1),
)
end

end # module Ops
27 changes: 27 additions & 0 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,30 @@ for randfun in (:rand, :randn, :randexp)
# end
end
end

# LinearAlgebra.jl overloads
## `_mul!` goes through too many layers of abstractions and we aren't able to overload
## without specializing on every possible combination of types
for (cT, aT, bT) in (
(:AbstractVector, :AbstractMatrix, :AbstractVector),
(:AbstractMatrix, :AbstractMatrix, :AbstractVecOrMat),
)
@eval begin
@reactant_overlay @noinline function LinearAlgebra.mul!(
C::$cT, A::$aT, B::$bT, α::Number, β::Number
)
if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B))
TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β)
else
LinearAlgebra._mul!(C, A, B, α, β)
end
return C
end

# Needed mostly for 1.10 where 3-arg mul is often specialized
@reactant_overlay @noinline function LinearAlgebra.mul!(C::$cT, A::$aT, B::$bT)
call_with_reactant(LinearAlgebra.mul!, C, A, B, true, false)
return C
end
end
end
20 changes: 14 additions & 6 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N}
) where {T,N}
shape = Tuple(shape)
if !isnothing(mlir_data)
@assert size(MLIR.IR.type(mlir_data)) == shape
@assert size(MLIR.IR.type(mlir_data)) == shape "Expected: $(shape), got: $(size(MLIR.IR.type(mlir_data)))"
end
return new{T,N}(paths, mlir_data, shape)
end
Expand All @@ -119,15 +119,23 @@ const WrappedTracedRArray{T,N} = WrappedArray{
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
const AnyTracedRVector{T} = AnyTracedRArray{T,1}
const AnyTracedRMatrix{T} = Union{
AnyTracedRArray{T,2},LinearAlgebra.Diagonal{T,TracedRArray{T,1}}
AnyTracedRArray{T,2},
LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}},
LinearAlgebra.Tridiagonal{TracedRNumber{T},TracedRArray{T,1}},
}
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}

function TracedRArray(data::MLIR.IR.Value)
function TracedRArray{T}(data::MLIR.IR.Value) where {T}
data_type = MLIR.IR.type(data)
return TracedRArray{eltype(MLIR.IR.julia_type(data_type)),ndims(data_type)}(
(), data, size(data_type)
)
if T == eltype(MLIR.IR.julia_type(data_type))
return TracedRArray{T,ndims(data_type)}((), data, size(data_type))
end
tdata = TracedRArray(data)
return Ops.convert(TracedRArray{T,ndims(data_type)}, tdata)
end

function TracedRArray(data::MLIR.IR.Value)
return TracedRArray{eltype(MLIR.IR.julia_type(MLIR.IR.type(data)))}(data)
end

struct XLAArray{T,N} <: RArray{T,N} end
Expand Down
74 changes: 54 additions & 20 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ using ..Reactant:
MLIR,
ancestor,
unwrapped_eltype
using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array

using ReactantCore: ReactantCore
using ..TracedUtils: TracedUtils, materialize_traced_array
using GPUArraysCore: GPUArraysCore

ReactantCore.is_traced(::TracedRArray) = true
Expand Down Expand Up @@ -55,25 +56,37 @@ function Base.getindex(
return TracedRNumber{T}((), res2)
end

function Base.getindex(a::TracedRArray{T,0}) where {T}
return TracedRNumber{T}((), a.mlir_data)
end
Base.getindex(a::TracedRArray{T,0}) where {T} = TracedRNumber{T}((), a.mlir_data)

# XXX: We want to support https://github.com/EnzymeAD/Reactant.jl/issues/242 eventually
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
indices = map(enumerate(indices)) do (idx, i)
i isa Colon && return 1:size(a, idx)
i isa CartesianIndex && return Tuple(i)
return i
end

foreach(indices) do idxs
idxs isa Number && return nothing
non_contiguous_getindex = false
for idxs in indices
idxs isa Number && continue
contiguous = all(isone, diff(idxs))
# XXX: We want to throw error even for dynamic indexing
if typeof(a) <: Bool
contiguous || error("non-contiguous indexing is not supported")
if typeof(contiguous) <: Bool && !contiguous
non_contiguous_getindex = true
break
end
end

if non_contiguous_getindex
indices_tuples = collect(Iterators.product(indices...))
indices = Matrix{Int}(
undef, (length(indices_tuples), length(first(indices_tuples)))
)
for (i, idx) in enumerate(indices_tuples)
indices[i, :] .= idx .- 1
end
indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices)
res = Ops.gather_getindex(a, indices)
return Ops.reshape(res, size(indices_tuples)...)
end

start_indices = map(indices) do i
Expand All @@ -99,16 +112,41 @@ function Base.getindex(a::WrappedTracedRArray, indices...)
return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...)
end

function Base.setindex!(
a::TracedRArray{T,N},
v,
indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N},
) where {T,N}
function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N}
indices = map(enumerate(indices)) do (idx, i)
i isa Int ? (i:i) : (i isa Colon ? (1:size(a, idx)) : i)
i isa Colon && return 1:size(a, idx)
i isa CartesianIndex && return Tuple(i)
return i
end

non_contiguous_setindex = false
for idxs in indices
idxs isa Number && continue
contiguous = all(isone, diff(idxs))
# XXX: We want to throw error even for dynamic indexing
if typeof(contiguous) <: Bool && !contiguous
non_contiguous_setindex = true
break
end
Comment on lines +127 to +130
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we introduce a runtime error in the generated MLIR?

Copy link
Collaborator

@mofeing mofeing Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think you can't? at least not in the ML IR, but the verifier will error when verifying the ops

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of some sort of custom_call which jax is using here https://github.com/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb

stablehlo.custom_call @shape_assertion(%1, %0) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'a'. Using the following polymorphic shapes specifications: args[0].shape = (a, 3, 224, 224). Obtained dimension variables: 'a' = {0} from specification 'a' for dimension args[0].shape[0] (= {0}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i32>) -> ()

For tracedRarray indices we should probably always do a (dynamic_)gather (and we might be able to write an optimization later to transform that into a slice if contiguous.

end

if non_contiguous_setindex
indices_tuples = collect(Iterators.product(indices...))
indices = Matrix{Int}(
undef, (length(indices_tuples), length(first(indices_tuples)))
)
for (i, idx) in enumerate(indices_tuples)
indices[i, :] .= idx .- 1
end
indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices)
res = Ops.scatter_setindex(a, indices, Ops.reshape(v, length(v)))
a.mlir_data = res.mlir_data
return v
end

v = TracedUtils.broadcast_to_size(v, length.(indices))
v = TracedUtils.promote_to(TracedRArray{T,N}, v)

indices = [
(
TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1
Expand All @@ -124,11 +162,7 @@ function Base.setindex!(
return v
end

function Base.setindex!(
a::AnyTracedRArray{T,N},
v,
indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N},
) where {T,N}
function Base.setindex!(a::AnyTracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N}
ancestor_indices = TracedUtils.get_ancestor_indices(a, indices...)
setindex!(ancestor(a), v, ancestor_indices...)
return a
Expand Down
Loading
Loading