Skip to content

Commit

Permalink
fixup! fixup! fixup! add macro to create custom Ops also on aarch64
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Sep 3, 2024
1 parent 50c46ac commit 20d1bc9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
18 changes: 14 additions & 4 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,24 +108,34 @@ function Op(f, T=Any; iscommutative=false)
return op
end

"""
@Op(f, T)
Define a custom operator [`Op`](@ref) using the function `f`.
"""
macro Op(f, T)
name_wrapper = gensym(Symbol(f, :_, T, :_wrapper))
name_fptr = gensym(Symbol(f, :_, T, :_ptr))
name_module = gensym(Symbol(f, :_, T, :_module))
esc(quote
expr = quote
module $(name_module)
import ..$f, ..$T
$(name_wrapper) = $OpWrapper{typeof($f),$T}($f)
$(name_fptr) = @cfunction($(name_wrapper), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{$MPI_Datatype}))
function __init__()
global $(name_fptr) = @cfunction($(name_wrapper), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}, Ptr{$MPI_Datatype}))
end
function $Op(::typeof($f), ::Type{T}; iscommutative=true)
op = $Op($OP_NULL.val, $(name_fptr))
import MPI: Op
function Op(::typeof($f), ::Type{$T}; iscommutative=true)
op = Op($OP_NULL.val, $(name_fptr))
# int MPI_Op_create(MPI_User_function* user_fn, int commute, MPI_Op* op)
$API.MPI_Op_create($(name_fptr), iscommutative, op)

finalizer($free, op)
end
end
end)
end
expr.head = :toplevel
esc(expr)
end
8 changes: 3 additions & 5 deletions test/test_reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,10 @@ if isroot
@test sum_mesg == sz .* mesg
end

@eval begin
function my_reduce(x, y)
2x+y-x
end
MPI.@Op(my_reduce, Int)
function my_reduce(x, y)
2x+y-x
end
MPI.@Op(my_reduce, Int)

if can_do_closures
operators = [MPI.SUM, +, my_reduce, (x,y) -> 2x+y-x]
Expand Down

0 comments on commit 20d1bc9

Please sign in to comment.