Skip to content

Commit

Permalink
Fix process divergence issues (could hang when autotuning) in generic…
Browse files Browse the repository at this point in the history
…PrintMatrix routine. Apply same patch to genericPrintVector for future proofing
  • Loading branch information
maddyscientist committed Nov 27, 2024
1 parent 2b81309 commit 0a3f608
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions lib/color_spinor_util.in.cu
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,6 @@ namespace quda {

void genericPrintVector(const ColorSpinorField &a, int parity, unsigned int x_cb, int rank)
{
if (rank != comm_rank()) return;

ColorSpinorParam param(a);
param.location = QUDA_CPU_FIELD_LOCATION;
param.create = QUDA_COPY_FIELD_CREATE;
Expand All @@ -388,6 +386,8 @@ namespace quda {
std::unique_ptr<ColorSpinorField> clone_a = !host_clone ? nullptr : std::make_unique<ColorSpinorField>(param);
const ColorSpinorField &a_ = !host_clone ? a : *clone_a.get();

if (rank != comm_rank()) return; // rank returns after potential copy to host to prevent tuning hang

switch (a.Precision()) {
case QUDA_DOUBLE_PRECISION: genericPrintVector<double>(a_, parity, x_cb); break;
case QUDA_SINGLE_PRECISION: genericPrintVector<float>(a_, parity, x_cb); break;
Expand Down
4 changes: 2 additions & 2 deletions lib/gauge_norm.in.cu
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,6 @@ namespace quda {

void genericPrintMatrix(const GaugeField &a, int d, int parity, unsigned int x_cb, int rank)
{
if (rank != comm_rank()) return;

GaugeFieldParam param(a);
param.field = const_cast<GaugeField *>(&a);
param.location = QUDA_CPU_FIELD_LOCATION;
Expand All @@ -172,6 +170,8 @@ namespace quda {
std::unique_ptr<GaugeField> clone_a = !host_clone ? nullptr : std::make_unique<GaugeField>(param);
const GaugeField &a_ = !host_clone ? a : *clone_a.get();

if (rank != comm_rank()) return; // rank returns after potential copy to host to prevent tuning hang

switch (a.Precision()) {
case QUDA_DOUBLE_PRECISION: genericPrintMatrix<double>(a_, d, parity, x_cb); break;
case QUDA_SINGLE_PRECISION: genericPrintMatrix<float>(a_, d, parity, x_cb); break;
Expand Down

0 comments on commit 0a3f608

Please sign in to comment.