Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add macro to create custom Ops also on aarch64 #871

Merged
merged 23 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions docs/examples/03-reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,23 @@ function pool(S1::SummaryStat, S2::SummaryStat)
SummaryStat(m,v,n)
end

# Register the custom reduction operator. This is necessary only on platforms
# where Julia doesn't support closures as cfunctions (e.g. ARM), but can be used
# on all platforms for consistency.
MPI.@RegisterOp(pool, SummaryStat)

X = randn(10,3) .* [1,3,7]'

# Perform a scalar reduction
summ = MPI.Reduce(SummaryStat(X), pool, root, comm)
summ = MPI.Reduce(SummaryStat(X), pool, comm; root)

if MPI.Comm_rank(comm) == root
@show summ.var
end

# Perform a vector reduction:
# the reduction operator is applied elementwise
col_summ = MPI.Reduce(mapslices(SummaryStat,X,dims=1), pool, root, comm)
col_summ = MPI.Reduce(mapslices(SummaryStat,X,dims=1), pool, comm; root)

if MPI.Comm_rank(comm) == root
col_var = map(summ -> summ.var, col_summ)
Expand Down
2 changes: 2 additions & 0 deletions docs/src/knownissues.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,5 @@ However they have two limitations:

* [Julia's C-compatible function pointers](https://docs.julialang.org/en/v1/manual/calling-c-and-fortran-code/index.html#Creating-C-Compatible-Julia-Function-Pointers-1) cannot be used where the `stdcall` calling convention is expected, which is the case for 32-bit Microsoft MPI,
* closure cfunctions in Julia are based on LLVM trampolines, which are not supported on ARM architecture.

As an alternative [`MPI.@RegisterOp`](@ref) may be used to statically register reduction operations.
1 change: 1 addition & 0 deletions docs/src/reference/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ MPI.Types.duplicate

```@docs
MPI.Op
MPI.@RegisterOp
```

## Info objects
Expand Down
99 changes: 91 additions & 8 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ associative, and if `iscommutative` is true, assumed to be commutative as well.
- [`Allreduce!`](@ref)/[`Allreduce`](@ref)
- [`Scan!`](@ref)/[`Scan`](@ref)
- [`Exscan!`](@ref)/[`Exscan`](@ref)
- [`@RegisterOp`](@ref)
"""
mutable struct Op
val::MPI_Op
Expand Down Expand Up @@ -81,21 +82,36 @@ end

function (w::OpWrapper{F,T})(_a::Ptr{Cvoid}, _b::Ptr{Cvoid}, _len::Ptr{Cint}, t::Ptr{MPI_Datatype}) where {F,T}
len = unsafe_load(_len)
@assert isconcretetype(T)
a = Ptr{T}(_a)
b = Ptr{T}(_b)
for i = 1:len
unsafe_store!(b, w.f(unsafe_load(a,i), unsafe_load(b,i)), i)
if !isconcretetype(T)
concrete_T = to_type(Datatype(unsafe_load(t))) # Ptr might actually point to a Julia object so we could unsafe_pointer_to_objref?
else
concrete_T = T
end
function copy(::Type{T}) where T
@assert isconcretetype(T)
a = Ptr{T}(_a)
b = Ptr{T}(_b)
for i = 1:len
unsafe_store!(b, w.f(unsafe_load(a,i), unsafe_load(b,i)), i)
end
end
copy(concrete_T)
return nothing
end


function Op(f, T=Any; iscommutative=false)
@static if MPI_LIBRARY == "MicrosoftMPI" && Sys.WORD_SIZE == 32
error("User-defined reduction operators are not supported on 32-bit Windows.\nSee https://github.com/JuliaParallel/MPI.jl/issues/246 for more details.")
error("""
User-defined reduction operators are not supported on 32-bit Windows.
See https://github.com/JuliaParallel/MPI.jl/issues/246 for more details.
""")
elseif Sys.ARCH ∈ (:aarch64, :ppc64le, :powerpc64le) || startswith(lowercase(String(Sys.ARCH)), "arm")
error("User-defined reduction operators are currently not supported on non-Intel architectures.\nSee https://github.com/JuliaParallel/MPI.jl/issues/404 for more details.")
error("""
User-defined reduction operators are currently not supported on non-Intel architectures.
See https://github.com/JuliaParallel/MPI.jl/issues/404 for more details.

You may want to use `@RegisterOp` to statically register `f`.
""")
end
w = OpWrapper{typeof(f),T}(f)
fptr = @cfunction($w, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{MPI_Datatype}))
Expand All @@ -107,3 +123,70 @@ function Op(f, T=Any; iscommutative=false)
finalizer(free, op)
return op
end

"""
@RegisterOp(f, T)

Register a custom operator [`Op`](@ref) using the function `f` statically.
On platfroms like AArch64, Julia does not support runtime closures,
being passed to C. The generic version of [`Op`](@ref) uses runtime closures
to support arbitrary functions being passed as MPI reduction operators.
`@RegisterOp` statically adds a function to the set of functions allowed as
as an MPI operator.

```julia
function my_reduce(x, y)
2x+y-x
end
MPI.@RegisterOp(my_reduce, Int)
# ...
MPI.Reduce!(send_arr, recv_arr, my_reduce, MPI.COMM_WORLD; root=root)
#...
```
vchuravy marked this conversation as resolved.
Show resolved Hide resolved
!!! warning
Note that `@RegisterOp` works be introducing a new method of the generic function `Op`.
It can only be used as a top-level statement and may trigger method invalidations.

!!! note
`T` can be `Any`, but this will lead to a runtime dispatch.
"""
macro RegisterOp(f, T)
name_wrapper = gensym(Symbol(f, :_, T, :_wrapper))
name_fptr = gensym(Symbol(f, :_, T, :_ptr))
name_module = gensym(Symbol(f, :_, T, :_module))
# The gist is that we can use a method very similar to how we handle `min`/`max`
vchuravy marked this conversation as resolved.
Show resolved Hide resolved
# but since this might be used from user code we can't use add_load_time_hook!
# this is why we introduce a new module that has a `__init__` function.
# If this module approach is too costly for loading MPI.jl for internal use we could use
# `add_load_time_hook`
expr = quote
module $(name_module)
# import ..$f, ..$T
$(Expr(:import, Expr(:., :., :., f), Expr(:., :., :., T))) # Julia 1.6 strugles with import ..$f, ..$T
const $(name_wrapper) = $OpWrapper{typeof($f),$T}($f)
const $(name_fptr) = Ref(@cfunction($(name_wrapper), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{$MPI_Datatype})))
function __init__()
$(name_fptr)[] = @cfunction($(name_wrapper), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{$MPI_Datatype}))
end
import MPI: Op
# we can't create a const Op since MPI needs to be initialized?
function Op(::typeof($f), ::Type{<:$T}; iscommutative=false)
op = Op($OP_NULL.val, $(name_fptr)[])
# int MPI_Op_create(MPI_User_function* user_fn, int commute, MPI_Op* op)
$API.MPI_Op_create($(name_fptr)[], iscommutative, op)

finalizer($free, op)
end
end
end
expr.head = :toplevel
esc(expr)
end

@RegisterOp(min, Any)
@RegisterOp(max, Any)
@RegisterOp(+, Any)
@RegisterOp(*, Any)
@RegisterOp(&, Any)
@RegisterOp(|, Any)
@RegisterOp(⊻, Any)
8 changes: 4 additions & 4 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[compat]
AMDGPU = "0.6, 0.7, 0.8, 0.9, 1"
CUDA = "3, 4, 5"
Expand All @@ -16,7 +20,3 @@ TOML = "< 0.0.1, 1.0"
[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
27 changes: 15 additions & 12 deletions test/test_reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,15 @@ if isroot
@test sum_mesg == sz .* mesg
end

function my_reduce(x, y)
2x+y-x
end
MPI.@RegisterOp(my_reduce, Any)

if can_do_closures
operators = [MPI.SUM, +, (x,y) -> 2x+y-x]
operators = [MPI.SUM, +, my_reduce, (x,y) -> 2x+y-x]
else
operators = [MPI.SUM, +]
operators = [MPI.SUM, +, my_reduce]
end

for T = [Int]
Expand Down Expand Up @@ -117,19 +122,17 @@ end

MPI.Barrier( MPI.COMM_WORLD )

if can_do_closures
send_arr = [Double64(i)/10 for i = 1:10]

result = MPI.Reduce(send_arr, +, MPI.COMM_WORLD; root=root)
if rank == root
@test result ≈ [Double64(sz*i)/10 for i = 1:10] rtol=sz*eps(Double64)
else
@test result === nothing
end
send_arr = [Double64(i)/10 for i = 1:10]

MPI.Barrier( MPI.COMM_WORLD )
result = MPI.Reduce(send_arr, +, MPI.COMM_WORLD; root=root)
if rank == root
@test result ≈ [Double64(sz*i)/10 for i = 1:10] rtol=sz*eps(Double64)
else
@test result === nothing
end

MPI.Barrier( MPI.COMM_WORLD )

GC.gc()
MPI.Finalize()
@test MPI.Finalized()
Loading