diff --git a/src/Compiler.jl b/src/Compiler.jl index 62cf685b6..3c4c3996d 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -41,6 +41,9 @@ function create_result(tocopy::T, path, result_stores) where {T} elems = Union{Symbol,Expr}[] for i in 1:fieldcount(T) + # If the field is undefined we don't set it. A common example for this is `du2` + # for Tridiagonal + isdefined(tocopy, i) || continue ev = create_result(getfield(tocopy, i), append_path(path, i), result_stores) push!(elems, ev) end @@ -102,7 +105,7 @@ function create_result(tocopy::D, path, result_stores) where {K,V,D<:AbstractDic end function create_result( - tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol}, + tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol,Char}, path, result_stores, ) diff --git a/src/Ops.jl b/src/Ops.jl index 928dfbefe..f67300787 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1418,4 +1418,114 @@ julia> Reactant.@jit( end end +""" + scatter_setindex(dest, scatter_indices, updates) + +Uses [`MLIR.Dialects.stablehlo.scatter`](@ref) to set the values of `dest` at the indices +specified by `scatter_indices` to the values in `updates`. If the indices are contiguous it +is recommended to directly use [`MLIR.Dialects.stablehlo.dynamic_update_slice`](@ref) +instead. +""" +@noinline function scatter_setindex( + dest::TracedRArray{T,N}, + scatter_indices::TracedRArray{Int64,2}, + updates::TracedRArray{T,1}, +) where {T,N} + @assert length(updates) == size(scatter_indices, 1) + @assert size(scatter_indices, 2) == N + + update_computation = MLIR.IR.Region() + block = MLIR.IR.Block( + [mlir_type(TracedRNumber{T}), mlir_type(TracedRNumber{T})], + [MLIR.IR.Location(), MLIR.IR.Location()], + ) + return_op = MLIR.Dialects.stablehlo.return_([MLIR.IR.argument(block, 2)]) + MLIR.IR.rmfromparent!(return_op) + push!(block, return_op) + pushfirst!(update_computation, block) + + #! format: off + update_window_dims = Int64[] + inserted_window_dims = collect(Int64, 0:(N - 1)) + input_batching_dims = Int64[] + scatter_indices_batching_dims = Int64[] + scatter_dims_to_operand_dims = collect(Int64, 0:(N - 1)) + index_vector_dim = Int64(1) + + scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet( + MLIR.IR.context(), + length(update_window_dims), update_window_dims, + length(inserted_window_dims), inserted_window_dims, + length(input_batching_dims), input_batching_dims, + length(scatter_indices_batching_dims), scatter_indices_batching_dims, + length(scatter_dims_to_operand_dims), scatter_dims_to_operand_dims, + index_vector_dim, + ) + #! format: on + + return TracedRArray{T,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.scatter( + [dest.mlir_data], + scatter_indices.mlir_data, + [updates.mlir_data]; + result_0=[mlir_type(TracedRArray{T,N}, size(dest))], + update_computation, + scatter_dimension_numbers, + ), + 1, + ), + size(dest), + ) +end + +""" + gather_getindex(src, gather_indices) + +Uses [`MLIR.Dialects.stablehlo.gather`](@ref) to get the values of `src` at the indices +specified by `gather_indices`. If the indices are contiguous it is recommended to directly +use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead. +""" +@noinline function gather_getindex( + src::TracedRArray{T,N}, gather_indices::TracedRArray{Int64,2} +) where {T,N} + @assert size(gather_indices, 2) == N + + #! format: off + offset_dims = Int64[1] + collapsed_slice_dims = collect(Int64, 0:(N - 2)) + operand_batching_dims = Int64[] + start_indices_batching_dims = Int64[] + start_index_map = collect(Int64, 0:(N - 1)) + index_vector_dim = Int64(1) + + dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( + MLIR.IR.context(), + Int64(length(offset_dims)), offset_dims, + Int64(length(collapsed_slice_dims)), collapsed_slice_dims, + Int64(length(operand_batching_dims)), operand_batching_dims, + Int64(length(start_indices_batching_dims)), start_indices_batching_dims, + Int64(length(start_index_map)), start_index_map, + Int64(index_vector_dim), + ) + #! format: on + + return reshape( + TracedRArray{T}( + MLIR.IR.result( + MLIR.Dialects.stablehlo.gather( + src.mlir_data, + gather_indices.mlir_data; + dimension_numbers, + slice_sizes=fill(Int64(1), N), + indices_are_sorted=false, + ), + 1, + ), + ), + size(gather_indices, 1), + ) +end + end # module Ops diff --git a/src/Overlay.jl b/src/Overlay.jl index b9785b7fa..0b5844464 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -115,3 +115,30 @@ for randfun in (:rand, :randn, :randexp) # end end end + +# LinearAlgebra.jl overloads +## `_mul!` goes through too many layers of abstractions and we aren't able to overload +## without specializing on every possible combination of types +for (cT, aT, bT) in ( + (:AbstractVector, :AbstractMatrix, :AbstractVector), + (:AbstractMatrix, :AbstractMatrix, :AbstractVecOrMat), +) + @eval begin + @reactant_overlay @noinline function LinearAlgebra.mul!( + C::$cT, A::$aT, B::$bT, α::Number, β::Number + ) + if any(Base.Fix2(isa, TracedRArray) ∘ ancestor, (C, A, B)) + TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β) + else + LinearAlgebra._mul!(C, A, B, α, β) + end + return C + end + + # Needed mostly for 1.10 where 3-arg mul is often specialized + @reactant_overlay @noinline function LinearAlgebra.mul!(C::$cT, A::$aT, B::$bT) + call_with_reactant(LinearAlgebra.mul!, C, A, B, true, false) + return C + end + end +end diff --git a/src/Reactant.jl b/src/Reactant.jl index 73cfb516b..d06784c13 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -105,7 +105,7 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N} ) where {T,N} shape = Tuple(shape) if !isnothing(mlir_data) - @assert size(MLIR.IR.type(mlir_data)) == shape + @assert size(MLIR.IR.type(mlir_data)) == shape "Expected: $(shape), got: $(size(MLIR.IR.type(mlir_data)))" end return new{T,N}(paths, mlir_data, shape) end @@ -119,15 +119,23 @@ const WrappedTracedRArray{T,N} = WrappedArray{ const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} const AnyTracedRVector{T} = AnyTracedRArray{T,1} const AnyTracedRMatrix{T} = Union{ - AnyTracedRArray{T,2},LinearAlgebra.Diagonal{T,TracedRArray{T,1}} + AnyTracedRArray{T,2}, + LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}}, + LinearAlgebra.Tridiagonal{TracedRNumber{T},TracedRArray{T,1}}, } const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} -function TracedRArray(data::MLIR.IR.Value) +function TracedRArray{T}(data::MLIR.IR.Value) where {T} data_type = MLIR.IR.type(data) - return TracedRArray{eltype(MLIR.IR.julia_type(data_type)),ndims(data_type)}( - (), data, size(data_type) - ) + if T == eltype(MLIR.IR.julia_type(data_type)) + return TracedRArray{T,ndims(data_type)}((), data, size(data_type)) + end + tdata = TracedRArray(data) + return Ops.convert(TracedRArray{T,ndims(data_type)}, tdata) +end + +function TracedRArray(data::MLIR.IR.Value) + return TracedRArray{eltype(MLIR.IR.julia_type(MLIR.IR.type(data)))}(data) end struct XLAArray{T,N} <: RArray{T,N} end diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 2f8c07eb3..275f6dd92 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -14,8 +14,9 @@ using ..Reactant: MLIR, ancestor, unwrapped_eltype +using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array + using ReactantCore: ReactantCore -using ..TracedUtils: TracedUtils, materialize_traced_array using GPUArraysCore: GPUArraysCore ReactantCore.is_traced(::TracedRArray) = true @@ -55,11 +56,8 @@ function Base.getindex( return TracedRNumber{T}((), res2) end -function Base.getindex(a::TracedRArray{T,0}) where {T} - return TracedRNumber{T}((), a.mlir_data) -end +Base.getindex(a::TracedRArray{T,0}) where {T} = TracedRNumber{T}((), a.mlir_data) -# XXX: We want to support https://github.com/EnzymeAD/Reactant.jl/issues/242 eventually function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} indices = map(enumerate(indices)) do (idx, i) i isa Colon && return 1:size(a, idx) @@ -67,13 +65,28 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} return i end - foreach(indices) do idxs - idxs isa Number && return nothing + non_contiguous_getindex = false + for idxs in indices + idxs isa Number && continue contiguous = all(isone, diff(idxs)) # XXX: We want to throw error even for dynamic indexing - if typeof(a) <: Bool - contiguous || error("non-contiguous indexing is not supported") + if typeof(contiguous) <: Bool && !contiguous + non_contiguous_getindex = true + break + end + end + + if non_contiguous_getindex + indices_tuples = collect(Iterators.product(indices...)) + indices = Matrix{Int}( + undef, (length(indices_tuples), length(first(indices_tuples))) + ) + for (i, idx) in enumerate(indices_tuples) + indices[i, :] .= idx .- 1 end + indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices) + res = Ops.gather_getindex(a, indices) + return Ops.reshape(res, size(indices_tuples)...) end start_indices = map(indices) do i @@ -99,16 +112,41 @@ function Base.getindex(a::WrappedTracedRArray, indices...) return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...) end -function Base.setindex!( - a::TracedRArray{T,N}, - v, - indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N}, -) where {T,N} +function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N} indices = map(enumerate(indices)) do (idx, i) - i isa Int ? (i:i) : (i isa Colon ? (1:size(a, idx)) : i) + i isa Colon && return 1:size(a, idx) + i isa CartesianIndex && return Tuple(i) + return i + end + + non_contiguous_setindex = false + for idxs in indices + idxs isa Number && continue + contiguous = all(isone, diff(idxs)) + # XXX: We want to throw error even for dynamic indexing + if typeof(contiguous) <: Bool && !contiguous + non_contiguous_setindex = true + break + end + end + + if non_contiguous_setindex + indices_tuples = collect(Iterators.product(indices...)) + indices = Matrix{Int}( + undef, (length(indices_tuples), length(first(indices_tuples))) + ) + for (i, idx) in enumerate(indices_tuples) + indices[i, :] .= idx .- 1 + end + indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices) + res = Ops.scatter_setindex(a, indices, Ops.reshape(v, length(v))) + a.mlir_data = res.mlir_data + return v end + v = TracedUtils.broadcast_to_size(v, length.(indices)) v = TracedUtils.promote_to(TracedRArray{T,N}, v) + indices = [ ( TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1 @@ -124,11 +162,7 @@ function Base.setindex!( return v end -function Base.setindex!( - a::AnyTracedRArray{T,N}, - v, - indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N}, -) where {T,N} +function Base.setindex!(a::AnyTracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N} ancestor_indices = TracedUtils.get_ancestor_indices(a, indices...) setindex!(ancestor(a), v, ancestor_indices...) return a diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index b69027549..7b491f4b7 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -3,7 +3,6 @@ # within compilation. However, it means these functions are a _lot_ faster to compile. module TracedUtils -using LinearAlgebra: LinearAlgebra using Adapt: Adapt, WrappedReshapedArray using ..Reactant: Reactant, @@ -19,34 +18,20 @@ using ..Reactant: Ops materialize_traced_array(x::TracedRArray) = x + materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...] + function materialize_traced_array( x::WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}} ) where {T,N,M} return Ops.reshape(materialize_traced_array(parent(x)), size(x)...) end -function materialize_traced_array( - x::LinearAlgebra.Transpose{TracedRNumber{T},TracedRArray{T,N}} -) where {T,N} - px = parent(x) - A = ndims(px) == 1 ? reshape(px, :, 1) : px - return permutedims(A, (2, 1)) -end -function materialize_traced_array( - x::LinearAlgebra.Adjoint{TracedRNumber{T},TracedRArray{T,N}} -) where {T,N} - return conj(materialize_traced_array(transpose(parent(x)))) -end + function materialize_traced_array( x::PermutedDimsArray{TracedRNumber{T},N,perm,iperm,TracedRArray{T,N}} ) where {T,N,perm,iperm} return permutedims(parent(x), perm) end -function materialize_traced_array( - x::LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}} -) where {T} - return LinearAlgebra.diagm(parent(x)) -end get_mlir_data(x::TracedRNumber) = x.mlir_data set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data; return x) @@ -58,51 +43,24 @@ function set_mlir_data!(x::TracedRArray, data) x.mlir_data = data return x end + function set_mlir_data!( x::WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}}, data ) where {T,N,M} - res_mlir_data = Ops.reshape(TracedRArray(data), size(parent(x))...).mlir_data + res_mlir_data = Ops.reshape(TracedRArray{T}(data), size(parent(x))...).mlir_data set_mlir_data!(parent(x), res_mlir_data) return x end -function set_mlir_data!( - x::LinearAlgebra.Transpose{TracedRNumber{T},TracedRArray{T,N}}, data -) where {T,N} - tdata = TracedRArray(data) - px = parent(x) - px.mlir_data = ( - if ndims(px) == 1 - Ops.reshape(tdata, length(tdata)) - else - Ops.transpose(tdata, [2, 1]) - end - ).mlir_data - return x -end -function set_mlir_data!( - x::LinearAlgebra.Adjoint{TracedRNumber{T},TracedRArray{T,N}}, data -) where {T,N} - tdata = TracedRArray(data) - px = parent(x) - transposed_data = - ndims(px) == 1 ? Ops.reshape(tdata, length(tdata)) : Ops.transpose(tdata, [2, 1]) - px.mlir_data = (T <: Real ? transposed_data : Ops.conj(transposed_data)).mlir_data - return x -end + function set_mlir_data!( x::PermutedDimsArray{TracedRNumber{T},N,perm,iperm,TracedRArray{T,N}}, data ) where {T,N,perm,iperm} - parent(x).mlir_data = permutedims(TracedRArray(data), iperm).mlir_data + parent(x).mlir_data = permutedims(TracedRArray{T}(data), iperm).mlir_data return x end -function set_mlir_data!( - x::LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data -) where {T} - parent(x).mlir_data = LinearAlgebra.diag(TracedRArray(data)).mlir_data - return x -end -function set_mlir_data!(x::AnyTracedRArray, data) - setindex!(x, TracedRArray(data), axes(x)...) + +function set_mlir_data!(x::AnyTracedRArray{T}, data) where {T} + setindex!(x, TracedRArray{T}(data), axes(x)...) return x end diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index c011f8aec..aa56c7b92 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -1,20 +1,170 @@ module TracedLinearAlgebra -using ..Reactant -import ..TracedRArray -import ..TracedRNumber -import ..AnyTracedRArray -import ..AnyTracedRMatrix -import ..AnyTracedRVector - -import ..TracedUtils -using ..TracedUtils: get_mlir_data, materialize_traced_array, set_mlir_data! - -import ..Ops -import ..MLIR +using ..Reactant: + TracedRArray, + TracedRNumber, + AnyTracedRArray, + AnyTracedRMatrix, + AnyTracedRVector, + Ops, + MLIR + +using ..TracedUtils: TracedUtils, get_mlir_data, materialize_traced_array, set_mlir_data! + using LinearAlgebra -function LinearAlgebra.mul!( +# Various Wrapper Arrays defined in LinearAlgebra +function TracedUtils.materialize_traced_array( + x::Transpose{TracedRNumber{T},TracedRArray{T,N}} +) where {T,N} + px = parent(x) + A = ndims(px) == 1 ? reshape(px, :, 1) : px + return permutedims(A, (2, 1)) +end + +function TracedUtils.materialize_traced_array( + x::Adjoint{TracedRNumber{T},TracedRArray{T,N}} +) where {T,N} + return conj(materialize_traced_array(transpose(parent(x)))) +end + +function TracedUtils.materialize_traced_array( + x::Diagonal{TracedRNumber{T},TracedRArray{T,1}} +) where {T} + return diagm(parent(x)) +end + +function TracedUtils.materialize_traced_array( + x::Tridiagonal{TracedRNumber{T},TracedRArray{T,1}} +) where {T} + return diagm(-1 => x.dl, 0 => x.d, 1 => x.du) +end + +for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE")) + uAT = Symbol(:Unit, AT) + @eval begin + function TracedUtils.materialize_traced_array( + x::$(AT){TracedRNumber{T},TracedRArray{T,2}} + ) where {T} + m, n = size(x) + row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) + col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) + indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=$(comp)) + return Ops.select(indicator, parent(x), zero(parent(x))) + end + + function TracedUtils.materialize_traced_array( + x::$(uAT){TracedRNumber{T},TracedRArray{T,2}} + ) where {T} + m, n = size(x) + row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) + col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) + nondiag_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="NE") + x = materialize_traced_array($(AT)(parent(x))) + return Ops.select(nondiag_indicator, x, one.(x)) + end + end +end + +function TracedUtils.materialize_traced_array( + x::Symmetric{TracedRNumber{T},TracedRArray{T,2}} +) where {T} + m, n = size(x) + row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) + col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) + if x.uplo == 'L' + indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="GT") + x_lt = Ops.select(indicator, parent(x), zero(parent(x))) + x_ltd = materialize_traced_array(LowerTriangular(parent(x))) + return Ops.add(x_lt, Ops.transpose(x_ltd, [2, 1])) + else + indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="LT") + x_ut = Ops.select(indicator, parent(x), zero(parent(x))) + x_utd = materialize_traced_array(UpperTriangular(parent(x))) + return Ops.add(Ops.transpose(x_utd, [2, 1]), x_ut) + end +end + +function TracedUtils.set_mlir_data!( + x::Transpose{TracedRNumber{T},TracedRArray{T,N}}, data +) where {T,N} + tdata = TracedRArray{T}(data) + px = parent(x) + px.mlir_data = ( + if ndims(px) == 1 + Ops.reshape(tdata, length(tdata)) + else + Ops.transpose(tdata, [2, 1]) + end + ).mlir_data + return x +end + +function TracedUtils.set_mlir_data!( + x::Adjoint{TracedRNumber{T},TracedRArray{T,N}}, data +) where {T,N} + tdata = TracedRArray{T}(data) + px = parent(x) + transposed_data = + ndims(px) == 1 ? Ops.reshape(tdata, length(tdata)) : Ops.transpose(tdata, [2, 1]) + px.mlir_data = (T <: Real ? transposed_data : Ops.conj(transposed_data)).mlir_data + return x +end + +function TracedUtils.set_mlir_data!( + x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data +) where {T} + parent(x).mlir_data = diag(TracedRArray{T}(data)).mlir_data + return x +end + +for (AT, dcomp, ocomp) in ( + (:LowerTriangular, "GE", "LT"), + (:UnitLowerTriangular, "GT", "LE"), + (:UpperTriangular, "LE", "GT"), + (:UnitUpperTriangular, "LT", "GE"), +) + @eval function TracedUtils.set_mlir_data!( + x::$(AT){TracedRNumber{T},TracedRArray{T,2}}, data + ) where {T} + tdata = TracedRArray{T}(data) + z = zero(tdata) + m, n = size(x) + row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) + col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) + data_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=$(dcomp)) + original_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=$(ocomp)) + res = Ops.add( + Ops.select(data_indicator, tdata, z), Ops.select(original_indicator, x.data, z) + ) + set_mlir_data!(x.data, res.mlir_data) + return x + end +end + +function TracedUtils.set_mlir_data!( + x::Symmetric{TracedRNumber{T},TracedRArray{T,2}}, data +) where {T} + if x.uplo == 'L' + set_mlir_data!(LowerTriangular(parent(x)), data) + else + set_mlir_data!(UpperTriangular(parent(x)), data) + end + return x +end + +function TracedUtils.set_mlir_data!( + x::Tridiagonal{TracedRNumber{T},TracedRArray{T,1}}, data +) where {T} + tdata = TracedRArray{T}(data) + set_mlir_data!(x.dl, diag(tdata, -1).mlir_data) + set_mlir_data!(x.d, diag(tdata, 0).mlir_data) + set_mlir_data!(x.du, diag(tdata, 1).mlir_data) + return x +end + +# Core functions +function overloaded_mul!( @nospecialize(C::TracedRArray{T,1}), @nospecialize(A::AnyTracedRMatrix), @nospecialize(B::AnyTracedRVector), @@ -23,23 +173,23 @@ function LinearAlgebra.mul!( ) where {T} # TODO: The reshape operations are not getting optimized, we should directly call dot_general rC = Ops.reshape(C, length(C), 1) - LinearAlgebra.mul!(rC, A, reshape(B, :, 1), α, β) + overloaded_mul!(rC, A, reshape(B, :, 1), α, β) C.mlir_data = get_mlir_data(vec(rC)) return C end -function LinearAlgebra.mul!( +function overloaded_mul!( @nospecialize(C::TracedRArray{T,2}), @nospecialize(A::AnyTracedRMatrix), @nospecialize(B::AnyTracedRVector), α::Number=true, β::Number=false, ) where {T} - LinearAlgebra.mul!(C, A, reshape(B, :, 1), α, β) + overloaded_mul!(C, A, reshape(B, :, 1), α, β) return C end -function LinearAlgebra.mul!( +function overloaded_mul!( @nospecialize(C::TracedRArray{T,2}), @nospecialize(A::AnyTracedRMatrix), @nospecialize(B::AnyTracedRMatrix), @@ -119,50 +269,52 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T} # :0: note: see current operation: %0 = "tensor.empty"() : () -> tensor<0xf64> length(indices) ≤ 0 && return TracedUtils.promote_to(TracedRArray{T,1}, T[]) - idxs = get_mlir_data(TracedUtils.promote_to(TracedRArray{Int,2}, indices)) - - #! format: off - dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( - MLIR.IR.context(), - Int64(0), Int64[], - Int64(2), Int64[0, 1], - Int64(0), Int64[], - Int64(0), Int64[], - Int64(2), Int64[0, 1], - Int64(1) - ) - #! format: on - - slice_sizes = get_mlir_data( - Reactant.TracedUtils.promote_to(TracedRArray{Int,1}, [1, 1]) - ) - res = MLIR.IR.result( - MLIR.Dialects.stablehlo.dynamic_gather( - get_mlir_data(y), idxs, slice_sizes; dimension_numbers - ), - 1, - ) - return TracedRArray{T,1}((), res, (diag_length,)) + return Ops.gather_getindex(x, TracedUtils.promote_to(TracedRArray{Int,2}, indices)) end -function LinearAlgebra.diagm(v::AnyTracedRArray{T,1}) where {T} - return LinearAlgebra.diagm(length(v), length(v), v) -end -function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) where {T} - m, n = LinearAlgebra.diagm_size((m, n), 0 => v) # size check +function LinearAlgebra._diagm( + shape, kv::Pair{<:Integer,<:AnyTracedRArray{T,1}}... +) where {T} + m, n = LinearAlgebra.diagm_size(shape, kv...) - v = materialize_traced_array(v) - D = length(v) - row_idxs = Ops.iota(Int, [D, D]; iota_dimension=1) - col_idxs = Ops.iota(Int, [D, D]; iota_dimension=2) - diag_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="EQ") + # For repeated indices we need to aggregate the values + kv_updated = Dict{Integer,AnyTracedRArray{T,1}}() + for (k, v) in kv + if haskey(kv_updated, k) + kv_updated[k] = kv_updated[k] + v + else + kv_updated[k] = v + end + end - mat = (v .+ zero(v)') .* diag_indicator - return Ops.pad( - mat, - TracedUtils.promote_to(TracedRNumber{T}, 0); - high=[m - length(v), n - length(v)], + scatter_indices = Matrix{Int64}[] + concat_inputs = MLIR.IR.Value[] + for (k, v) in pairs(kv_updated) + push!(scatter_indices, diagonal_indices_zero_indexed(m, n, k)[1:length(v), :]) + push!(concat_inputs, get_mlir_data(v)) + end + scatter_indices = Ops.constant(reduce(vcat, scatter_indices)) + values = TracedRArray{T,1}( + (), + MLIR.IR.result(MLIR.Dialects.stablehlo.concatenate(concat_inputs; dimension=0), 1), + (size(scatter_indices, 1),), + ) + return Ops.scatter_setindex( + Ops.constant(fill(zero(T), (m, n))), scatter_indices, values ) end +# Common Utilities +## The cartesian version doesn't exist in julia 1.10 +function diagonal_indices_zero_indexed(m::Integer, n::Integer, k::Integer=0) + idx1, idx2 = 1 + max(0, -k), 1 + max(0, k) + L = max(0, k ≤ 0 ? min(m + k, n) : min(m, n - k)) + indices = Matrix{Int}(undef, (L, 2)) + for i in axes(indices, 1) + indices[i, 1] = idx1 + i - 2 + indices[i, 2] = idx2 + i - 2 + end + return indices +end + end diff --git a/test/basic.jl b/test/basic.jl index 8e97ed41d..3522cd59e 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -442,6 +442,26 @@ end @test @allowscalar all(isone, x_ra_array[4, :]) end +function non_contiguous_setindex!(x) + x[[1, 3, 2], [1, 2, 3, 4]] .= 1.0 + return x +end + +@testset "non-contiguous setindex!" begin + x = rand(6, 6) + x_ra = Reactant.to_rarray(x) + + y = @jit(non_contiguous_setindex!(x_ra)) + y = Array(y) + x_ra = Array(x_ra) + @test all(isone, y[1:3, 1:4]) + @test all(isone, x_ra[1:3, 1:4]) + @test !all(isone, y[4:end, :]) + @test !all(isone, x_ra[4:end, :]) + @test !all(isone, y[:, 5:end]) + @test !all(isone, x_ra[:, 5:end]) +end + tuple_byref(x) = (; a=(; b=x)) tuple_byref2(x) = abs2.(x), tuple_byref2(x) @@ -717,3 +737,57 @@ end @test res[1] isa ConcreteRArray{Float64,2} @test res[2] isa ConcreteRNumber{Float64} end + +@testset "non-contiguous indexing" begin + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing1(x) = x[[1, 3, 2], :, :] + non_contiguous_indexing2(x) = x[:, [1, 2, 1, 3], [1, 3]] + + @test @jit(non_contiguous_indexing1(x_ra)) ≈ non_contiguous_indexing1(x) + @test @jit(non_contiguous_indexing2(x_ra)) ≈ non_contiguous_indexing2(x) + + x = rand(4, 2) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing1(x) = x[[1, 3, 2], :] + non_contiguous_indexing2(x) = x[:, [1, 2, 2]] + + @test @jit(non_contiguous_indexing1(x_ra)) ≈ non_contiguous_indexing1(x) + @test @jit(non_contiguous_indexing2(x_ra)) ≈ non_contiguous_indexing2(x) + + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing1!(x) = x[[1, 3, 2], :, :] .= 2 + non_contiguous_indexing2!(x) = x[:, [1, 2, 1, 3], [1, 3]] .= 2 + + @jit(non_contiguous_indexing1!(x_ra)) + non_contiguous_indexing1!(x) + @test x_ra ≈ x + + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + @jit(non_contiguous_indexing2!(x_ra)) + non_contiguous_indexing2!(x) + @test x_ra ≈ x + + x = rand(4, 2) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing1!(x) = x[[1, 3, 2], :] .= 2 + non_contiguous_indexing2!(x) = x[:, [1, 2, 2]] .= 2 + + @jit(non_contiguous_indexing1!(x_ra)) + non_contiguous_indexing1!(x) + @test x_ra ≈ x + + x = rand(4, 2) + x_ra = Reactant.to_rarray(x) + + @jit(non_contiguous_indexing2!(x_ra)) + non_contiguous_indexing2!(x) + @test x_ra ≈ x +end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 0c6efc5fd..ea39556f9 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -1,4 +1,4 @@ -using LinearAlgebra, Reactant +using LinearAlgebra, Reactant, Test function muladd2(A, x, b) C = similar(A, promote_type(eltype(A), eltype(b)), size(A, 1), size(x, 2)) @@ -130,15 +130,42 @@ end @test @jit(diagm(4, 5, x_ra)) ≈ diagm(4, 5, x) @test @jit(diagm(6, 6, x_ra)) ≈ diagm(6, 6, x) @test_throws DimensionMismatch @jit(diagm(3, 3, x_ra)) + + x1 = rand(3) + x2 = rand(3) + x3 = rand(2) + x_ra1 = Reactant.to_rarray(x1) + x_ra2 = Reactant.to_rarray(x2) + x_ra3 = Reactant.to_rarray(x3) + + @test @jit(diagm(1 => x_ra1)) ≈ diagm(1 => x1) + @test @jit(diagm(1 => x_ra1, -1 => x_ra3)) ≈ diagm(1 => x1, -1 => x3) + @test @jit(diagm(1 => x_ra1, 1 => x_ra2)) ≈ diagm(1 => x1, 1 => x2) end -# TODO: Currently Diagonal(x) * x goes down the generic matmul path but it should clearly be -# optimized +# TODO: Currently (x) * x goes down the generic matmul path but it should +# clearly be optimized mul_diagonal(x) = Diagonal(x) * x - -@testset "mul_diagonal" begin - x = rand(4) +mul_tridiagonal(x) = Tridiagonal(x) * x +mul_unit_lower_triangular(x) = UnitLowerTriangular(x) * x +mul_unit_upper_triangular(x) = UnitUpperTriangular(x) * x +mul_lower_triangular(x) = LowerTriangular(x) * x +mul_upper_triangular(x) = UpperTriangular(x) * x +mul_symmetric(x) = Symmetric(x) * x + +@testset "Wrapper Types Matrix Multiplication" begin + x = rand(4, 4) x_ra = Reactant.to_rarray(x) - @test @jit(mul_diagonal(x_ra)) ≈ mul_diagonal(x) + @testset "$(wrapper_type)" for (wrapper_type, fn) in [ + (Diagonal, mul_diagonal), + (Tridiagonal, mul_tridiagonal), + (UnitLowerTriangular, mul_unit_lower_triangular), + (UnitUpperTriangular, mul_unit_upper_triangular), + (LowerTriangular, mul_lower_triangular), + (UpperTriangular, mul_upper_triangular), + (Symmetric, mul_symmetric), + ] + @test @jit(fn(x_ra)) ≈ fn(x) + end end diff --git a/test/wrapped_arrays.jl b/test/wrapped_arrays.jl index f5418e5c8..c522bcd17 100644 --- a/test/wrapped_arrays.jl +++ b/test/wrapped_arrays.jl @@ -172,3 +172,33 @@ end @test all(iszero, y_res) end end + +function lower_triangular_write(x) + y = LowerTriangular(copy(x)) + @. y *= 2 + return y +end + +function upper_triangular_write(x) + y = UpperTriangular(copy(x)) + @. y *= 2 + return y +end + +function tridiagonal_write(x) + y = Tridiagonal(copy(x)) + @. y *= 2 + return y +end + +@testset "Broadcasted Multiply and Alloate" begin + @testset "$(aType)" for (aType, fn) in [ + ("LowerTriangular", lower_triangular_write), + ("UpperTriangular", upper_triangular_write), + ("Tridiagonal", tridiagonal_write), + ] + x = rand(4, 4) + x_ra = Reactant.to_rarray(x) + @test @jit(fn(x_ra)) ≈ fn(x) + end +end