Skip to content

Commit

Permalink
Add communicator to the distributed datawrangling utils (#282)
Browse files Browse the repository at this point in the history
* add commuincator

* Update distributed_utils.jl

* Update distributed_utils.jl

* add a safety net to avoid conflicts

* Update test_distributed_utils.jl
  • Loading branch information
simone-silvestri authored Dec 3, 2024
1 parent 4b0d77d commit 0fff4fa
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 43 deletions.
106 changes: 71 additions & 35 deletions src/distributed_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,66 +6,81 @@ using MPI
#####

# Utilities to make the macro work importing only ClimaOcean and not MPI
mpi_initialized() = MPI.Initialized()
mpi_rank() = MPI.Comm_rank(MPI.COMM_WORLD)
mpi_size() = MPI.Comm_size(MPI.COMM_WORLD)
global_barrier() = mpi_initialized() ? MPI.Barrier(MPI.COMM_WORLD) : nothing
mpi_initialized() = MPI.Initialized()
mpi_rank(comm) = MPI.Comm_rank(comm)
mpi_size(comm) = MPI.Comm_size(comm)
global_barrier(comm) = MPI.Barrier(comm)

"""
@root exs...
@root communicator exs...
Perform `exs` only on rank 0, otherwise know as "root" rank.
Other ranks will wait for the root rank to finish before continuing
Perform `exs` only on rank 0 in communicator, otherwise known as the "root" rank.
Other ranks will wait for the root rank to finish before continuing.
If `communicator` is not provided, `MPI.COMM_WORLD` is used.
"""
macro root(exp)
macro root(communicator, exp)
command = quote
if ClimaOcean.mpi_initialized()
rank = ClimaOcean.mpi_rank()
rank = ClimaOcean.mpi_rank($communicator)
if rank == 0
$exp
end
ClimaOcean.global_barrier()
ClimaOcean.global_barrier($communicator)
else
$exp
end
end
return esc(command)
end

macro root(exp)
command = quote
@root MPI.COMM_WORLD $exp
end
return esc(command)
end

"""
@onrank rank, exs...
@onrank communicator rank exs...
Perform `exp` only on rank `rank`
Other ranks will wait for the root rank to finish before continuing.
The expression is run anyways if MPI in not initialized
Perform `exp` only on rank `rank` (0-based index) in `communicator`.
Other ranks will wait for rank `rank` to finish before continuing.
The expression is run anyways if MPI in not initialized.
If `communicator` is not provided, `MPI.COMM_WORLD` is used.
"""
macro onrank(exp_with_rank)
on_rank = exp_with_rank.args[1]
exp = exp_with_rank.args[2]
macro onrank(communicator, on_rank, exp)
command = quote
mpi_initialized = ClimaOcean.mpi_initialized()
rank = ClimaOcean.mpi_rank()
if !mpi_initialized
$exp
else
rank = ClimaOcean.mpi_rank($communicator)
if rank == $on_rank
$exp
end
ClimaOcean.global_barrier()
ClimaOcean.global_barrier($communicator)
end
end

return esc(command)
end

macro onrank(rank, exp)
command = quote
@onrank MPI.COMM_WORLD $rank $exp
end
return esc(command)
end

"""
@distribute for i in iterable
@distribute communicator for i in iterable
...
end
Distribute a `for` loop among different ranks
Distribute a `for` loop among different ranks in `communicator`.
If `communicator` is not provided, `MPI.COMM_WORLD` is used.
"""
macro distribute(exp)
macro distribute(communicator, exp)
if exp.head != :for
error("The `@distribute` macro expects a `for` loop")
end
Expand All @@ -74,46 +89,67 @@ macro distribute(exp)
variable = exp.args[1].args[1]
forbody = exp.args[2]

# Safety net if the iterable variable has the same name as the
# reserved variable names (nprocs, counter, rank)
nprocs = ifelse(variable == :nprocs, :othernprocs, :nprocs)
counter = ifelse(variable == :counter, :othercounter, :counter)
rank = ifelse(variable == :rank, :otherrank, :rank)

new_loop = quote
mpi_initialized = ClimaOcean.mpi_initialized()
if !mpi_initialized
$exp
else
rank = ClimaOcean.mpi_rank()
nprocs = ClimaOcean.mpi_size()
for (counter, $variable) in enumerate($iterable)
if (counter - 1) % nprocs == rank
$rank = ClimaOcean.mpi_rank($communicator)
$nprocs = ClimaOcean.mpi_size($communicator)
for ($counter, $variable) in enumerate($iterable)
if ($counter - 1) % $nprocs == $rank
$forbody
end
end
ClimaOcean.global_barrier()
ClimaOcean.global_barrier($communicator)
end
end

return esc(new_loop)
end

macro distribute(exp)
command = quote
@distribute MPI.COMM_WORLD $exp
end
return esc(command)
end

"""
@handshake exs...
@handshake communicator exs...
perform `exs` on all ranks, but only one rank at a time, where
ranks `r2 > r1` wait for rank `r1` to finish before executing `exs`
perform `exs` on all ranks in `communicator`, but only one rank at a time, where
ranks `r2 > r1` wait for rank `r1` to finish before executing `exs`.
If `communicator` is not provided, `MPI.COMM_WORLD` is used.
"""
macro handshake(exp)
macro handshake(communicator, exp)
command = quote
mpi_initialized = ClimaOcean.mpi_initialized()
if !mpi_initialized
$exp
else
rank = ClimaOcean.mpi_rank()
nprocs = ClimaOcean.mpi_size()
rank = ClimaOcean.mpi_rank($communicator)
nprocs = ClimaOcean.mpi_size($communicator)
for r in 0 : nprocs -1
if rank == r
$exp
end
ClimaOcean.global_barrier()
ClimaOcean.global_barrier($communicator)
end
end
end
return esc(command)
end
end

macro handshake(exp)
command = quote
@handshake MPI.COMM_WORLD $exp
end
return esc(command)
end
27 changes: 19 additions & 8 deletions test/test_distributed_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@ MPI.Init()
@testset begin
rank = MPI.Comm_rank(MPI.COMM_WORLD)

@onrank 0, begin
@onrank 0 begin
@test rank == 0
end

@root begin
@test rank == 0
end

@onrank 1, begin
@onrank 1 begin
@test rank == 1
end

@onrank 2, begin
@onrank 2 begin
@test rank == 2
end

@onrank 3, begin
@onrank 3 begin
@test rank == 3
end

Expand All @@ -36,15 +36,26 @@ MPI.Init()
@test a == [1, 5, 9]
end

@onrank 1, begin
@onrank 1 begin
@test a == [2, 6, 10]
end

@onrank 2, begin
@onrank 2 begin
@test a == [3, 7]
end

@onrank 3, begin
@onrank 3 begin
@test a == [4, 8]
end
end

split_comm = MPI.Comm_split(MPI.COMM_WORLD, rank % 2, rank)

a = Int[]

@distribute split_comm for i in 1:10
push!(a, i)
end

@onrank split_comm 0 @test a == [1, 3, 5, 7, 9]
@onrank split_comm 1 @test a == [2, 4, 6, 8, 10]
end

0 comments on commit 0fff4fa

Please sign in to comment.