diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index fbce7c973..ad6658b64 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -256,14 +256,11 @@ function LinearAlgebra.diagm(v::AnyTracedRArray{T,1}) where {T} end function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) where {T} m, n = LinearAlgebra.diagm_size((m, n), 0 => v) # size check - indices = Ops.constant(diagonal_indices_zero_indexed(m, n, 0)[1:length(v), :]) return simple_scatter_op((m, n), indices, materialize_traced_array(v)) end # Common Utilities -simple_update_overwrite(x, y) = y - ## This is quite handy to have but is not generalized enough to be put into Ops? Or maybe ## we can document it and place it there under a different name. It takes a list of values ## and a list of indices and constructs a matrix with the values at the indices. @@ -273,21 +270,15 @@ function simple_scatter_op( @assert length(updates) == size(scatter_indices, 1) @assert size(scatter_indices, 2) == 2 - # TODO: Directly use `Ops.hlo_call` for this part - (_, update_function) = make_mlir_fn( - simple_update_overwrite, - (promote_to(TracedRNumber{T}, 0), promote_to(TracedRNumber{T}, 0)), - (), - string(gensym("update_computation")), - false; - return_dialect=:stablehlo, - no_args_in_result=true, - ) update_computation = MLIR.IR.Region() - MLIR.API.mlirRegionTakeBody( - update_computation, MLIR.API.mlirOperationGetRegion(update_function, 0) + block = MLIR.IR.Block( + [mlir_type(TracedRNumber{T}), mlir_type(TracedRNumber{T})], + [MLIR.IR.Location(), MLIR.IR.Location()], ) - MLIR.IR.rmfromparent!(update_function) + return_op = MLIR.Dialects.stablehlo.return_([MLIR.IR.argument(block, 2)]) + MLIR.IR.rmfromparent!(return_op) + push!(block, return_op) + pushfirst!(update_computation, block) init_array = Ops.constant(fill(zero(T), shape)).mlir_data