@@ -369,20 +369,66 @@ function isend(
369
369
end
370
370
371
371
function recv! (
372
- ref :: TracedRArray ,
372
+ recvbuf :: TracedRArray ,
373
373
tag:: TracedRNumber ,
374
374
src:: TracedRNumber ;
375
375
location= mlir_stacktrace (" mpi.recv" , @__FILE__ , @__LINE__ ),
376
376
)
377
- # return mpi.recv(ref.mlir_data, tag.mlir_data, src.mlir_data; location)
377
+ T = Reactant. unwrapped_eltype (recvbuf)
378
+ mpi_datatype = convert_julia_type_to_mpi_datatype (T)
379
+ mpi_datatype_name = inject_mpi_datatype! (mpi_datatype)
378
380
379
- # TODO emit constant for size and datatype, and pass as args
380
- inputs = IR. Value[ref. mlir_data, tag. mlir_data, src. mlir_data]
381
- sym = IR. FlatSymbolRefAttribute (" enzymexla_wrapper_MPI_Recv" )
382
- rettype = IR. Type[]
381
+ sym_name = " enzymexla_wrapper_MPI_Recv_$(mpi_datatype_name) "
382
+ sym_attr = IR. FlatSymbolRefAttribute (sym_name)
383
383
384
- IR. result (enzymexla. jit_call (inputs; fn= sym, result_0= rettype, location))
385
- return ref
384
+ IR. inject! (" MPI_COMM_WORLD" , " llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr" )
385
+ IR. inject! (
386
+ " MPI_STATUS_IGNORE" , " llvm.mlir.global constant @MPI_STATUS_IGNORE() : !llvm.ptr"
387
+ )
388
+ IR. inject! (
389
+ " MPI_Recv" ,
390
+ " llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32" ,
391
+ )
392
+
393
+ # ! format: off
394
+ IR. inject! (sym_name, """
395
+ func.func @$sym_name (%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %source_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () {
396
+ %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
397
+ %datatype = llvm.mlir.addressof @$mpi_datatype_name : !llvm.ptr
398
+ %status = llvm.mlir.addressof @MPI_STATUS_IGNORE : !llvm.ptr
399
+ %count = llvm.load %count_ptr : !llvm.ptr -> i32
400
+ %source = llvm.load %source_ptr : !llvm.ptr -> i32
401
+ %tag = llvm.load %tag_ptr : !llvm.ptr -> i32
402
+ %errcode = llvm.call @MPI_Recv(%buf, %count, %datatype, %source, %tag, %comm, %status) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> (i32)
403
+ func.return
404
+ }
405
+ """ )
406
+ # ! format: on
407
+
408
+ count = Reactant. Ops. constant (fill (length (recvbuf)))
409
+
410
+ output_operand_aliases = IR. Attribute ([
411
+ IR. Attribute (
412
+ MLIR. API. stablehloOutputOperandAliasGet (
413
+ MLIR. IR. context (), 0 , C_NULL , 0 , 0 , C_NULL
414
+ ),
415
+ ),
416
+ ])
417
+
418
+ res = IR. result (
419
+ enzymexla. jit_call (
420
+ IR. Value[recvbuf. mlir_data, count. mlir_data, src. mlir_data, tag. mlir_data];
421
+ fn= sym_attr,
422
+ result_0= [mlir_type (recvbuf)],
423
+ output_operand_aliases,
424
+ location,
425
+ ),
426
+ 1 ,
427
+ )
428
+
429
+ recvbuf. mlir_data = res
430
+
431
+ return recvbuf
386
432
end
387
433
388
434
# TODO need c-function for creating MLIR `mpi.request` type?
0 commit comments