Skip to content

Commit

Permalink
refactor: directly generate the region for simple_scatter_op
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 12, 2024
1 parent 21c17f7 commit a9a8b2a
Showing 1 changed file with 7 additions and 16 deletions.
23 changes: 7 additions & 16 deletions src/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,11 @@ function LinearAlgebra.diagm(v::AnyTracedRArray{T,1}) where {T}
end
function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) where {T}
m, n = LinearAlgebra.diagm_size((m, n), 0 => v) # size check

indices = Ops.constant(diagonal_indices_zero_indexed(m, n, 0)[1:length(v), :])
return simple_scatter_op((m, n), indices, materialize_traced_array(v))
end

# Common Utilities
simple_update_overwrite(x, y) = y

## This is quite handy to have but is not generalized enough to be put into Ops? Or maybe
## we can document it and place it there under a different name. It takes a list of values
## and a list of indices and constructs a matrix with the values at the indices.
Expand All @@ -273,21 +270,15 @@ function simple_scatter_op(
@assert length(updates) == size(scatter_indices, 1)
@assert size(scatter_indices, 2) == 2

# TODO: Directly use `Ops.hlo_call` for this part
(_, update_function) = make_mlir_fn(
simple_update_overwrite,
(promote_to(TracedRNumber{T}, 0), promote_to(TracedRNumber{T}, 0)),
(),
string(gensym("update_computation")),
false;
return_dialect=:stablehlo,
no_args_in_result=true,
)
update_computation = MLIR.IR.Region()
MLIR.API.mlirRegionTakeBody(
update_computation, MLIR.API.mlirOperationGetRegion(update_function, 0)
block = MLIR.IR.Block(
[mlir_type(TracedRNumber{T}), mlir_type(TracedRNumber{T})],
[MLIR.IR.Location(), MLIR.IR.Location()],
)
MLIR.IR.rmfromparent!(update_function)
return_op = MLIR.Dialects.stablehlo.return_([MLIR.IR.argument(block, 2)])
MLIR.IR.rmfromparent!(return_op)
push!(block, return_op)
pushfirst!(update_computation, block)

init_array = Ops.constant(fill(zero(T), shape)).mlir_data

Expand Down

0 comments on commit a9a8b2a

Please sign in to comment.