Skip to content

Commit 19c0eca

Browse files
committed
Implement MPI.Recv!
1 parent 7eb4107 commit 19c0eca

File tree

2 files changed

+78
-20
lines changed

2 files changed

+78
-20
lines changed

ext/ReactantMPIExt/Ops.jl

+54-8
Original file line numberDiff line numberDiff line change
@@ -369,20 +369,66 @@ function isend(
369369
end
370370

371371
function recv!(
372-
ref::TracedRArray,
372+
recvbuf::TracedRArray,
373373
tag::TracedRNumber,
374374
src::TracedRNumber;
375375
location=mlir_stacktrace("mpi.recv", @__FILE__, @__LINE__),
376376
)
377-
# return mpi.recv(ref.mlir_data, tag.mlir_data, src.mlir_data; location)
377+
T = Reactant.unwrapped_eltype(recvbuf)
378+
mpi_datatype = convert_julia_type_to_mpi_datatype(T)
379+
mpi_datatype_name = inject_mpi_datatype!(mpi_datatype)
378380

379-
# TODO emit constant for size and datatype, and pass as args
380-
inputs = IR.Value[ref.mlir_data, tag.mlir_data, src.mlir_data]
381-
sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Recv")
382-
rettype = IR.Type[]
381+
sym_name = "enzymexla_wrapper_MPI_Recv_$(mpi_datatype_name)"
382+
sym_attr = IR.FlatSymbolRefAttribute(sym_name)
383383

384-
IR.result(enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location))
385-
return ref
384+
IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr")
385+
IR.inject!(
386+
"MPI_STATUS_IGNORE", "llvm.mlir.global constant @MPI_STATUS_IGNORE() : !llvm.ptr"
387+
)
388+
IR.inject!(
389+
"MPI_Recv",
390+
"llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32",
391+
)
392+
393+
#! format: off
394+
IR.inject!(sym_name, """
395+
func.func @$sym_name(%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %source_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () {
396+
%comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
397+
%datatype = llvm.mlir.addressof @$mpi_datatype_name : !llvm.ptr
398+
%status = llvm.mlir.addressof @MPI_STATUS_IGNORE : !llvm.ptr
399+
%count = llvm.load %count_ptr : !llvm.ptr -> i32
400+
%source = llvm.load %source_ptr : !llvm.ptr -> i32
401+
%tag = llvm.load %tag_ptr : !llvm.ptr -> i32
402+
%errcode = llvm.call @MPI_Recv(%buf, %count, %datatype, %source, %tag, %comm, %status) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> (i32)
403+
func.return
404+
}
405+
""")
406+
#! format: on
407+
408+
count = Reactant.Ops.constant(fill(length(recvbuf)))
409+
410+
output_operand_aliases = IR.Attribute([
411+
IR.Attribute(
412+
MLIR.API.stablehloOutputOperandAliasGet(
413+
MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL
414+
),
415+
),
416+
])
417+
418+
res = IR.result(
419+
enzymexla.jit_call(
420+
IR.Value[recvbuf.mlir_data, count.mlir_data, src.mlir_data, tag.mlir_data];
421+
fn=sym_attr,
422+
result_0=[mlir_type(recvbuf)],
423+
output_operand_aliases,
424+
location,
425+
),
426+
1,
427+
)
428+
429+
recvbuf.mlir_data = res
430+
431+
return recvbuf
386432
end
387433

388434
# TODO need c-function for creating MLIR `mpi.request` type?

ext/ReactantMPIExt/Overrides.jl

+24-12
Original file line numberDiff line numberDiff line change
@@ -75,22 +75,34 @@ function MPI.Isend(
7575
return req
7676
end
7777

78-
# TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer`
78+
function MPI.Recv!(buf::TracedRArray, source::Integer, tag::Integer, comm::MPI.Comm)
79+
tag = Reactant.Ops.constant(tag)
80+
source = Reactant.Ops.constant(source)
81+
return MPI.Recv!(buf, source, tag, comm)
82+
end
83+
7984
function MPI.Recv!(
80-
recvbuf::TracedRArray, source::Number, tag::Number, comm::MPI.Comm, status
85+
recvbuf::TracedRArray,
86+
source::Integer,
87+
tag::Integer,
88+
comm::MPI.Comm,
89+
::Type{MPI.API.MPI_Status},
8190
)
82-
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
83-
@assert isnothing(status) "Status not supported yet"
84-
85-
tag = if !(tag isa TracedRNumber)
86-
Reactant.Ops.constant(tag)
87-
end
91+
return MPI.Recv!(recvbuf, source, tag, comm)
92+
end
8893

89-
source = if !(source isa TracedRNumber)
90-
Reactant.Ops.constant(source)
91-
end
94+
function MPI.Recv!(
95+
recvbuf::TracedRArray, source::Integer, tag::Integer, comm::MPI.Comm, ::Nothing
96+
)
97+
return MPI.Recv!(recvbuf, source, tag, comm)
98+
end
9299

93-
return Ops.recv(recvbuf, tag, source)
100+
# TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer`
101+
function MPI.Recv!(
102+
recvbuf::TracedRArray, source::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm
103+
)
104+
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
105+
return Ops.recv!(recvbuf, tag, source)
94106
end
95107

96108
# TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer`

0 commit comments

Comments
 (0)