Skip to content

Commit

Permalink
ensure all GPUs having same allreduces to avoid hanging
Browse files Browse the repository at this point in the history
  • Loading branch information
lijinf2 committed Nov 22, 2024
1 parent 06958c4 commit 6c56fa1
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions cpp/src/glm/qn/mg/standardization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,13 @@ void mean_stddev(const raft::handle_t& handle,

SimpleVec<T> stddevVec(stddev_vector, D);
stddevVec.fill(0., stream);

// call allreduces on zeroes to sync with other GPUs to avoid hanging
auto& comm = handle.get_comms();
comm.allreduce(mean_vector, mean_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);
comm.allreduce(stddev_vector, stddev_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);
return;
}

Expand Down

0 comments on commit 6c56fa1

Please sign in to comment.