Skip to content

Commit 62f4ec5

Browse files
committed
Fix MLIR of Ops.comm_rank
1 parent 95f723d commit 62f4ec5

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

ext/ReactantMPIExt/Ops.jl

+4-5
Original file line numberDiff line numberDiff line change
@@ -66,23 +66,22 @@ function comm_rank(; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LIN
6666
comm = MPI.COMM_WORLD
6767

6868
#! format: off
69-
return Ops.hlo_call("""module {
69+
return Reactant.Ops.hlo_call("""module {
7070
llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
71-
func.func @$(sym_name)_jit(%comm_ptr : !llvm.ptr, %rank_ptr : !llvm.ptr) -> () {
72-
%comm = llvm.load %comm_ptr : !llvm.ptr -> i32
71+
func.func @$(sym_name)_jit(%rank_ptr : !llvm.ptr) -> () {
7372
%comm = arith.constant $(Base.unsafe_convert(Cint, comm)) : i32
7473
%status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32)
7574
func.return
7675
}
7776
func.func @$sym_name() -> tensor<i32> {
7877
%rank_placeholder = stablehlo.constant dense<-1> : tensor<i32>
7978
%rank = enzymexla.jit_call @$(sym_name)_jit(%rank_placeholder) {
80-
output_operand_alias = [
79+
output_operand_aliases = [
8180
#stablehlo.output_operand_alias<output_tuple_indices = [],
8281
operand_index = 1,
8382
operand_tuple_indices = []>
8483
]
85-
}
84+
} : (tensor<i32>) -> (tensor<i32>)
8685
func.return %rank : tensor<i32>
8786
}
8887
}"""; func_name=sym_name)

0 commit comments

Comments
 (0)