Skip to content

Commit f4acb15

Browse files
committed
Update Ops.comm_rank
1 parent 6e9b1c5 commit f4acb15

File tree

1 file changed

+26
-76
lines changed

1 file changed

+26
-76
lines changed

ext/ReactantMPIExt/Ops.jl

+26-76
Original file line numberDiff line numberDiff line change
@@ -18,94 +18,44 @@ using MPI: MPI
1818
# return mpi.finalize(; location)
1919
# end
2020

21-
# TODO change to this kind of MLIR
22-
# module {
23-
# llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
24-
# func.func @$sym_name(%comm_ptr : !llvm.ptr, %rank_ptr : !llvm.ptr) -> () {
25-
# %comm = llvm.load %comm_ptr : !llvm.ptr -> i32
26-
# %world_ptr = arith.constant dense<0x0asdfa> : tensor<i32>
27-
# memref.get_global # global variable MPI_COMM_GLOBAL
28-
# %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32)
29-
# func.return
30-
# }
31-
# func.func @real_$sym_name() -> tensor<> {
32-
# %rank_ptr = stablehlo.constant dense<-1> : tensor<i32> # this is a placeholder
33-
# %rank = enzymexla.jit_call @$sym_name(%world_ptr, %rank_ptr) {
34-
# output_operand_alias = [
35-
# #stablehlo.output_operand_alias<output_tuple_indices = [],
36-
# operand_index = 1,
37-
# operand_tuple_indices = []>
38-
# ]
39-
# }
40-
# }
41-
# }
42-
4321
function comm_rank(; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LINE__))
4422
sym_name = "enzymexla_wrapper_MPI_Comm_rank"
45-
# sym_attr = IR.FlatSymbolRefAttribute(sym_name)
46-
comm = MPI.COMM_WORLD
47-
48-
@show IR.mmodule()
23+
sym_attr = IR.FlatSymbolRefAttribute(sym_name)
4924

50-
# memref.global constant @MPI_COMM_WORLD : memref<i32>
51-
# llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
25+
# dirty hack: since MPI constants are i32, we pass the info as the pointer and then bitcast
26+
# DONT LOAD FROM THEM!
27+
IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : i32")
28+
IR.inject!("MPI_Comm_rank", "llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32")
5229

5330
#! format: off
54-
# IR.tryinjectop!("MPI_COMM_WORLD", "memref.global @MPI_COMM_WORLD : memref<i32>")
55-
# IR.tryinjectop!("MPI_Comm_rank", "module { llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 }")
56-
IR.inject!("$(sym_name)_jit", """
57-
func.func @$(sym_name)_jit(%rank_ptr : !llvm.ptr) -> () {
58-
%comm_ref = memref.get_global @MPI_COMM_WORLD : memref<i32>
59-
%comm_ptr = "enzymexla.memref2pointer"(%comm_ref) : (memref<i32>) -> (!llvm.ptr)
31+
IR.inject!(sym_name, """
32+
func.func @$sym_name(%rank_ptr : !llvm.ptr) -> () {
33+
%comm_ptr = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
6034
%comm = llvm.ptrtoint %comm_ptr : !llvm.ptr to i32
6135
%status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32)
6236
func.return
6337
}
6438
""")
65-
@show res
66-
#! format: on
67-
68-
# %comm_ref = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
69-
# %comm = llvm.ptrtoint %comm_ref : !llvm.ptr to i32
70-
71-
#! format: off
72-
# return Reactant.Ops.hlo_call("""module {
73-
# memref.global constant @MPI_COMM_WORLD : memref<i32>
74-
# llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
75-
# func.func @$(sym_name)_jit(%rank_ptr : !llvm.ptr) -> () {
76-
# %comm_ref = memref.get_global @MPI_COMM_WORLD : memref<i32>
77-
# %comm_ptr = "enzymexla.memref2pointer"(%comm_ref) : (memref<i32>) -> (!llvm.ptr)
78-
# %comm = llvm.ptrtoint %comm_ptr : !llvm.ptr to i32
79-
# %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32)
80-
# func.return
81-
# }
82-
# func.func @$sym_name() -> tensor<i32> {
83-
# %rank_placeholder = stablehlo.constant dense<-1> : tensor<i32>
84-
# %rank = enzymexla.jit_call @$(sym_name)_jit(%rank_placeholder) {
85-
# output_operand_aliases = [
86-
# #stablehlo.output_operand_alias<output_tuple_indices = [],
87-
# operand_index = 1,
88-
# operand_tuple_indices = []>
89-
# ]
90-
# } : (tensor<i32>) -> (tensor<i32>)
91-
# func.return %rank : tensor<i32>
92-
# }
93-
# }"""; func_name=sym_name)
9439
#! format: on
40+
rank_placeholder = Reactant.Ops.constant(fill(Cint(-1)))
41+
output_operand_aliases = IR.Attribute([
42+
IR.Attribute(
43+
MLIR.API.stablehloOutputOperandAliasGet(
44+
MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL
45+
),
46+
),
47+
])
9548

96-
# NOTE we assume here that `MPI_Comm` is of word-size
97-
# comm = Reactant.Ops.constant(Base.unsafe_convert(Cint, comm))
98-
# value_out = Reactant.Ops.constant(fill(Cint(-1)))
99-
# inputs = IR.Value[comm.mlir_data, value_out.mlir_data]
100-
101-
# tensor_int_type = IR.TensorType(Int[], IR.Type(Cint))
102-
# signature = IR.Type[tensor_int_type, tensor_int_type]
103-
104-
# # TODO output_operand_aliases
105-
# res = IR.result(
106-
# enzymexla.jit_call(inputs; fn=sym_attr, result_0=signature, location), 2
107-
# )
108-
# return TracedRNumber{Cint}((), res)
49+
res = IR.result(
50+
enzymexla.jit_call(
51+
IR.Value[rank_placeholder.mlir_data];
52+
fn=sym_attr,
53+
result_0=[IR.TensorType(Int[], IR.Type(Cint))],
54+
location,
55+
output_operand_aliases,
56+
),
57+
)
58+
return TracedRNumber{Cint}((), res)
10959
end
11060

11161
function comm_size(comm; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @__LINE__))

0 commit comments

Comments
 (0)