diff --git a/src/operators.jl b/src/operators.jl index fec95eb7a..48ddfba99 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -10,14 +10,23 @@ An MPI reduction operator, for use with [Reduce/Scan collective operations](@ref Wrap the Julia reduction function `op` for arguments of type `T`. `op` is assumed to be associative, and if `iscommutative` is true, assumed to be commutative as well. +Certain combinations of `op` and `T` will use the predefined MPI intrinsic operations, +otherwise it will wrap the function in a Julia closure at runtime. The macro [`@Op`](@ref) +can be used to wrap functions ahead of time, which may reduce runtime overhead, and is +required on platforms where closures are not supported (such as ARM and PPC). + +User usage of this function is generally unnecessary since it will be called directly +by the relevant MPI collective operations. + ## See also - [`Reduce!`](@ref)/[`Reduce`](@ref) - [`Allreduce!`](@ref)/[`Allreduce`](@ref) - [`Scan!`](@ref)/[`Scan`](@ref) - [`Exscan!`](@ref)/[`Exscan`](@ref) + """ -@mpi_handle Op MPI_Op fptr +@mpi_handle Op MPI_Op cfunc::Union{Base.CFunction, Nothing} const OP_NULL = _Op(MPI_OP_NULL, nothing) const BAND = _Op(MPI_BAND, nothing) @@ -74,16 +83,56 @@ function Op(f, T=Any; iscommutative=false) error("User-defined reduction operators are not supported on 32-bit Windows.\nSee https://github.com/JuliaParallel/MPI.jl/issues/246 for more details.") end w = OpWrapper{typeof(f),T}(f) - fptr = @cfunction($w, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{MPI_Datatype})) - - op = Op(OP_NULL.val, fptr) + cfunc = @cfunction($w, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{MPI_Datatype})) + + op = Op(OP_NULL.val, cfunc) # int MPI_Op_create(MPI_User_function* user_fn, int commute, MPI_Op* op) @mpichk ccall((:MPI_Op_create, libmpi), Cint, (Ptr{Cvoid}, Cint, Ptr{MPI_Op}), - fptr, iscommutative, op) + cfunc, iscommutative, op) refcount_inc() finalizer(free, op) return op end +""" + @declareOp(op, T[, iscommutative]) + +Declare a Julia function `op` to be used as a custom MPI operator [`Op`](@ref) for +variables of type `T`. This will create the [`Op`](@ref) object and define an appropriate +constructor method to `Op`. The `iscommutative` argument indicates to MPI whether or not +MPI can assume the operation is commutative (default is `false`). + +The usage of this macro is optional: the main advantage of this is that will avoid the use +of a closure (see ["Closure cfunctions" in the Julia +manual](https://docs.julialang.org/en/v1/manual/calling-c-and-fortran-code/#Closure-cfunctions-1), +which may offer some performance advantages. + +This should only be called once per combination of `op` and `T`, and should be at the +top-level (e.g. not inside a function). It can be safely used before `MPI.Init()` and +inside a precompiled module. +""" +macro declareOp(f, T, iscommutative=false) + opwrap = gensym(:opwrap) # we need to manually gensym for use with `@cfunction` macro + quote + if !Base.issingletontype(typeof($(esc(f)))) + error("@declareOp macro can only be used with instances of singleton types") + end + const op = Op(OP_NULL.val, nothing) + const $(esc(opwrap)) = OpWrapper{typeof($(esc(f))),$(esc(T))}($(esc(f))) + function initop() + fptr = @cfunction($opwrap, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{MPI_Datatype})) + @mpichk ccall((:MPI_Op_create, libmpi), Cint, + (Ptr{Cvoid}, Cint, Ptr{MPI_Op}), + fptr, $iscommutative, op) + end + if Initialized() && !Finalized() + initop() + else + push!(mpi_init_hooks, initop) + end + MPI.Op(::typeof($(esc(f))), ::Type{$(esc(T))}; iscommutative=$iscommutative) = op + op + end +end