@@ -18,94 +18,44 @@ using MPI: MPI
18
18
# return mpi.finalize(; location)
19
19
# end
20
20
21
- # TODO change to this kind of MLIR
22
- # module {
23
- # llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
24
- # func.func @$sym_name(%comm_ptr : !llvm.ptr, %rank_ptr : !llvm.ptr) -> () {
25
- # %comm = llvm.load %comm_ptr : !llvm.ptr -> i32
26
- # %world_ptr = arith.constant dense<0x0asdfa> : tensor<i32>
27
- # memref.get_global # global variable MPI_COMM_GLOBAL
28
- # %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32)
29
- # func.return
30
- # }
31
- # func.func @real_$sym_name() -> tensor<> {
32
- # %rank_ptr = stablehlo.constant dense<-1> : tensor<i32> # this is a placeholder
33
- # %rank = enzymexla.jit_call @$sym_name(%world_ptr, %rank_ptr) {
34
- # output_operand_alias = [
35
- # #stablehlo.output_operand_alias<output_tuple_indices = [],
36
- # operand_index = 1,
37
- # operand_tuple_indices = []>
38
- # ]
39
- # }
40
- # }
41
- # }
42
-
43
21
function comm_rank (; location= mlir_stacktrace (" mpi.comm_rank" , @__FILE__ , @__LINE__ ))
44
22
sym_name = " enzymexla_wrapper_MPI_Comm_rank"
45
- # sym_attr = IR.FlatSymbolRefAttribute(sym_name)
46
- comm = MPI. COMM_WORLD
47
-
48
- @show IR. mmodule ()
23
+ sym_attr = IR. FlatSymbolRefAttribute (sym_name)
49
24
50
- # memref.global constant @MPI_COMM_WORLD : memref<i32>
51
- # llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
25
+ # dirty hack: since MPI constants are i32, we pass the info as the pointer and then bitcast
26
+ # DONT LOAD FROM THEM!
27
+ IR. inject! (" MPI_COMM_WORLD" , " llvm.mlir.global constant @MPI_COMM_WORLD() : i32" )
28
+ IR. inject! (" MPI_Comm_rank" , " llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32" )
52
29
53
30
# ! format: off
54
- # IR.tryinjectop!("MPI_COMM_WORLD", "memref.global @MPI_COMM_WORLD : memref<i32>")
55
- # IR.tryinjectop!("MPI_Comm_rank", "module { llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 }")
56
- IR. inject! (" $(sym_name) _jit" , """
57
- func.func @$(sym_name) _jit(%rank_ptr : !llvm.ptr) -> () {
58
- %comm_ref = memref.get_global @MPI_COMM_WORLD : memref<i32>
59
- %comm_ptr = "enzymexla.memref2pointer"(%comm_ref) : (memref<i32>) -> (!llvm.ptr)
31
+ IR. inject! (sym_name, """
32
+ func.func @$sym_name (%rank_ptr : !llvm.ptr) -> () {
33
+ %comm_ptr = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
60
34
%comm = llvm.ptrtoint %comm_ptr : !llvm.ptr to i32
61
35
%status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32)
62
36
func.return
63
37
}
64
38
""" )
65
- @show res
66
- # ! format: on
67
-
68
- # %comm_ref = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
69
- # %comm = llvm.ptrtoint %comm_ref : !llvm.ptr to i32
70
-
71
- # ! format: off
72
- # return Reactant.Ops.hlo_call("""module {
73
- # memref.global constant @MPI_COMM_WORLD : memref<i32>
74
- # llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
75
- # func.func @$(sym_name)_jit(%rank_ptr : !llvm.ptr) -> () {
76
- # %comm_ref = memref.get_global @MPI_COMM_WORLD : memref<i32>
77
- # %comm_ptr = "enzymexla.memref2pointer"(%comm_ref) : (memref<i32>) -> (!llvm.ptr)
78
- # %comm = llvm.ptrtoint %comm_ptr : !llvm.ptr to i32
79
- # %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32)
80
- # func.return
81
- # }
82
- # func.func @$sym_name() -> tensor<i32> {
83
- # %rank_placeholder = stablehlo.constant dense<-1> : tensor<i32>
84
- # %rank = enzymexla.jit_call @$(sym_name)_jit(%rank_placeholder) {
85
- # output_operand_aliases = [
86
- # #stablehlo.output_operand_alias<output_tuple_indices = [],
87
- # operand_index = 1,
88
- # operand_tuple_indices = []>
89
- # ]
90
- # } : (tensor<i32>) -> (tensor<i32>)
91
- # func.return %rank : tensor<i32>
92
- # }
93
- # }"""; func_name=sym_name)
94
39
# ! format: on
40
+ rank_placeholder = Reactant. Ops. constant (fill (Cint (- 1 )))
41
+ output_operand_aliases = IR. Attribute ([
42
+ IR. Attribute (
43
+ MLIR. API. stablehloOutputOperandAliasGet (
44
+ MLIR. IR. context (), 0 , C_NULL , 0 , 0 , C_NULL
45
+ ),
46
+ ),
47
+ ])
95
48
96
- # NOTE we assume here that `MPI_Comm` is of word-size
97
- # comm = Reactant.Ops.constant(Base.unsafe_convert(Cint, comm))
98
- # value_out = Reactant.Ops.constant(fill(Cint(-1)))
99
- # inputs = IR.Value[comm.mlir_data, value_out.mlir_data]
100
-
101
- # tensor_int_type = IR.TensorType(Int[], IR.Type(Cint))
102
- # signature = IR.Type[tensor_int_type, tensor_int_type]
103
-
104
- # # TODO output_operand_aliases
105
- # res = IR.result(
106
- # enzymexla.jit_call(inputs; fn=sym_attr, result_0=signature, location), 2
107
- # )
108
- # return TracedRNumber{Cint}((), res)
49
+ res = IR. result (
50
+ enzymexla. jit_call (
51
+ IR. Value[rank_placeholder. mlir_data];
52
+ fn= sym_attr,
53
+ result_0= [IR. TensorType (Int[], IR. Type (Cint))],
54
+ location,
55
+ output_operand_aliases,
56
+ ),
57
+ )
58
+ return TracedRNumber {Cint} ((), res)
109
59
end
110
60
111
61
function comm_size (comm; location= mlir_stacktrace (" mpi.comm_size" , @__FILE__ , @__LINE__ ))
0 commit comments