Skip to content

Commit

Permalink
fix: incorrect rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 19, 2024
1 parent 1c0d744 commit f39adeb
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
12 changes: 6 additions & 6 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1414,12 +1414,12 @@ instead.
#! format: off
scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet(
MLIR.IR.context(),
0, Int64[],
N, collect(Int64, 0:(N - 1)),
0, Int64[],
0, Int64[],
N, collect(Int64, 0:(N - 1)),
1
Int64(0), Int64[],
Int64(N), collect(Int64, 0:(N - 1)),
Int64(0), Int64[],
Int64(0), Int64[],
Int64(N), collect(Int64, 0:(N - 1)),
Int64(1)
)
#! format: on

Expand Down
2 changes: 1 addition & 1 deletion src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ function set_mlir_data!(
return x
end

function set_mlir_data!(x::AnyTracedRArray, data)
function set_mlir_data!(x::AnyTracedRArray{T}, data) where {T}
setindex!(x, TracedRArray{T}(data), axes(x)...)
return x
end
Expand Down
12 changes: 7 additions & 5 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function TracedUtils.materialize_traced_array(x::Diagonal{T,TracedRArray{T,1}})
return diagm(parent(x))
end

function materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T}
function TracedUtils.materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T}
return diagm(-1 => x.dl, 0 => x.d, 1 => x.du)
end

Expand Down Expand Up @@ -108,7 +108,7 @@ for (AT, dcomp, ocomp) in (
(:UpperTriangular, "LE", "GT"),
(:UnitUpperTriangular, "LT", "GE"),
)
@eval function set_mlir_data!(
@eval function TracedUtils.set_mlir_data!(
x::LinearAlgebra.$(AT){T,TracedRArray{T,2}}, data
) where {T}
tdata = TracedRArray{T}(data)
Expand All @@ -126,7 +126,9 @@ for (AT, dcomp, ocomp) in (
end
end

function set_mlir_data!(x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data) where {T}
function TracedUtils.set_mlir_data!(
x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data
) where {T}
if x.uplo == 'L'
set_mlir_data!(LinearAlgebra.LowerTriangular(parent(x)), data)
else
Expand All @@ -135,7 +137,7 @@ function set_mlir_data!(x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data) w
return x
end

function set_mlir_data!(x::Tridiagonal{T,TracedRArray{T,1}}, data) where {T}
function TracedUtils.set_mlir_data!(x::Tridiagonal{T,TracedRArray{T,1}}, data) where {T}
tdata = TracedRArray{T}(data)
set_mlir_data!(x.dl, diag(tdata, -1).mlir_data)
set_mlir_data!(x.d, diag(tdata, 0).mlir_data)
Expand Down Expand Up @@ -249,7 +251,7 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T}
# <unknown>:0: note: see current operation: %0 = "tensor.empty"() : () -> tensor<0xf64>
length(indices) 0 && return TracedUtils.promote_to(TracedRArray{T,1}, T[])

return Ops.gather_getindex(x, promote_to(TracedRArray{Int,2}, indices))
return Ops.gather_getindex(x, TracedUtils.promote_to(TracedRArray{Int,2}, indices))
end

function LinearAlgebra._diagm(
Expand Down

0 comments on commit f39adeb

Please sign in to comment.