Skip to content

Commit

Permalink
feat: generalize diagm
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 12, 2024
1 parent a9a8b2a commit b2c7e54
Showing 1 changed file with 19 additions and 27 deletions.
46 changes: 19 additions & 27 deletions src/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,7 @@ end
function materialize_traced_array(
x::LinearAlgebra.Tridiagonal{T,TracedRArray{T,1}}
) where {T}
scatter_indices = vcat(
diagonal_indices_zero_indexed(size(x, 1), size(x, 2), -1),
diagonal_indices_zero_indexed(size(x, 1), size(x, 2), 0),
diagonal_indices_zero_indexed(size(x, 1), size(x, 2), 1),
)
scatter_indices = Ops.constant(scatter_indices)

updates = TracedRArray{T,1}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.concatenate(
[x.dl.mlir_data, x.d.mlir_data, x.du.mlir_data]; dimension=0
),
1,
),
(size(scatter_indices, 1),),
)

return simple_scatter_op(size(x), scatter_indices, updates)
return LinearAlgebra.diagm(-1 => x.dl, 0 => x.d, 1 => x.du)
end

for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
Expand Down Expand Up @@ -251,13 +233,23 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T}
return TracedRArray{T,1}((), res, (diag_length,))
end

function LinearAlgebra.diagm(v::AnyTracedRArray{T,1}) where {T}
return LinearAlgebra.diagm(length(v), length(v), v)
end
function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) where {T}
m, n = LinearAlgebra.diagm_size((m, n), 0 => v) # size check
indices = Ops.constant(diagonal_indices_zero_indexed(m, n, 0)[1:length(v), :])
return simple_scatter_op((m, n), indices, materialize_traced_array(v))
function LinearAlgebra._diagm(
shape, kv::Pair{<:Integer,<:AnyTracedRArray{T,1}}...
) where {T}
m, n = LinearAlgebra.diagm_size(shape, kv...)
scatter_indices = Matrix{Int64}[]
concat_inputs = MLIR.IR.Value[]
for (k, v) in kv
push!(scatter_indices, diagonal_indices_zero_indexed(m, n, k)[1:length(v), :])
push!(concat_inputs, get_mlir_data(v))
end
scatter_indices = Ops.constant(reduce(vcat, scatter_indices))
values = TracedRArray{T,1}(
(),
MLIR.IR.result(MLIR.Dialects.stablehlo.concatenate(concat_inputs; dimension=0), 1),
(size(scatter_indices, 1),),
)
return simple_scatter_op((m, n), scatter_indices, values)
end

# Common Utilities
Expand Down Expand Up @@ -309,7 +301,7 @@ function simple_scatter_op(
return TracedRArray{T,2}((), res, shape)
end

# The cartesian version doesn't exist in julia 1.10
## The cartesian version doesn't exist in julia 1.10
function diagonal_indices_zero_indexed(m::Integer, n::Integer, k::Integer=0)
Cstart = CartesianIndex(1 + max(0, -k), 1 + max(0, k))
Cstep = CartesianIndex(1, 1)
Expand Down

0 comments on commit b2c7e54

Please sign in to comment.