From f39adeb743cbd00339d66b00d0c79eb4fd8fe019 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 19 Dec 2024 11:53:00 +0530 Subject: [PATCH] fix: incorrect rebase --- src/Ops.jl | 12 ++++++------ src/TracedUtils.jl | 2 +- src/stdlibs/LinearAlgebra.jl | 12 +++++++----- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index a31416cc8..80a84e0b1 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1414,12 +1414,12 @@ instead. #! format: off scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet( MLIR.IR.context(), - 0, Int64[], - N, collect(Int64, 0:(N - 1)), - 0, Int64[], - 0, Int64[], - N, collect(Int64, 0:(N - 1)), - 1 + Int64(0), Int64[], + Int64(N), collect(Int64, 0:(N - 1)), + Int64(0), Int64[], + Int64(0), Int64[], + Int64(N), collect(Int64, 0:(N - 1)), + Int64(1) ) #! format: on diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 887916287..e0c41a101 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -58,7 +58,7 @@ function set_mlir_data!( return x end -function set_mlir_data!(x::AnyTracedRArray, data) +function set_mlir_data!(x::AnyTracedRArray{T}, data) where {T} setindex!(x, TracedRArray{T}(data), axes(x)...) return x end diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 82f7900f4..934730f1a 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -28,7 +28,7 @@ function TracedUtils.materialize_traced_array(x::Diagonal{T,TracedRArray{T,1}}) return diagm(parent(x)) end -function materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T} +function TracedUtils.materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T} return diagm(-1 => x.dl, 0 => x.d, 1 => x.du) end @@ -108,7 +108,7 @@ for (AT, dcomp, ocomp) in ( (:UpperTriangular, "LE", "GT"), (:UnitUpperTriangular, "LT", "GE"), ) - @eval function set_mlir_data!( + @eval function TracedUtils.set_mlir_data!( x::LinearAlgebra.$(AT){T,TracedRArray{T,2}}, data ) where {T} tdata = TracedRArray{T}(data) @@ -126,7 +126,9 @@ for (AT, dcomp, ocomp) in ( end end -function set_mlir_data!(x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data) where {T} +function TracedUtils.set_mlir_data!( + x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data +) where {T} if x.uplo == 'L' set_mlir_data!(LinearAlgebra.LowerTriangular(parent(x)), data) else @@ -135,7 +137,7 @@ function set_mlir_data!(x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data) w return x end -function set_mlir_data!(x::Tridiagonal{T,TracedRArray{T,1}}, data) where {T} +function TracedUtils.set_mlir_data!(x::Tridiagonal{T,TracedRArray{T,1}}, data) where {T} tdata = TracedRArray{T}(data) set_mlir_data!(x.dl, diag(tdata, -1).mlir_data) set_mlir_data!(x.d, diag(tdata, 0).mlir_data) @@ -249,7 +251,7 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T} # :0: note: see current operation: %0 = "tensor.empty"() : () -> tensor<0xf64> length(indices) ≤ 0 && return TracedUtils.promote_to(TracedRArray{T,1}, T[]) - return Ops.gather_getindex(x, promote_to(TracedRArray{Int,2}, indices)) + return Ops.gather_getindex(x, TracedUtils.promote_to(TracedRArray{Int,2}, indices)) end function LinearAlgebra._diagm(