From 0fff4fa3564d50a5f44ba7fe45b6e647242387b8 Mon Sep 17 00:00:00 2001 From: Simone Silvestri Date: Tue, 3 Dec 2024 09:33:25 +0100 Subject: [PATCH] Add communicator to the distributed datawrangling utils (#282) * add commuincator * Update distributed_utils.jl * Update distributed_utils.jl * add a safety net to avoid conflicts * Update test_distributed_utils.jl --- src/distributed_utils.jl | 106 ++++++++++++++++++++++----------- test/test_distributed_utils.jl | 27 ++++++--- 2 files changed, 90 insertions(+), 43 deletions(-) diff --git a/src/distributed_utils.jl b/src/distributed_utils.jl index 089962d8..f24d6722 100644 --- a/src/distributed_utils.jl +++ b/src/distributed_utils.jl @@ -6,25 +6,26 @@ using MPI ##### # Utilities to make the macro work importing only ClimaOcean and not MPI -mpi_initialized() = MPI.Initialized() -mpi_rank() = MPI.Comm_rank(MPI.COMM_WORLD) -mpi_size() = MPI.Comm_size(MPI.COMM_WORLD) -global_barrier() = mpi_initialized() ? MPI.Barrier(MPI.COMM_WORLD) : nothing +mpi_initialized() = MPI.Initialized() +mpi_rank(comm) = MPI.Comm_rank(comm) +mpi_size(comm) = MPI.Comm_size(comm) +global_barrier(comm) = MPI.Barrier(comm) """ - @root exs... + @root communicator exs... -Perform `exs` only on rank 0, otherwise know as "root" rank. -Other ranks will wait for the root rank to finish before continuing +Perform `exs` only on rank 0 in communicator, otherwise known as the "root" rank. +Other ranks will wait for the root rank to finish before continuing. +If `communicator` is not provided, `MPI.COMM_WORLD` is used. """ -macro root(exp) +macro root(communicator, exp) command = quote if ClimaOcean.mpi_initialized() - rank = ClimaOcean.mpi_rank() + rank = ClimaOcean.mpi_rank($communicator) if rank == 0 $exp end - ClimaOcean.global_barrier() + ClimaOcean.global_barrier($communicator) else $exp end @@ -32,40 +33,54 @@ macro root(exp) return esc(command) end +macro root(exp) + command = quote + @root MPI.COMM_WORLD $exp + end + return esc(command) +end + """ - @onrank rank, exs... + @onrank communicator rank exs... -Perform `exp` only on rank `rank` -Other ranks will wait for the root rank to finish before continuing. -The expression is run anyways if MPI in not initialized +Perform `exp` only on rank `rank` (0-based index) in `communicator`. +Other ranks will wait for rank `rank` to finish before continuing. +The expression is run anyways if MPI in not initialized. +If `communicator` is not provided, `MPI.COMM_WORLD` is used. """ -macro onrank(exp_with_rank) - on_rank = exp_with_rank.args[1] - exp = exp_with_rank.args[2] +macro onrank(communicator, on_rank, exp) command = quote mpi_initialized = ClimaOcean.mpi_initialized() - rank = ClimaOcean.mpi_rank() if !mpi_initialized $exp else + rank = ClimaOcean.mpi_rank($communicator) if rank == $on_rank $exp end - ClimaOcean.global_barrier() + ClimaOcean.global_barrier($communicator) end end return esc(command) end +macro onrank(rank, exp) + command = quote + @onrank MPI.COMM_WORLD $rank $exp + end + return esc(command) +end + """ - @distribute for i in iterable + @distribute communicator for i in iterable ... end -Distribute a `for` loop among different ranks +Distribute a `for` loop among different ranks in `communicator`. +If `communicator` is not provided, `MPI.COMM_WORLD` is used. """ -macro distribute(exp) +macro distribute(communicator, exp) if exp.head != :for error("The `@distribute` macro expects a `for` loop") end @@ -74,46 +89,67 @@ macro distribute(exp) variable = exp.args[1].args[1] forbody = exp.args[2] + # Safety net if the iterable variable has the same name as the + # reserved variable names (nprocs, counter, rank) + nprocs = ifelse(variable == :nprocs, :othernprocs, :nprocs) + counter = ifelse(variable == :counter, :othercounter, :counter) + rank = ifelse(variable == :rank, :otherrank, :rank) + new_loop = quote mpi_initialized = ClimaOcean.mpi_initialized() if !mpi_initialized $exp else - rank = ClimaOcean.mpi_rank() - nprocs = ClimaOcean.mpi_size() - for (counter, $variable) in enumerate($iterable) - if (counter - 1) % nprocs == rank + $rank = ClimaOcean.mpi_rank($communicator) + $nprocs = ClimaOcean.mpi_size($communicator) + for ($counter, $variable) in enumerate($iterable) + if ($counter - 1) % $nprocs == $rank $forbody end end - ClimaOcean.global_barrier() + ClimaOcean.global_barrier($communicator) end end return esc(new_loop) end +macro distribute(exp) + command = quote + @distribute MPI.COMM_WORLD $exp + end + return esc(command) +end + """ - @handshake exs... + @handshake communicator exs... -perform `exs` on all ranks, but only one rank at a time, where -ranks `r2 > r1` wait for rank `r1` to finish before executing `exs` +perform `exs` on all ranks in `communicator`, but only one rank at a time, where +ranks `r2 > r1` wait for rank `r1` to finish before executing `exs`. +If `communicator` is not provided, `MPI.COMM_WORLD` is used. """ -macro handshake(exp) +macro handshake(communicator, exp) command = quote mpi_initialized = ClimaOcean.mpi_initialized() if !mpi_initialized $exp else - rank = ClimaOcean.mpi_rank() - nprocs = ClimaOcean.mpi_size() + rank = ClimaOcean.mpi_rank($communicator) + nprocs = ClimaOcean.mpi_size($communicator) for r in 0 : nprocs -1 if rank == r $exp end - ClimaOcean.global_barrier() + ClimaOcean.global_barrier($communicator) end end end return esc(command) -end \ No newline at end of file +end + +macro handshake(exp) + command = quote + @handshake MPI.COMM_WORLD $exp + end + return esc(command) +end diff --git a/test/test_distributed_utils.jl b/test/test_distributed_utils.jl index a6a20075..8e9d8be7 100644 --- a/test/test_distributed_utils.jl +++ b/test/test_distributed_utils.jl @@ -6,7 +6,7 @@ MPI.Init() @testset begin rank = MPI.Comm_rank(MPI.COMM_WORLD) - @onrank 0, begin + @onrank 0 begin @test rank == 0 end @@ -14,15 +14,15 @@ MPI.Init() @test rank == 0 end - @onrank 1, begin + @onrank 1 begin @test rank == 1 end - @onrank 2, begin + @onrank 2 begin @test rank == 2 end - @onrank 3, begin + @onrank 3 begin @test rank == 3 end @@ -36,15 +36,26 @@ MPI.Init() @test a == [1, 5, 9] end - @onrank 1, begin + @onrank 1 begin @test a == [2, 6, 10] end - @onrank 2, begin + @onrank 2 begin @test a == [3, 7] end - @onrank 3, begin + @onrank 3 begin @test a == [4, 8] end -end \ No newline at end of file + + split_comm = MPI.Comm_split(MPI.COMM_WORLD, rank % 2, rank) + + a = Int[] + + @distribute split_comm for i in 1:10 + push!(a, i) + end + + @onrank split_comm 0 @test a == [1, 3, 5, 7, 9] + @onrank split_comm 1 @test a == [2, 4, 6, 8, 10] +end