-
Notifications
You must be signed in to change notification settings - Fork 15
Using Multiple GPUs with YAKL
Matt Norman edited this page Mar 21, 2024
·
1 revision
YAKL does not have a concept of multiple GPUs per MPI task (per process). Here are some examples of using YAKL in a few different MPI contexts:
Here's an example of a halo exchange in the x-direction for a fluids code of mine. Note that for GPU-aware MPI, I directly use the GPU pointer. For host-memory API, I create a host copy first. Importantly, note the fence()
at the beginning of the GPU-aware block. GPU-aware MPI does not share the GPU queue, so you cannot expect all GPU work to be done before MPI performs those operations. Thus, you have to synchronize yourself.
auto &neigh = coupler.get_neighbor_rankid_matrix();
auto dtype = coupler.get_mpi_data_type();
MPI_Request sReq [2], rReq [2];
MPI_Status sStat[2], rStat[2];
auto comm = MPI_COMM_WORLD;
int npack = fields.extent(0);
// x-direction exchanges
{
real5d halo_send_buf_W("halo_send_buf_W",npack,nz,ny,hs,nens);
real5d halo_send_buf_E("halo_send_buf_E",npack,nz,ny,hs,nens);
real5d halo_recv_buf_W("halo_recv_buf_W",npack,nz,ny,hs,nens);
real5d halo_recv_buf_E("halo_recv_buf_E",npack,nz,ny,hs,nens);
parallel_for( YAKL_AUTO_LABEL() , SimpleBounds<5>(npack,nz,ny,hs,nens) ,
YAKL_LAMBDA (int v, int k, int j, int ii, int iens) {
halo_send_buf_W(v,k,j,ii,iens) = fields(v,hs+k,hs+j,hs+ii,iens);
halo_send_buf_E(v,k,j,ii,iens) = fields(v,hs+k,hs+j,nx+ii,iens);
});
yakl::timer_start("halo_exchange_mpi");
#ifdef MW_GPU_AWARE_MPI
yakl::fence();
MPI_Irecv( halo_recv_buf_W.data() , halo_recv_buf_W.size() , dtype , neigh(1,0) , 0 , comm , &rReq[0] );
MPI_Irecv( halo_recv_buf_E.data() , halo_recv_buf_E.size() , dtype , neigh(1,2) , 1 , comm , &rReq[1] );
MPI_Isend( halo_send_buf_W.data() , halo_send_buf_W.size() , dtype , neigh(1,0) , 1 , comm , &sReq[0] );
MPI_Isend( halo_send_buf_E.data() , halo_send_buf_E.size() , dtype , neigh(1,2) , 0 , comm , &sReq[1] );
MPI_Waitall(2, sReq, sStat);
MPI_Waitall(2, rReq, rStat);
yakl::timer_stop("halo_exchange_mpi");
#else
auto halo_send_buf_W_host = halo_send_buf_W.createHostObject();
auto halo_send_buf_E_host = halo_send_buf_E.createHostObject();
auto halo_recv_buf_W_host = halo_recv_buf_W.createHostObject();
auto halo_recv_buf_E_host = halo_recv_buf_E.createHostObject();
MPI_Irecv( halo_recv_buf_W_host.data() , halo_recv_buf_W_host.size() , dtype , neigh(1,0) , 0 , comm , &rReq[0] );
MPI_Irecv( halo_recv_buf_E_host.data() , halo_recv_buf_E_host.size() , dtype , neigh(1,2) , 1 , comm , &rReq[1] );
halo_send_buf_W.deep_copy_to(halo_send_buf_W_host);
halo_send_buf_E.deep_copy_to(halo_send_buf_E_host);
yakl::fence();
MPI_Isend( halo_send_buf_W_host.data() , halo_send_buf_W_host.size() , dtype , neigh(1,0) , 1 , comm , &sReq[0] );
MPI_Isend( halo_send_buf_E_host.data() , halo_send_buf_E_host.size() , dtype , neigh(1,2) , 0 , comm , &sReq[1] );
MPI_Waitall(2, sReq, sStat);
MPI_Waitall(2, rReq, rStat);
yakl::timer_stop("halo_exchange_mpi");
halo_recv_buf_W_host.deep_copy_to(halo_recv_buf_W);
halo_recv_buf_E_host.deep_copy_to(halo_recv_buf_E);
#endif
parallel_for( YAKL_AUTO_LABEL() , SimpleBounds<5>(npack,nz,ny,hs,nens) ,
YAKL_LAMBDA (int v, int k, int j, int ii, int iens) {
fields(v,hs+k,hs+j, ii,iens) = halo_recv_buf_W(v,k,j,ii,iens);
fields(v,hs+k,hs+j,nx+hs+ii,iens) = halo_recv_buf_E(v,k,j,ii,iens);
});
}
using yakl::c::parallel_for;
using yakl::c::Bounds;
int nx_glob = coupler.get_nx_glob();
int ny_glob = coupler.get_ny_glob();
int nens = coupler.get_nens();
int nx = coupler.get_nx();
int ny = coupler.get_ny();
int nz = coupler.get_nz();
real3d column_loc("column_loc",num_fields,nz,nens);
column_loc = 0;
parallel_for( YAKL_AUTO_LABEL() , Bounds<5>(num_fields,nz,ny,nx,nens) ,
YAKL_LAMBDA (int l, int k, int j, int i, int iens) {
yakl::atomicAdd( column_loc(l,k,iens) , state(l,k,j,i,iens) );
});
#ifdef MW_GPU_AWARE_MPI
auto column_total = column_loc.createDeviceObject();
yakl::fence();
MPI_Allreduce( column_loc.data() , column_total.data() , column_total.size() ,
coupler.get_mpi_data_type() , MPI_SUM , MPI_COMM_WORLD );
#else
auto column_total_host = column_loc.createHostObject();
MPI_Allreduce( column_loc.createHostCopy().data() , column_total_host.data() , column_total_host.size() ,
coupler.get_mpi_data_type() , MPI_SUM , MPI_COMM_WORLD );
auto column_total = column_total_host.createDeviceCopy();
#endif
parallel_for( YAKL_AUTO_LABEL() , Bounds<3>(num_fields,nz,nens) , YAKL_LAMBDA (int l, int k, int iens) {
column_loc(l,k,iens) = column_total(l,k,iens) / (nx_glob*ny_glob);
});