diff --git a/CMakeLists.txt b/CMakeLists.txt index e66f12f5ce..eb8d85468f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -228,6 +228,8 @@ option(QUDA_CLOVER_DYNAMIC "Dynamically invert the clover term" ON) option(QUDA_CLOVER_RECONSTRUCT "set to ON to enable compressed clover storage (requires QUDA_CLOVER_DYNAMIC)" ON) option(QUDA_CLOVER_CHOLESKY_PROMOTE "Whether to promote the internal precision when inverting the clover term" ON) +option(QUDA_IMPROVED_STAGGERED_EIGENSOLVER_CTEST "Whether to run eigensolver ctests against the improved staggered operator (requires QUDA_DIRAC_STAGGERED)" OFF) + # Set CTest options option(QUDA_CTEST_SEP_DSLASH_POLICIES "Test Dslash policies separately in ctest instead of only autotuning them." OFF) option(QUDA_CTEST_DISABLE_BENCHMARKS "Disable benchmark test" ON) @@ -391,7 +393,11 @@ set(CMAKE_EXE_LINKER_FLAGS_SANITIZE CACHE STRING "Flags used by the linker during sanitizer debug builds.") if(QUDA_CLOVER_RECONSTRUCT AND NOT QUDA_CLOVER_DYNAMIC) - message(SEND_ERROR "QUDA_CLOVER_RECONSTRUCT requires QUDA_CLOVER_DYNAMIC)") + message(SEND_ERROR "QUDA_CLOVER_RECONSTRUCT requires QUDA_CLOVER_DYNAMIC") +endif() + +if (QUDA_IMPROVED_STAGGERED_EIGENSOLVER_CTEST AND NOT QUDA_DIRAC_STAGGERED) + message(SEND_ERROR "QUDA_IMPROVED_STAGGERED_EIGENSOLVER_CTEST requires QUDA_DIRAC_STAGGERED") endif() find_package(Threads REQUIRED) diff --git a/ci/docker/Dockerfile.build b/ci/docker/Dockerfile.build index 3bd1f20e8e..5322ddfd2d 100644 --- a/ci/docker/Dockerfile.build +++ b/ci/docker/Dockerfile.build @@ -40,7 +40,6 @@ RUN QUDA_TEST_GRID_SIZE="1 1 1 2" cmake -S /quda/src \ -DQUDA_DIRAC_DEFAULT_OFF=ON \ -DQUDA_DIRAC_WILSON=ON \ -DQUDA_DIRAC_CLOVER=ON \ - -DQUDA_DIRAC_TWISTED_MASS=ON \ -DQUDA_DIRAC_TWISTED_CLOVER=ON \ -DQUDA_DIRAC_STAGGERED=ON \ -GNinja \ diff --git a/include/invert_quda.h b/include/invert_quda.h index 11ac64708e..7ab7a1138c 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -1048,10 +1048,24 @@ namespace quda { private: const DiracMdagM matMdagM; // used by the eigensolver - // pointers to fields to avoid multiple creation overhead - ColorSpinorField *yp, *rp, *pp, *vp, *tmpp, *tp; + + ColorSpinorField y; // Full precision solution accumulator + ColorSpinorField r; // Full precision residual vector + ColorSpinorField p; // Sloppy precision search direction + ColorSpinorField v; // Sloppy precision A * p + ColorSpinorField t; // Sloppy precision vector used for minres step + ColorSpinorField r0; // Bi-orthogonalization vector + ColorSpinorField r_sloppy; // Slopy precision residual vector + ColorSpinorField x_sloppy; // Sloppy solution accumulator vector bool init = false; + /** + @brief Initiate the fields needed by the solver + @param[in] x Solution vector + @param[in] b Source vector + */ + void create(ColorSpinorField &x, const ColorSpinorField &b); + public: BiCGstab(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m, TimeProfile &profile); @@ -1059,6 +1073,11 @@ namespace quda { void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + /** + @return Return the residual vector from the prior solve + */ + ColorSpinorField &get_residual() override; + virtual bool hermitian() const override { return false; } /** BiCGStab is for any linear system */ virtual QudaInverterType getInverterType() const final { return QUDA_BICGSTAB_INVERTER; } diff --git a/lib/eig_block_trlm.cpp b/lib/eig_block_trlm.cpp index 160af9dff4..890257c1ed 100644 --- a/lib/eig_block_trlm.cpp +++ b/lib/eig_block_trlm.cpp @@ -105,16 +105,25 @@ namespace quda eigensolveFromBlockArrowMat(); profile.TPSTART(QUDA_PROFILE_COMPUTE); - // mat_norm is updated. + // mat_norm is updated and used for LR for (int i = num_locked; i < n_kr; i++) if (fabs(alpha[i]) > mat_norm) mat_norm = fabs(alpha[i]); + // Lambda that returns mat_norm for LR and returns the relevant alpha + // (the corresponding Ritz value) for SR + auto check_norm = [&](double sr_norm) -> double { + if (eig_param->spectrum == QUDA_SPECTRUM_LR_EIG) + return mat_norm; + else + return sr_norm; + }; + // Locking check iter_locked = 0; for (int i = 1; i < (n_kr - num_locked); i++) { - if (residua[i + num_locked] < epsilon * mat_norm) { + if (residua[i + num_locked] < epsilon * check_norm(alpha[i + num_locked])) { logQuda(QUDA_DEBUG_VERBOSE, "**** Locking %d resid=%+.6e condition=%.6e ****\n", i, residua[i + num_locked], - epsilon * mat_norm); + epsilon * check_norm(alpha[i + num_locked])); iter_locked = i; } else { // Unlikely to find new locked pairs @@ -125,9 +134,9 @@ namespace quda // Convergence check iter_converged = iter_locked; for (int i = iter_locked + 1; i < n_kr - num_locked; i++) { - if (residua[i + num_locked] < tol * mat_norm) { + if (residua[i + num_locked] < tol * check_norm(alpha[i + num_locked])) { logQuda(QUDA_DEBUG_VERBOSE, "**** Converged %d resid=%+.6e condition=%.6e ****\n", i, residua[i + num_locked], - tol * mat_norm); + tol * check_norm(alpha[i + num_locked])); iter_converged = i; } else { // Unlikely to find new converged pairs diff --git a/lib/eig_trlm.cpp b/lib/eig_trlm.cpp index 00d3941527..02d8c26246 100644 --- a/lib/eig_trlm.cpp +++ b/lib/eig_trlm.cpp @@ -86,16 +86,25 @@ namespace quda eigensolveFromArrowMat(); profile.TPSTART(QUDA_PROFILE_COMPUTE); - // mat_norm is updated. + // mat_norm is updated and used for LR for (int i = num_locked; i < n_kr; i++) if (fabs(alpha[i]) > mat_norm) mat_norm = fabs(alpha[i]); + // Lambda that returns mat_norm for LR and returns the relevant alpha + // (the corresponding Ritz value) for SR + auto check_norm = [&](double sr_norm) -> double { + if (eig_param->spectrum == QUDA_SPECTRUM_LR_EIG) + return mat_norm; + else + return sr_norm; + }; + // Locking check iter_locked = 0; for (int i = 1; i < (n_kr - num_locked); i++) { - if (residua[i + num_locked] < epsilon * mat_norm) { + if (residua[i + num_locked] < epsilon * check_norm(alpha[i + num_locked])) { logQuda(QUDA_DEBUG_VERBOSE, "**** Locking %d resid=%+.6e condition=%.6e ****\n", i, residua[i + num_locked], - epsilon * mat_norm); + epsilon * check_norm(alpha[i + num_locked])); iter_locked = i; } else { // Unlikely to find new locked pairs @@ -106,9 +115,9 @@ namespace quda // Convergence check iter_converged = iter_locked; for (int i = iter_locked + 1; i < n_kr - num_locked; i++) { - if (residua[i + num_locked] < tol * mat_norm) { + if (residua[i + num_locked] < tol * check_norm(alpha[i + num_locked])) { logQuda(QUDA_DEBUG_VERBOSE, "**** Converged %d resid=%+.6e condition=%.6e ****\n", i, residua[i + num_locked], - tol * mat_norm); + tol * check_norm(alpha[i + num_locked])); iter_converged = i; } else { // Unlikely to find new converged pairs @@ -165,8 +174,8 @@ namespace quda logQuda(QUDA_SUMMARIZE, "TRLM computed the requested %d vectors in %d restart steps and %d OP*x operations.\n", n_conv, restart_iter, iter); - // Dump all Ritz values and residua if using Chebyshev - for (int i = 0; i < n_conv && eig_param->use_poly_acc; i++) { + // Dump all Ritz values and residua + for (int i = 0; i < n_conv; i++) { logQuda(QUDA_SUMMARIZE, "RitzValue[%04d]: (%+.16e, %+.16e) residual %.16e\n", i, alpha[i], 0.0, residua[i]); } diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index b07a0b2c7b..46fed2c43d 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -3014,7 +3014,7 @@ void loadFatLongGaugeQuda(QudaInvertParam *inv_param, QudaGaugeParam *gauge_para template void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // color spinor field pointers, and inv_param void *h_gauge, void *milc_fatlinks, void *milc_longlinks, - QudaGaugeParam *gauge_param, // gauge field pointers + QudaGaugeParam *gauge_param_, // gauge field pointers void *h_clover, void *h_clovinv, // clover field pointers Interface op, Args... args) { @@ -3034,14 +3034,17 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col errorQuda("split_key = [%d,%d,%d,%d] is not valid", split_key[0], split_key[1], split_key[2], split_key[3]); } + // Create a local copy of gauge_param that we can modify without perturbing + // the original one + if (!gauge_param_) errorQuda("Input gauge_param is null"); + QudaGaugeParam gauge_param = *gauge_param_; + if (num_sub_partition == 1) { // In this case we don't split the grid. for (int n = 0; n < param->num_src; n++) { op(_hp_x[n], _hp_b[n], param, args...); } } else { - if (gauge_param == nullptr) { errorQuda("gauge_param == nullptr"); } - // Doing the sub-partition arithmatics if (param->num_src_per_sub_partition * num_sub_partition != param->num_src) { errorQuda("We need to have split_grid[0](=%d) * split_grid[1](=%d) * split_grid[2](=%d) * split_grid[3](=%d) * " @@ -3058,44 +3061,50 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col checkInvertParam(param, _hp_x[0], _hp_b[0]); - bool is_staggered; + bool is_staggered = false; + bool is_asqtad = false; if (h_gauge) { is_staggered = false; } else if (milc_fatlinks) { is_staggered = true; + if (param->dslash_type == QUDA_ASQTAD_DSLASH) { + if (!milc_longlinks) errorQuda("milc_longlinks is null for an asqtad dslash"); + is_asqtad = true; + } } else { errorQuda("Both h_gauge and milc_fatlinks are null."); - is_staggered = true; // to suppress compiler warning/error. } // Gauge fields/params - GaugeFieldParam *gf_param = nullptr; - GaugeField *in = nullptr; + GaugeFieldParam gf_param; + GaugeField in; // Staggered gauge fields/params - GaugeFieldParam *milc_fatlink_param = nullptr; - GaugeFieldParam *milc_longlink_param = nullptr; - GaugeField *milc_fatlink_field = nullptr; - GaugeField *milc_longlink_field = nullptr; + GaugeFieldParam milc_fatlink_param; + GaugeFieldParam milc_longlink_param; + quda::GaugeField milc_fatlink_field; + quda::GaugeField milc_longlink_field; // set up the gauge field params. if (!is_staggered) { // not staggered - gf_param = new GaugeFieldParam(*gauge_param, h_gauge); - if (gf_param->order <= 4) gf_param->ghostExchange = QUDA_GHOST_EXCHANGE_NO; - in = GaugeField::Create(*gf_param); + gf_param = GaugeFieldParam(gauge_param, h_gauge); + in = GaugeField(gf_param); } else { // staggered - milc_fatlink_param = new GaugeFieldParam(*gauge_param, milc_fatlinks); - if (milc_fatlink_param->order <= 4) milc_fatlink_param->ghostExchange = QUDA_GHOST_EXCHANGE_NO; - milc_fatlink_field = GaugeField::Create(*milc_fatlink_param); - milc_longlink_param = new GaugeFieldParam(*gauge_param, milc_longlinks); - if (milc_longlink_param->order <= 4) milc_longlink_param->ghostExchange = QUDA_GHOST_EXCHANGE_NO; - milc_longlink_field = GaugeField::Create(*milc_longlink_param); + milc_fatlink_param = GaugeFieldParam(gauge_param, milc_fatlinks); + milc_fatlink_param.order = QUDA_MILC_GAUGE_ORDER; + milc_fatlink_field = GaugeField(milc_fatlink_param); + + if (is_asqtad) { + milc_longlink_param = GaugeFieldParam(gauge_param, milc_longlinks); + milc_longlink_param.order = QUDA_MILC_GAUGE_ORDER; + milc_longlink_field = GaugeField(milc_longlink_param); + } } // Create the temp host side helper fields, which are just wrappers of the input pointers. bool pc_solution = (param->solution_type == QUDA_MATPC_SOLUTION) || (param->solution_type == QUDA_MATPCDAG_MATPC_SOLUTION); - lat_dim_t X = {gauge_param->X[0], gauge_param->X[1], gauge_param->X[2], gauge_param->X[3]}; + lat_dim_t X = {gauge_param.X[0], gauge_param.X[1], gauge_param.X[2], gauge_param.X[3]}; ColorSpinorParam cpuParam(_hp_b[0], *param, X, pc_solution, param->input_location); std::vector _h_b(param->num_src); for (int i = 0; i < param->num_src; i++) { @@ -3119,16 +3128,14 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col errorQuda("Split not possible: %2d %% %2d != 0", comm_dim(d), split_key[d]); } if (!is_staggered) { - gf_param->x[d] *= split_key[d]; - gf_param->pad *= split_key[d]; + gf_param.x[d] *= split_key[d]; + gf_param.pad *= split_key[d]; } else { - milc_fatlink_param->x[d] *= split_key[d]; - milc_fatlink_param->pad *= split_key[d]; - milc_longlink_param->x[d] *= split_key[d]; - milc_longlink_param->pad *= split_key[d]; + milc_fatlink_param.x[d] *= split_key[d]; + if (is_asqtad) milc_longlink_param.x[d] *= split_key[d]; } - gauge_param->X[d] *= split_key[d]; - gauge_param->ga_pad *= split_key[d]; + gauge_param.X[d] *= split_key[d]; + if (!is_staggered) gauge_param.ga_pad *= split_key[d]; } // Deal with clover field. For Multi source computatons, clover field construction is done @@ -3171,26 +3178,30 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col } } - quda::GaugeField *collected_gauge = nullptr; - quda::GaugeField *collected_milc_fatlink_field = nullptr; - quda::GaugeField *collected_milc_longlink_field = nullptr; + quda::GaugeField collected_gauge; + quda::GaugeField collected_milc_fatlink_field; + quda::GaugeField collected_milc_longlink_field; if (!is_staggered) { - gf_param->create = QUDA_NULL_FIELD_CREATE; - collected_gauge = new quda::GaugeField(*gf_param); + gf_param.create = QUDA_NULL_FIELD_CREATE; + collected_gauge = quda::GaugeField(gf_param); std::vector v_g(1); - v_g[0] = in; - quda::split_field(*collected_gauge, v_g, split_key); + v_g[0] = ∈ + quda::split_field(collected_gauge, v_g, split_key); } else { - milc_fatlink_param->create = QUDA_NULL_FIELD_CREATE; - milc_longlink_param->create = QUDA_NULL_FIELD_CREATE; - collected_milc_fatlink_field = new quda::GaugeField(*milc_fatlink_param); - collected_milc_longlink_field = new quda::GaugeField(*milc_longlink_param); std::vector v_g(1); - v_g[0] = milc_fatlink_field; - quda::split_field(*collected_milc_fatlink_field, v_g, split_key); - v_g[0] = milc_longlink_field; - quda::split_field(*collected_milc_longlink_field, v_g, split_key); + + milc_fatlink_param.create = QUDA_NULL_FIELD_CREATE; + collected_milc_fatlink_field = GaugeField(milc_fatlink_param); + v_g[0] = &milc_fatlink_field; + quda::split_field(collected_milc_fatlink_field, v_g, split_key); + + if (is_asqtad) { + milc_longlink_param.create = QUDA_NULL_FIELD_CREATE; + collected_milc_longlink_field = GaugeField(milc_longlink_param); + v_g[0] = &milc_longlink_field; + quda::split_field(collected_milc_longlink_field, v_g, split_key); + } } profileInvertMultiSrc.TPSTART(QUDA_PROFILE_PREAMBLE); @@ -3223,10 +3234,10 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col // the split topology. logQuda(QUDA_DEBUG_VERBOSE, "Split grid loading gauge field...\n"); if (!is_staggered) { - loadGaugeQuda(collected_gauge->raw_pointer(), gauge_param); + loadGaugeQuda(collected_gauge.raw_pointer(), &gauge_param); } else { - loadFatLongGaugeQuda(param, gauge_param, collected_milc_fatlink_field->raw_pointer(), - collected_milc_longlink_field->raw_pointer()); + loadFatLongGaugeQuda(param, &gauge_param, collected_milc_fatlink_field.raw_pointer(), + (is_asqtad) ? collected_milc_longlink_field.raw_pointer() : nullptr); } logQuda(QUDA_DEBUG_VERBOSE, "Split grid loaded gauge field...\n"); @@ -3251,8 +3262,8 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col comm_barrier(); for (int d = 0; d < CommKey::n_dim; d++) { - gauge_param->X[d] /= split_key[d]; - gauge_param->ga_pad /= split_key[d]; + gauge_param.X[d] /= split_key[d]; + if (!is_staggered) gauge_param.ga_pad /= split_key[d]; } for (int n = 0; n < param->num_src_per_sub_partition; n++) { @@ -3268,16 +3279,6 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col for (auto p : _h_x) { delete p; } for (auto p : _h_b) { delete p; } - if (!is_staggered) { - delete in; - delete collected_gauge; - } else { - delete milc_fatlink_field; - delete milc_longlink_field; - delete collected_milc_fatlink_field; - delete collected_milc_longlink_field; - } - if (input_clover) { delete input_clover; } if (collected_clover) { delete collected_clover; } @@ -3285,10 +3286,10 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col // Restore the gauge field if (!is_staggered) { - loadGaugeQuda(h_gauge, gauge_param); + loadGaugeQuda(h_gauge, gauge_param_); } else { freeGaugeQuda(); - loadFatLongGaugeQuda(param, gauge_param, milc_fatlinks, milc_longlinks); + loadFatLongGaugeQuda(param, gauge_param_, milc_fatlinks, milc_longlinks); } if (param->dslash_type == QUDA_CLOVER_WILSON_DSLASH || param->dslash_type == QUDA_TWISTED_CLOVER_DSLASH) { diff --git a/lib/inv_bicgstab_quda.cpp b/lib/inv_bicgstab_quda.cpp index 4fdf08020a..742e026bd2 100644 --- a/lib/inv_bicgstab_quda.cpp +++ b/lib/inv_bicgstab_quda.cpp @@ -21,19 +21,37 @@ namespace quda { BiCGstab::~BiCGstab() { profile.TPSTART(QUDA_PROFILE_FREE); - - if(init) { - delete yp; - delete rp; - delete pp; - delete vp; - delete tmpp; - delete tp; - } destroyDeflationSpace(); profile.TPSTOP(QUDA_PROFILE_FREE); } + void BiCGstab::create(ColorSpinorField &x, const ColorSpinorField &b) + { + Solver::create(x, b); + + if (!init) { + if (!param.is_preconditioner) profile.TPSTART(QUDA_PROFILE_INIT); + ColorSpinorParam csParam(x); + csParam.create = QUDA_ZERO_FIELD_CREATE; + y = ColorSpinorField(csParam); + r = ColorSpinorField(csParam); + csParam.setPrecision(param.precision_sloppy); + p = ColorSpinorField(csParam); + v = ColorSpinorField(csParam); + t = ColorSpinorField(csParam); + + if (!param.is_preconditioner) profile.TPSTOP(QUDA_PROFILE_INIT); + init = true; + } // init + } + + ColorSpinorField &BiCGstab::get_residual() + { + if (!init) errorQuda("No residual vector present"); + if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); + return r; + } + int reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta) { // reliable updates rNorm = sqrt(r2); @@ -50,33 +68,12 @@ namespace quda { void BiCGstab::operator()(ColorSpinorField &x, ColorSpinorField &b) { - profile.TPSTART(QUDA_PROFILE_PREAMBLE); - - if (!init) { - ColorSpinorParam csParam(x); - csParam.create = QUDA_ZERO_FIELD_CREATE; - yp = new ColorSpinorField(csParam); - rp = new ColorSpinorField(csParam); - csParam.setPrecision(param.precision_sloppy); - pp = new ColorSpinorField(csParam); - vp = new ColorSpinorField(csParam); - tmpp = new ColorSpinorField(csParam); - tp = new ColorSpinorField(csParam); - - init = true; - } + create(x, b); - ColorSpinorField &y = *yp; - ColorSpinorField &r = *rp; - ColorSpinorField &p = *pp; - ColorSpinorField &v = *vp; - ColorSpinorField &tmp = *tmpp; - ColorSpinorField &t = *tp; - - ColorSpinorField *x_sloppy, *r_sloppy, *r_0; + if (!param.is_preconditioner) profile.TPSTART(QUDA_PROFILE_INIT); double b2 = blas::norm2(b); // norm sq of source - double r2; // norm sq of residual + double r2 = 0.0; // norm sq of residual if (param.deflate) { // Construct the eigensolver and deflation space if requested. @@ -89,15 +86,15 @@ namespace quda { } if (deflate_compute) { // compute the deflation space. - if (!param.is_preconditioner) profile.TPSTOP(QUDA_PROFILE_PREAMBLE); + if (!param.is_preconditioner) profile.TPSTOP(QUDA_PROFILE_INIT); (*eig_solve)(evecs, evals); + if (!param.is_preconditioner) profile.TPSTART(QUDA_PROFILE_INIT); if (param.deflate) { // double the size of the Krylov space extendSVDDeflationSpace(); // populate extra memory with L/R singular vectors eig_solve->computeSVD(evecs, evals); } - if (!param.is_preconditioner) profile.TPSTART(QUDA_PROFILE_PREAMBLE); deflate_compute = false; } if (recompute_evals) { @@ -134,7 +131,7 @@ namespace quda { x = b; param.true_res = 0.0; param.true_res_hq = 0.0; - profile.TPSTOP(QUDA_PROFILE_PREAMBLE); + if (!param.is_preconditioner) profile.TPSTOP(QUDA_PROFILE_INIT); return; } else if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { b2 = r2; @@ -145,46 +142,40 @@ namespace quda { // set field aliasing according to whether we are doing mixed precision or not if (param.precision_sloppy == x.Precision()) { - r_sloppy = &r; + r_sloppy = r.create_alias(); - if(param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) - { - r_0 = &b; - } - else - { + if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { + r0 = b.create_alias(); + } else { ColorSpinorParam csParam(r); - csParam.create = QUDA_ZERO_FIELD_CREATE; - r_0 = new ColorSpinorField(csParam); // remember to delete this pointer. - *r_0 = r; + csParam.create = QUDA_NULL_FIELD_CREATE; + r0 = ColorSpinorField(csParam); + blas::copy(r0, r); } } else { ColorSpinorParam csParam(x); csParam.setPrecision(param.precision_sloppy); csParam.create = QUDA_NULL_FIELD_CREATE; - r_sloppy = new ColorSpinorField(csParam); - *r_sloppy = r; - r_0 = new ColorSpinorField(csParam); - *r_0 = r; + r_sloppy = ColorSpinorField(csParam); + blas::copy(r_sloppy, r); + r0 = ColorSpinorField(csParam); + blas::copy(r0, r); } - if (param.precision_sloppy == x.Precision() || !param.use_sloppy_partial_accumulator) - { - x_sloppy = &x; - blas::zero(*x_sloppy); - } - else - { + if (param.precision_sloppy == x.Precision() || !param.use_sloppy_partial_accumulator) { + x_sloppy = x.create_alias(); + blas::zero(x_sloppy); + } else { ColorSpinorParam csParam(x); csParam.create = QUDA_ZERO_FIELD_CREATE; csParam.setPrecision(param.precision_sloppy); - x_sloppy = new ColorSpinorField(csParam); + x_sloppy = ColorSpinorField(csParam); } - // Syntatic sugar - ColorSpinorField &rSloppy = *r_sloppy; - ColorSpinorField &xSloppy = *x_sloppy; - ColorSpinorField &r0 = *r_0; + if (!param.is_preconditioner) { + profile.TPSTOP(QUDA_PROFILE_INIT); + profile.TPSTART(QUDA_PROFILE_PREAMBLE); + } double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver @@ -214,86 +205,90 @@ namespace quda { PrintStats("BiCGstab", k, r2, b2, heavy_quark_res); - profile.TPSTOP(QUDA_PROFILE_PREAMBLE); - profile.TPSTART(QUDA_PROFILE_COMPUTE); + if (!param.is_preconditioner) { + profile.TPSTOP(QUDA_PROFILE_PREAMBLE); + profile.TPSTART(QUDA_PROFILE_COMPUTE); + } rho = r2; // cDotProductCuda(r0, r_sloppy); // BiCRstab - blas::copy(p, rSloppy); + blas::copy(p, r_sloppy); + + bool converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); if (getVerbosity() >= QUDA_DEBUG_VERBOSE) - printfQuda("BiCGstab debug: x2=%e, r2=%e, v2=%e, p2=%e, tmp2=%e r0=%e t2=%e\n", - blas::norm2(x), blas::norm2(rSloppy), blas::norm2(v), blas::norm2(p), - blas::norm2(tmp), blas::norm2(r0), blas::norm2(t)); + printfQuda("BiCGstab debug: x2=%e, r2=%e, v2=%e, p2=%e, r0=%e, t2=%e\n", blas::norm2(x), blas::norm2(r_sloppy), + blas::norm2(v), blas::norm2(p), blas::norm2(r0), blas::norm2(t)); - while ( !convergence(r2, heavy_quark_res, stop, param.tol_hq) && - k < param.maxiter) { + // track if we just performed an exact recalculation of y, r, r2 + bool just_updated = false; + + while (!converged && k < param.maxiter) { + just_updated = false; matSloppy(v, p); Complex r0v; if (param.pipeline) { - r0v = blas::cDotProduct(r0, v); - if (k>0) rho = blas::cDotProduct(r0, r); + r0v = blas::cDotProduct(r0, v); + if (k > 0) rho = blas::cDotProduct(r0, r); } else { - r0v = blas::cDotProduct(r0, v); + r0v = blas::cDotProduct(r0, v); } if (abs(rho) == 0.0) alpha = 0.0; else alpha = rho / r0v; // r -= alpha*v - blas::caxpy(-alpha, v, rSloppy); + blas::caxpy(-alpha, v, r_sloppy); - matSloppy(t, rSloppy); + matSloppy(t, r_sloppy); int updateR = 0; if (param.pipeline) { - // omega = (t, r) / (t, t) - omega_t2 = blas::cDotProductNormA(t, rSloppy); - Complex tr = Complex(omega_t2.x, omega_t2.y); - double t2 = omega_t2.z; - omega = tr / t2; - double s2 = blas::norm2(rSloppy); - Complex r0t = blas::cDotProduct(r0, t); - beta = -r0t / r0v; - r2 = s2 - real(omega * conj(tr)) ; - - // now we can work out if we need to do a reliable update + // omega = (t, r) / (t, t) + omega_t2 = blas::cDotProductNormA(t, r_sloppy); + Complex tr = Complex(omega_t2.x, omega_t2.y); + double t2 = omega_t2.z; + omega = tr / t2; + double s2 = blas::norm2(r_sloppy); + Complex r0t = blas::cDotProduct(r0, t); + beta = -r0t / r0v; + r2 = s2 - real(omega * conj(tr)); + // now we can work out if we need to do a reliable update updateR = reliable(rNorm, maxrx, maxrr, r2, delta); } else { - // omega = (t, r) / (t, t) - omega_t2 = blas::cDotProductNormA(t, rSloppy); - omega = Complex(omega_t2.x / omega_t2.z, omega_t2.y / omega_t2.z); + // omega = (t, r) / (t, t) + omega_t2 = blas::cDotProductNormA(t, r_sloppy); + omega = Complex(omega_t2.x / omega_t2.z, omega_t2.y / omega_t2.z); } if (param.pipeline && !updateR) { - //x += alpha*p + omega*r, r -= omega*t, p = r - beta*omega*v + beta*p - blas::caxpbypzYmbw(alpha, p, omega, rSloppy, xSloppy, t); - blas::cxpaypbz(rSloppy, -beta*omega, v, beta, p); - //tripleBiCGstabUpdate(alpha, p, omega, rSloppy, xSloppy, t, -beta*omega, v, beta, p + // x += alpha*p + omega*r, r -= omega*t, p = r - beta*omega*v + beta*p + blas::caxpbypzYmbw(alpha, p, omega, r_sloppy, x_sloppy, t); + blas::cxpaypbz(r_sloppy, -beta * omega, v, beta, p); + // tripleBiCGstabUpdate(alpha, p, omega, r_sloppy, x_sloppy, t, -beta*omega, v, beta, p } else { - //x += alpha*p + omega*r, r -= omega*t, r2 = (r,r), rho = (r0, r) - rho_r2 = blas::caxpbypzYmbwcDotProductUYNormY(alpha, p, omega, rSloppy, xSloppy, t, r0); - rho0 = rho; - rho = Complex(rho_r2.x, rho_r2.y); - r2 = rho_r2.z; + // x += alpha*p + omega*r, r -= omega*t, r2 = (r,r), rho = (r0, r) + rho_r2 = blas::caxpbypzYmbwcDotProductUYNormY(alpha, p, omega, r_sloppy, x_sloppy, t, r0); + rho0 = rho; + rho = Complex(rho_r2.x, rho_r2.y); + r2 = rho_r2.z; } - if (use_heavy_quark_res && k%heavy_quark_check==0) { - if (&x != &xSloppy) { - blas::copy(tmp,y); - heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy, tmp, rSloppy).z); + if (use_heavy_quark_res && k % heavy_quark_check == 0) { + if (&x != &x_sloppy) { + heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x_sloppy, r_sloppy).z); } else { - blas::copy(r, rSloppy); - heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r).z); + blas::copy(r, r_sloppy); + heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r).z); } } if (!param.pipeline) updateR = reliable(rNorm, maxrx, maxrr, r2, delta); if (updateR) { - if (x.Precision() != xSloppy.Precision()) blas::copy(x, xSloppy); + if (x.Precision() != x_sloppy.Precision()) blas::copy(x, x_sloppy); - blas::xpy(x, y); // swap these around? + blas::xpy(x, y); mat(r, y); r2 = blas::xmyNorm(b, r); @@ -307,69 +302,97 @@ namespace quda { r2 = blas::xmyNorm(b, r); } - if (x.Precision() != rSloppy.Precision()) blas::copy(rSloppy, r); - blas::zero(xSloppy); + if (x.Precision() != r_sloppy.Precision()) blas::copy(r_sloppy, r); + blas::zero(x_sloppy); + + rNorm = sqrt(r2); + maxrr = rNorm; + maxrx = rNorm; + // r0Norm = rNorm; + rUpdate++; - rNorm = sqrt(r2); - maxrr = rNorm; - maxrx = rNorm; - //r0Norm = rNorm; - rUpdate++; + just_updated = true; } k++; PrintStats("BiCGstab", k, r2, b2, heavy_quark_res); if (getVerbosity() >= QUDA_DEBUG_VERBOSE) - printfQuda("BiCGstab debug: x2=%e, r2=%e, v2=%e, p2=%e, tmp2=%e r0=%e t2=%e\n", - blas::norm2(x), blas::norm2(rSloppy), blas::norm2(v), blas::norm2(p), - blas::norm2(tmp), blas::norm2(r0), blas::norm2(t)); + printfQuda("BiCGstab debug: x2=%e, r2=%e, v2=%e, p2=%e, r0=%e, t2=%e\n", blas::norm2(x), blas::norm2(r_sloppy), + blas::norm2(v), blas::norm2(p), blas::norm2(r0), blas::norm2(t)); - // update p - if (!param.pipeline || updateR) {// need to update if not pipeline or did a reliable update - if (abs(rho*alpha) == 0.0) beta = 0.0; - else beta = (rho/rho0) * (alpha/omega); - blas::cxpaypbz(rSloppy, -beta*omega, v, beta, p); + converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); + + if (converged) { + // make sure we've truly converged + if (!just_updated) { + if (x.Precision() != x_sloppy.Precision()) blas::copy(x, x_sloppy); + blas::xpy(x, y); + mat(r, y); + r2 = blas::xmyNorm(b, r); + + if (param.deflate && sqrt(r2) < param.tol_restart) { + // Deflate and accumulate to solution vector + eig_solve->deflate(y, r, evecs, evals, true); + // Compute r_defl = RHS - A * LHS + mat(r, y); + r2 = blas::xmyNorm(b, r); + } + + if (x.Precision() != r_sloppy.Precision()) blas::copy(r_sloppy, r); + blas::zero(x_sloppy); + + rNorm = sqrt(r2); + maxrr = rNorm; + maxrx = rNorm; + // r0Norm = rNorm; + rUpdate++; + + just_updated = true; + } + + // explicitly compute the HQ residual if need be + heavy_quark_res = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(y, r).z) : 0.0; + + // Update convergence check + converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); } + // update p + if ((!param.pipeline || updateR) && !converged) { // need to update if not pipeline or did a reliable update + if (abs(rho * alpha) == 0.0) + beta = 0.0; + else + beta = (rho / rho0) * (alpha / omega); + blas::cxpaypbz(r_sloppy, -beta * omega, v, beta, p); + } } - if (x.Precision() != xSloppy.Precision()) blas::copy(x, xSloppy); - blas::xpy(y, x); + // We have a guarantee that we just converged via the true residual + // y has already been updated + blas::copy(x, y); - profile.TPSTOP(QUDA_PROFILE_COMPUTE); - profile.TPSTART(QUDA_PROFILE_EPILOGUE); + if (!param.is_preconditioner) { + profile.TPSTOP(QUDA_PROFILE_COMPUTE); + profile.TPSTART(QUDA_PROFILE_EPILOGUE); - param.iter += k; + param.iter += k; - if (k==param.maxiter) warningQuda("Exceeded maximum iterations %d", param.maxiter); + if (k == param.maxiter) warningQuda("Exceeded maximum iterations %d", param.maxiter); + } if (getVerbosity() >= QUDA_VERBOSE) printfQuda("BiCGstab: Reliable updates = %d\n", rUpdate); if (!param.is_preconditioner) { // do not do the below if we this is an inner solver - // Calculate the true residual - mat(r, x); - param.true_res = sqrt(blas::xmyNorm(b, r) / b2); + // r2 was freshly computed + param.true_res = sqrt(r2 / b2); param.true_res_hq = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(x,r).z) : 0.0; PrintSummary("BiCGstab", k, r2, b2, stop, param.tol_hq); - } - - profile.TPSTOP(QUDA_PROFILE_EPILOGUE); - profile.TPSTART(QUDA_PROFILE_FREE); - if (param.precision_sloppy != x.Precision()) { - delete r_0; - delete r_sloppy; + profile.TPSTOP(QUDA_PROFILE_EPILOGUE); } - else if(param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_YES) - { - delete r_0; - } - - if (&x != &xSloppy) delete x_sloppy; - profile.TPSTOP(QUDA_PROFILE_FREE); } } // namespace quda diff --git a/lib/inv_ca_gcr.cpp b/lib/inv_ca_gcr.cpp index 5b893bd3fc..bfb42119cc 100644 --- a/lib/inv_ca_gcr.cpp +++ b/lib/inv_ca_gcr.cpp @@ -141,6 +141,7 @@ namespace quda create(x, b); if (!param.is_preconditioner) profile.TPSTART(QUDA_PROFILE_PREAMBLE); + if (param.is_preconditioner) commGlobalReductionPush(param.global_reduction); // compute b2, but only if we need to bool fixed_iteration = param.sloppy_converge && n_krylov == param.maxiter && !param.compute_true_res; @@ -374,6 +375,8 @@ namespace quda } PrintSummary("CA-GCR", total_iter, r2, b2, stop, param.tol_hq); + + if (param.is_preconditioner) commGlobalReductionPop(); } } // namespace quda diff --git a/lib/spinor_dilute.in.cu b/lib/spinor_dilute.in.cu index 9b57458e8a..eadd519c5b 100644 --- a/lib/spinor_dilute.in.cu +++ b/lib/spinor_dilute.in.cu @@ -87,7 +87,12 @@ namespace quda const lat_dim_t &local_block, IntList) { if (src.Ncolor() == Nc) { - SpinorDilute(src, v, type, local_block); + if constexpr (Nc <= 32) { + SpinorDilute(src, v, type, local_block); + } else { + errorQuda( + "nColor = %d is too large to compile, see QUDA issue #1422 (https://github.com/lattice/quda/issues/1422)"); + } } else { if constexpr (sizeof...(N) > 0) spinorDilute(src, v, type, local_block, IntList()); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 472d6f63db..10088fc05b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -230,6 +230,11 @@ if(QUDA_DIRAC_STAGGERED) quda_checkbuildtest(hisq_stencil_test QUDA_BUILD_ALL_TESTS) install(TARGETS hisq_stencil_test ${QUDA_EXCLUDE_FROM_INSTALL} DESTINATION ${CMAKE_INSTALL_BINDIR}) + add_executable(hisq_stencil_ctest hisq_stencil_ctest.cpp) + target_link_libraries(hisq_stencil_ctest ${TEST_LIBS}) + quda_checkbuildtest(hisq_stencil_ctest QUDA_BUILD_ALL_TESTS) + install(TARGETS hisq_stencil_ctest ${QUDA_EXCLUDE_FROM_INSTALL} DESTINATION ${CMAKE_INSTALL_BINDIR}) + add_executable(hisq_paths_force_test hisq_paths_force_test.cpp) target_link_libraries(hisq_paths_force_test ${TEST_LIBS}) quda_checkbuildtest(hisq_paths_force_test QUDA_BUILD_ALL_TESTS) @@ -873,7 +878,7 @@ endif() add_test(NAME benchmark_dslash_${DIRAC_NAME}_policy${pol2} COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} --dslash-type ${DIRAC_NAME} - --test 0 + --test MatPC --dim 20 20 20 20 --gtest_output=json:dslash_${DIRAC_NAME}_benchmark_pol${pol2}.json --gtest_filter=*benchmark/*n0) @@ -909,7 +914,7 @@ endif() COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} --dslash-type ${DIRAC_NAME} --all-partitions 1 - --test 0 + --test MatPC --dim 20 20 20 20 --gtest_output=json:dslash_${DIRAC_NAME}_benchmark_pol${pol2}.json --gtest_filter=*benchmark/*n0) @@ -930,6 +935,19 @@ endif() if(polenv) set_tests_properties(dslash_${DIRAC_NAME}_build_policy${pol2} PROPERTIES ENVIRONMENT QUDA_ENABLE_DSLASH_POLICY=${pol2}) endif() + + if(QUDA_LAPLACE) + set(DIRAC_NAME laplace) + add_test(NAME dslash_${DIRAC_NAME}_mat_policy${pol2} + COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} + --dslash-type ${DIRAC_NAME} + --test Mat + --dim 2 4 6 8 + --gtest_output=xml:dslash_${DIRAC_NAME}_mat_test_pol${pol2}.xml) + if(polenv) + set_tests_properties(dslash_${DIRAC_NAME}_mat_policy${pol2} PROPERTIES ENVIRONMENT QUDA_ENABLE_DSLASH_POLICY=${pol2}) + endif() + endif() endif() if(QUDA_COVDEV) @@ -955,7 +973,7 @@ elseif(single_prec) set(TEST_PRECS single) endif() -# Inversions +# Wilson-type Inversions foreach(prec IN LISTS TEST_PRECS) if(${prec} STREQUAL "double") @@ -1130,7 +1148,73 @@ foreach(prec IN LISTS TEST_PRECS) endif() endforeach(prec) -# Eigensolves +# Staggered-type Inversions +foreach(prec IN LISTS TEST_PRECS) + + # These require looser tolerances to keep iterations to solution in check + if(${prec} STREQUAL "double") + set(tol 1e-6) + elseif(${prec} STREQUAL "single") + set(tol 1e-5) + endif() + + if(QUDA_DIRAC_STAGGERED) + # --compute-fat-long true is necessary to get well-behaved fields + + add_test(NAME invert_test_staggered_${prec} + COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} + --dslash-type staggered --ngcrkrylov 8 --compute-fat-long true + --dim 2 4 6 8 --prec ${prec} --tol ${tol} --tolhq ${tol} --niter 1000 + --enable-testing true + --gtest_output=xml:invert_test_staggered_${prec}.xml) + + if(DEFINED ENV{QUDA_ENABLE_TUNING}) + if($ENV{QUDA_ENABLE_TUNING} EQUAL 0) + add_test(NAME invert_test_splitgrid_staggered_${prec} + COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} + --dslash-type staggered --ngcrkrylov 8 --compute-fat-long true + --dim 2 4 6 8 --prec ${prec} --tol ${tol} --tolhq ${tol} --niter 1000 + --nsrc ${QUDA_TEST_NUM_PROCS} + --enable-testing true + --gtest_output=xml:invert_test_splitgrid_staggered_${prec}.xml) + + set_tests_properties(invert_test_splitgrid_staggered_${prec} PROPERTIES ENVIRONMENT QUDA_TEST_GRID_PARTITION=$ENV{QUDA_TEST_GRID_SIZE}) + endif() + endif() + + add_test(NAME invert_test_asqtad_${prec} + COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} + --dslash-type asqtad --ngcrkrylov 8 --compute-fat-long true + --dim 6 6 6 8 --prec ${prec} --tol ${tol} --tolhq ${tol} --niter 1000 + --enable-testing true + --gtest_output=xml:invert_test_asqtad_${prec}.xml) + + if(DEFINED ENV{QUDA_ENABLE_TUNING}) + if($ENV{QUDA_ENABLE_TUNING} EQUAL 0) + add_test(NAME invert_test_splitgrid_asqtad_${prec} + COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} + --dslash-type asqtad --ngcrkrylov 8 --compute-fat-long true + --dim 6 6 6 8 --prec ${prec} --tol ${tol} --tolhq ${tol} --niter 1000 + --nsrc ${QUDA_TEST_NUM_PROCS} + --enable-testing true + --gtest_output=xml:invert_test_splitgrid_asqtad_${prec}.xml) + + set_tests_properties(invert_test_splitgrid_asqtad_${prec} PROPERTIES ENVIRONMENT QUDA_TEST_GRID_PARTITION=$ENV{QUDA_TEST_GRID_SIZE}) + endif() + endif() + + if (QUDA_LAPLACE) + add_test(NAME invert_test_laplace_${prec} + COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} + --dslash-type laplace --ngcrkrylov 8 --compute-fat-long true + --dim 2 4 6 8 --prec ${prec} --tol ${tol} --tolhq ${tol} --niter 1000 + --enable-testing true + --gtest_output=xml:invert_test_laplace_${prec}.xml) + endif() + endif() +endforeach(prec) + +# Wilson-type eigensolves foreach(prec IN LISTS TEST_PRECS) if(${prec} STREQUAL "double") @@ -1296,7 +1380,59 @@ foreach(prec IN LISTS TEST_PRECS) --gtest_output=xml:eigensolve_test_mobius_eofa_asym_${prec}.xml) endif() endforeach(prec) - + +# Staggered-type eigensolves +foreach(prec IN LISTS TEST_PRECS) + + # These require looser tolerances to keep iterations to solution in check + if(${prec} STREQUAL "double") + set(tol 1e-5) + elseif(${prec} STREQUAL "single") + set(tol 1e-4) + endif() + + if(QUDA_DIRAC_STAGGERED) + # --compute-fat-long true is necessary to get well-behaved fields + + add_test(NAME eigensolve_test_staggered_${prec} + COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} + --dslash-type staggered --compute-fat-long true + --eig-n-conv 32 --eig-n-ev 32 --eig-n-kr 256 + --dim 2 4 6 8 --prec ${prec} --eig-tol ${tol} --eig-max-restarts 1000 + --enable-testing true + --gtest_output=xml:staggered_eigensolve_test_staggered_${prec}.xml) + + # These tests are particularly expensive so they are disabled by default + if(QUDA_IMPROVED_STAGGERED_EIGENSOLVER_CTEST) + add_test(NAME eigensolve_test_asqtad_${prec} + COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} + --dslash-type asqtad --compute-fat-long true + --eig-n-conv 32 --eig-n-ev 32 --eig-n-kr 256 + --dim 6 6 6 8 --prec ${prec} --eig-tol ${tol} --eig-max-restarts 1000 + --enable-testing true + --gtest_output=xml:staggered_eigensolve_test_staggered_${prec}.xml) + endif() + + if (QUDA_LAPLACE) + add_test(NAME eigensolve_test_laplace_${prec} + COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} + --dslash-type laplace --compute-fat-long true + --eig-n-conv 32 --eig-n-ev 32 --eig-n-kr 256 + --dim 2 4 6 8 --prec ${prec} --eig-tol ${tol} --eig-max-restarts 1000 + --enable-testing true + --gtest_output=xml:staggered_eigensolve_test_laplace_${prec}.xml) + endif() + endif() +endforeach(prec) + + +if(QUDA_DIRAC_STAGGERED) + add_test(NAME hisq_stencil + COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} + --dim 8 8 8 8 + --gtest_output=xml:hisq_stencil_test.xml) +endif() + foreach(prec IN LISTS TEST_PRECS) add_test(NAME gauge_path_${prec} diff --git a/tests/eigensolve_test_gtest.hpp b/tests/eigensolve_test_gtest.hpp index cd07ca401f..a872963413 100644 --- a/tests/eigensolve_test_gtest.hpp +++ b/tests/eigensolve_test_gtest.hpp @@ -11,17 +11,6 @@ class EigensolveTest : public ::testing::TestWithParam EigensolveTest() : param(GetParam()) { } }; -bool is_chiral(QudaDslashType type) -{ - switch (type) { - case QUDA_DOMAIN_WALL_DSLASH: - case QUDA_DOMAIN_WALL_4D_DSLASH: - case QUDA_MOBIUS_DWF_DSLASH: - case QUDA_MOBIUS_DWF_EOFA_DSLASH: return true; - default: return false; - } -} - bool skip_test(test_t param) { // dwf-style solves must use a normal solver diff --git a/tests/hisq_stencil_ctest.cpp b/tests/hisq_stencil_ctest.cpp new file mode 100644 index 0000000000..6df6d1d977 --- /dev/null +++ b/tests/hisq_stencil_ctest.cpp @@ -0,0 +1,173 @@ +#include "hisq_stencil_test_utils.h" + +using namespace quda; + +bool ctest_all_partitions = false; + +using ::testing::Bool; +using ::testing::Combine; +using ::testing::Range; +using ::testing::TestWithParam; +using ::testing::Values; + +class HisqStencilTest : public ::testing::TestWithParam<::testing::tuple> +{ +protected: + ::testing::tuple param; + + HisqStencilTestWrapper hisq_stencil_test_wrapper; + + bool skip() + { + QudaPrecision precision = static_cast(::testing::get<0>(GetParam())); + QudaReconstructType recon = static_cast(::testing::get<1>(GetParam())); + + if ((QUDA_PRECISION & precision) == 0 || (QUDA_RECONSTRUCT & getReconstructNibble(recon)) == 0) return true; + + const std::array partition_enabled {true, true, true, false, true, false, false, false, + true, false, false, false, true, false, true, true}; + if (!ctest_all_partitions && !partition_enabled[::testing::get<3>(GetParam())]) return true; + + return false; + } + + void display_test_info(QudaPrecision prec, QudaReconstructType link_recon, bool has_naik) + { + printfQuda("running the following test:\n"); + printfQuda( + "link_precision link_reconstruct space_dimension T_dimension Ordering\n"); + printfQuda("%s %s %d/%d/%d/ %d %s \n", + get_prec_str(prec), get_recon_str(link_recon), xdim, ydim, zdim, tdim, get_gauge_order_str(gauge_order)); + printfQuda("Grid partition info: X Y Z T\n"); + printfQuda(" %d %d %d %d\n", dimPartitioned(0), dimPartitioned(1), dimPartitioned(2), + dimPartitioned(3)); + printfQuda("Number of Naiks: %d\n", has_naik ? 2 : 1); + } + +public: + virtual void SetUp() + { + QudaPrecision prec = static_cast(::testing::get<0>(GetParam())); + QudaReconstructType recon = static_cast(::testing::get<1>(GetParam())); + bool has_naik = ::testing::get<2>(GetParam()); + + if (skip()) GTEST_SKIP(); + + int partition = ::testing::get<3>(GetParam()); + for (int j = 0; j < 4; j++) { + if (partition & (1 << j)) { commDimPartitionedSet(j); } + } + updateR(); + + hisq_stencil_test_wrapper.init_ctest(prec, recon, has_naik); + display_test_info(prec, recon, has_naik); + } + + virtual void TearDown() + { + if (skip()) GTEST_SKIP(); + hisq_stencil_test_wrapper.end(); + } + + static void SetUpTestCase() { initQuda(device_ordinal); } + + // Per-test-case tear-down. + // Called after the last test in this test case. + // Can be omitted if not needed. + static void TearDownTestCase() + { + HisqStencilTestWrapper::destroy(); + endQuda(); + } +}; + +TEST_P(HisqStencilTest, benchmark) { hisq_stencil_test_wrapper.run_test(niter, /**show_metrics =*/true); } + +TEST_P(HisqStencilTest, verify) +{ + hisq_stencil_test_wrapper.run_test(2); + + std::array res = hisq_stencil_test_wrapper.verify(); + + // extra factor of 10 b/c the norm isn't normalized + double max_dev = 10. * getTolerance(prec); + + // fat link + EXPECT_LE(res[0], max_dev) << "Reference CPU and QUDA implementations of fat link do not agree"; + + // long link + EXPECT_LE(res[1], max_dev) << "Reference CPU and QUDA implementations of long link do not agree"; +} + +int main(int argc, char **argv) +{ + // initalize google test + ::testing::InitGoogleTest(&argc, argv); + + // for speed + xdim = ydim = zdim = tdim = 8; + + // default to 18 reconstruct + link_recon = QUDA_RECONSTRUCT_NO; + cpu_prec = prec = QUDA_DOUBLE_PRECISION; + + // Parse command line options + auto app = make_app(); + app->add_option("--all-partitions", ctest_all_partitions, "Test all instead of reduced combination of partitions"); + try { + app->parse(argc, argv); + } catch (const CLI::ParseError &e) { + return app->exit(e); + } + + if (prec == QUDA_HALF_PRECISION || prec == QUDA_QUARTER_PRECISION) + errorQuda("Precision %d is unsupported in some link fattening routines\n", prec); + + if (link_recon != QUDA_RECONSTRUCT_NO) + errorQuda("Reconstruct %d is unsupported in some link fattening routines\n", link_recon); + + if (gauge_order != QUDA_MILC_GAUGE_ORDER) errorQuda("Unsupported gauge order %d", gauge_order); + + if (eps_naik != 0.0) { n_naiks = 2; } + + setVerbosity(verbosity); + initComms(argc, argv, gridsize_from_cmdline); + + // Ensure gtest prints only from rank 0 + ::testing::TestEventListeners &listeners = ::testing::UnitTest::GetInstance()->listeners(); + if (comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); } + + int test_rc = RUN_ALL_TESTS(); + + finalizeComms(); + + return test_rc; +} + +std::string +gethisqstenciltestname(testing::TestParamInfo<::testing::tuple> param) +{ + const QudaPrecision prec = static_cast(::testing::get<0>(param.param)); + const QudaReconstructType recon = static_cast(::testing::get<1>(param.param)); + const bool has_naik = ::testing::get<2>(param.param); + const int part = ::testing::get<3>(param.param); + std::stringstream ss; + // ss << get_dslash_str(dslash_type) << "_"; + ss << get_prec_str(prec); + ss << "_r" << recon; + if (has_naik) ss << "_naik"; + ss << "_partition" << part; + return ss.str(); +} + +#ifdef MULTI_GPU +INSTANTIATE_TEST_SUITE_P(QUDA, HisqStencilTest, + Combine(::testing::Values(QUDA_DOUBLE_PRECISION, QUDA_SINGLE_PRECISION), + ::testing::Values(QUDA_RECONSTRUCT_NO), ::testing::Bool(), Range(0, 16)), + gethisqstenciltestname); +#else +INSTANTIATE_TEST_SUITE_P(QUDA, HisqStencilTest, + Combine(::testing::Values(QUDA_DOUBLE_PRECISION, QUDA_SINGLE_PRECISION), + ::testing::Values(QUDA_RECONSTRUCT_NO), ::testing::Bool(), ::testing::Values(0)), + gethisqstenciltestname); +#endif diff --git a/tests/hisq_stencil_test.cpp b/tests/hisq_stencil_test.cpp index 98c2ae91d3..3b20287d3b 100644 --- a/tests/hisq_stencil_test.cpp +++ b/tests/hisq_stencil_test.cpp @@ -1,418 +1,71 @@ -#include -#include -#include -#include - -#include "quda.h" -#include "gauge_field.h" -#include "host_utils.h" -#include -#include "misc.h" -#include "util_quda.h" -#include "malloc_quda.h" -#include -#include "ks_improved_force.h" - -#ifdef MULTI_GPU -#include "comm_quda.h" -#endif - -#define TDIFF(a, b) (b.tv_sec - a.tv_sec + 0.000001 * (b.tv_usec - a.tv_usec)) +#include "hisq_stencil_test_utils.h" using namespace quda; -// Number of naiks. If eps_naik is 0.0, we only need -// to construct one naik. -static QudaGaugeFieldOrder gauge_order = QUDA_MILC_GAUGE_ORDER; - -// The file "generic_ks/fermion_links_hisq_load_milc.c" -// within MILC is the ultimate reference for what's going on here. - -// Unitarization coefficients -static double unitarize_eps = 1e-6; -static bool reunit_allow_svd = true; -static bool reunit_svd_only = false; -static double svd_rel_error = 1e-4; -static double svd_abs_error = 1e-4; -static double max_allowed_error = 1e-11; - -/*--------------------------------------------------------------------*/ -// Some notation: -// U -- original link, SU(3), copied to "field" from "site" -// V -- after 1st level of smearing, non-SU(3) -// W -- unitarized, SU(3) -// X -- after 2nd level of smearing, non-SU(3) -/*--------------------------------------------------------------------*/ - -static void hisq_test() +class HisqStencilTest : public ::testing::Test { +protected: + HisqStencilTestWrapper hisq_stencil_test_wrapper; - QudaGaugeParam qudaGaugeParam; - - initQuda(device_ordinal); - - if (prec == QUDA_HALF_PRECISION || prec == QUDA_QUARTER_PRECISION) { - errorQuda("Precision %d is unsupported in some link fattening routines\n", prec); - } - - cpu_prec = prec; - host_gauge_data_type_size = cpu_prec; - qudaGaugeParam = newQudaGaugeParam(); - - qudaGaugeParam.anisotropy = 1.0; - - // Fix me: must always be set to 1.0 for reasons not yet discerned. - // The tadpole coefficient gets encoded directly into the fat link - // construct coefficents. - qudaGaugeParam.tadpole_coeff = 1.0; - - qudaGaugeParam.X[0] = xdim; - qudaGaugeParam.X[1] = ydim; - qudaGaugeParam.X[2] = zdim; - qudaGaugeParam.X[3] = tdim; - - setDims(qudaGaugeParam.X); - - qudaGaugeParam.cpu_prec = cpu_prec; - qudaGaugeParam.cuda_prec = qudaGaugeParam.cuda_prec_sloppy = prec; - - if (gauge_order != QUDA_MILC_GAUGE_ORDER) errorQuda("Unsupported gauge order %d", gauge_order); - - qudaGaugeParam.gauge_order = gauge_order; - qudaGaugeParam.type = QUDA_WILSON_LINKS; - qudaGaugeParam.reconstruct = qudaGaugeParam.reconstruct_sloppy = link_recon; - qudaGaugeParam.t_boundary = QUDA_ANTI_PERIODIC_T; - qudaGaugeParam.staggered_phase_type = QUDA_STAGGERED_PHASE_MILC; - qudaGaugeParam.gauge_fix = QUDA_GAUGE_FIXED_NO; - qudaGaugeParam.ga_pad = 0; - - // Needed for unitarization, following "unitarize_link_test.cpp" - GaugeFieldParam gParam(qudaGaugeParam); - gParam.link_type = QUDA_GENERAL_LINKS; - gParam.ghostExchange = QUDA_GHOST_EXCHANGE_NO; - gParam.order = gauge_order; - - /////////////////////////////////////////////////////////////// - // Set up the coefficients for each part of the HISQ stencil // - /////////////////////////////////////////////////////////////// - - // Reference: "generic_ks/imp_actions/hisq/hisq_action.h", - // in QHMC: https://github.com/jcosborn/qhmc/blob/master/lib/qopqdp/hisq.c - - double u1 = 1.0 / tadpole_factor; - double u2 = u1 * u1; - double u4 = u2 * u2; - double u6 = u4 * u2; - - std::array, 3> act_paths; - - // First path: create V, W links - act_paths[0] = { - (1.0 / 8.0), /* one link */ - u2 * (0.0), /* Naik */ - u2 * (-1.0 / 8.0) * 0.5, /* simple staple */ - u4 * (1.0 / 8.0) * 0.25 * 0.5, /* displace link in two directions */ - u6 * (-1.0 / 8.0) * 0.125 * (1.0 / 6.0), /* displace link in three directions */ - u4 * (0.0) /* Lepage term */ - }; - - // Second path: create X, long links - act_paths[1] = { - ((1.0 / 8.0) + (2.0 * 6.0 / 16.0) + (1.0 / 8.0)), /* one link */ - /* One link is 1/8 as in fat7 + 2*3/8 for Lepage + 1/8 for Naik */ - (-1.0 / 24.0), /* Naik */ - (-1.0 / 8.0) * 0.5, /* simple staple */ - (1.0 / 8.0) * 0.25 * 0.5, /* displace link in two directions */ - (-1.0 / 8.0) * 0.125 * (1.0 / 6.0), /* displace link in three directions */ - (-2.0 / 16.0) /* Lepage term, correct O(a^2) 2x ASQTAD */ - }; - - // Paths for epsilon corrections. Not used if n_naiks = 1. - act_paths[2] = { - (1.0 / 8.0), /* one link b/c of Naik */ - (-1.0 / 24.0), /* Naik */ - 0.0, /* simple staple */ - 0.0, /* displace link in two directions */ - 0.0, /* displace link in three directions */ - 0.0 /* Lepage term */ - }; - - //////////////////////////////////// - // Set unitarization coefficients // - //////////////////////////////////// - - setUnitarizeLinksConstants(unitarize_eps, max_allowed_error, reunit_allow_svd, reunit_svd_only, svd_rel_error, - svd_abs_error); - - ///////////////// - // Input links // - ///////////////// - - void *sitelink[4]; - for (int i = 0; i < 4; i++) sitelink[i] = pinned_malloc(V * gauge_site_size * host_gauge_data_type_size); - - void *milc_sitelink; - milc_sitelink = (void *)safe_malloc(4 * V * gauge_site_size * host_gauge_data_type_size); - - // Note: this could be replaced with loading a gauge field - createSiteLinkCPU(sitelink, qudaGaugeParam.cpu_prec, 0); // 0 -> no phases - for (int i = 0; i < V; ++i) { - for (int dir = 0; dir < 4; ++dir) { - char *src = (char *)sitelink[dir]; - memcpy((char *)milc_sitelink + (i * 4 + dir) * gauge_site_size * host_gauge_data_type_size, - src + i * gauge_site_size * host_gauge_data_type_size, gauge_site_size * host_gauge_data_type_size); - } - } - - ////////////////////// - // Perform GPU test // - ////////////////////// - - // Paths for step 1: - void *vlink = pinned_malloc(4 * V * gauge_site_size * host_gauge_data_type_size); // V links - void *wlink = pinned_malloc(4 * V * gauge_site_size * host_gauge_data_type_size); // W links - - // Paths for step 2: - void *fatlink = pinned_malloc(4 * V * gauge_site_size * host_gauge_data_type_size); // final fat ("X") links - void *longlink = pinned_malloc(4 * V * gauge_site_size * host_gauge_data_type_size); // final long links - - // Place to accumulate Naiks - void *fatlink_eps = nullptr; - void *longlink_eps = nullptr; - if (n_naiks > 1) { - fatlink_eps = pinned_malloc(4 * V * gauge_site_size * host_gauge_data_type_size); // epsilon fat links - longlink_eps = pinned_malloc(4 * V * gauge_site_size * host_gauge_data_type_size); // epsilon long naiks - } - - // Tuning run... + void display_test_info() { - printfQuda("Tuning...\n"); - computeKSLinkQuda(vlink, longlink, wlink, milc_sitelink, act_paths[1].data(), &qudaGaugeParam); - } - - struct timeval t0, t1; - printfQuda("Running %d iterations of computation\n", niter); - gettimeofday(&t0, NULL); - for (int n = 0; n < niter; n++) { - - // If we create cudaGaugeField objs, we can do this 100% on the GPU, no copying! - - // Create V links (fat7 links) and W links (unitarized V links), 1st path table set - computeKSLinkQuda(vlink, nullptr, wlink, milc_sitelink, act_paths[0].data(), &qudaGaugeParam); - - if (n_naiks > 1) { - // Create Naiks, 3rd path table set - computeKSLinkQuda(fatlink, longlink, nullptr, wlink, act_paths[2].data(), &qudaGaugeParam); - - // Rescale+copy Naiks into Naik field - cpu_axy(prec, eps_naik, fatlink, fatlink_eps, V * 4 * gauge_site_size); - cpu_axy(prec, eps_naik, longlink, longlink_eps, V * 4 * gauge_site_size); - } else { - memset(fatlink, 0, V * 4 * gauge_site_size * host_gauge_data_type_size); - memset(longlink, 0, V * 4 * gauge_site_size * host_gauge_data_type_size); - } - - // Create X and long links, 2nd path table set - computeKSLinkQuda(fatlink, longlink, nullptr, wlink, act_paths[1].data(), &qudaGaugeParam); - - if (n_naiks > 1) { - // Add into Naik field - cpu_xpy(prec, fatlink, fatlink_eps, V * 4 * gauge_site_size); - cpu_xpy(prec, longlink, longlink_eps, V * 4 * gauge_site_size); - } + printfQuda("running the following test:\n"); + printfQuda( + "link_precision link_reconstruct space_dimension T_dimension Ordering\n"); + printfQuda("%s %s %d/%d/%d/ %d %s \n", + get_prec_str(prec), get_recon_str(link_recon), xdim, ydim, zdim, tdim, get_gauge_order_str(gauge_order)); + printfQuda("Grid partition info: X Y Z T\n"); + printfQuda(" %d %d %d %d\n", dimPartitioned(0), dimPartitioned(1), dimPartitioned(2), + dimPartitioned(3)); + printfQuda("Number of Naiks: %d\n", n_naiks); } - gettimeofday(&t1, NULL); - double secs = TDIFF(t0, t1); - - /////////////////////// - // Perform CPU Build // - /////////////////////// - - void *long_reflink[4]; // Long link for fermion with zero epsilon - void *fat_reflink[4]; // Fat link for fermion with zero epsilon - for (int i = 0; i < 4; i++) { - long_reflink[i] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); - fat_reflink[i] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); - } - - void *long_reflink_eps[4]; // Long link for fermion with non-zero epsilon - void *fat_reflink_eps[4]; // Fat link for fermion with non-zero epsilon - if (n_naiks > 1) { - for (int i = 0; i < 4; i++) { - long_reflink_eps[i] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); - fat_reflink_eps[i] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); - } - } - - if (verify_results) { - computeHISQLinksCPU(fat_reflink, long_reflink, fat_reflink_eps, long_reflink_eps, sitelink, &qudaGaugeParam, - act_paths, eps_naik); - } - - //////////////////////////////////////////////////////////////////// - // Layout change for fatlink, fatlink_eps, longlink, longlink_eps // - //////////////////////////////////////////////////////////////////// - - void *myfatlink[4]; - void *mylonglink[4]; - void *myfatlink_eps[4]; - void *mylonglink_eps[4]; - for (int i = 0; i < 4; i++) { - - myfatlink[i] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); - mylonglink[i] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); - memset(myfatlink[i], 0, V * gauge_site_size * host_gauge_data_type_size); - memset(mylonglink[i], 0, V * gauge_site_size * host_gauge_data_type_size); - - if (n_naiks > 1) { - myfatlink_eps[i] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); - mylonglink_eps[i] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); - memset(myfatlink_eps[i], 0, V * gauge_site_size * host_gauge_data_type_size); - memset(mylonglink_eps[i], 0, V * gauge_site_size * host_gauge_data_type_size); - } - } - - for (int i = 0; i < V; i++) { - for (int dir = 0; dir < 4; dir++) { - char *src = ((char *)fatlink) + (4 * i + dir) * gauge_site_size * host_gauge_data_type_size; - char *dst = ((char *)myfatlink[dir]) + i * gauge_site_size * host_gauge_data_type_size; - memcpy(dst, src, gauge_site_size * host_gauge_data_type_size); - - src = ((char *)longlink) + (4 * i + dir) * gauge_site_size * host_gauge_data_type_size; - dst = ((char *)mylonglink[dir]) + i * gauge_site_size * host_gauge_data_type_size; - memcpy(dst, src, gauge_site_size * host_gauge_data_type_size); - - if (n_naiks > 1) { - src = ((char *)fatlink_eps) + (4 * i + dir) * gauge_site_size * host_gauge_data_type_size; - dst = ((char *)myfatlink_eps[dir]) + i * gauge_site_size * host_gauge_data_type_size; - memcpy(dst, src, gauge_site_size * host_gauge_data_type_size); - - src = ((char *)longlink_eps) + (4 * i + dir) * gauge_site_size * host_gauge_data_type_size; - dst = ((char *)mylonglink_eps[dir]) + i * gauge_site_size * host_gauge_data_type_size; - memcpy(dst, src, gauge_site_size * host_gauge_data_type_size); - } - } - } - - ////////////////////////////// - // Perform the verification // - ////////////////////////////// - - if (verify_results) { - printfQuda("Checking fat links...\n"); - int res = 1; - for (int dir = 0; dir < 4; dir++) { - res &= compare_floats(fat_reflink[dir], myfatlink[dir], V * gauge_site_size, 1e-3, qudaGaugeParam.cpu_prec); - } - - strong_check_link(myfatlink, "GPU results: ", fat_reflink, "CPU reference results:", V, qudaGaugeParam.cpu_prec); - - printfQuda("Fat-link test %s\n\n", (1 == res) ? "PASSED" : "FAILED"); - - printfQuda("Checking long links...\n"); - res = 1; - for (int dir = 0; dir < 4; ++dir) { - res &= compare_floats(long_reflink[dir], mylonglink[dir], V * gauge_site_size, 1e-3, qudaGaugeParam.cpu_prec); - } - - strong_check_link(mylonglink, "GPU results: ", long_reflink, "CPU reference results:", V, qudaGaugeParam.cpu_prec); - - printfQuda("Long-link test %s\n\n", (1 == res) ? "PASSED" : "FAILED"); - - if (n_naiks > 1) { - - printfQuda("Checking fat eps_naik links...\n"); - res = 1; - for (int dir = 0; dir < 4; dir++) { - res &= compare_floats(fat_reflink_eps[dir], myfatlink_eps[dir], V * gauge_site_size, 1e-3, - qudaGaugeParam.cpu_prec); - } - - strong_check_link(myfatlink_eps, "GPU results: ", fat_reflink_eps, "CPU reference results:", V, - qudaGaugeParam.cpu_prec); - - printfQuda("Fat-link eps_naik test %s\n\n", (1 == res) ? "PASSED" : "FAILED"); - - printfQuda("Checking long eps_naik links...\n"); - res = 1; - for (int dir = 0; dir < 4; ++dir) { - res &= compare_floats(long_reflink_eps[dir], mylonglink_eps[dir], V * gauge_site_size, 1e-3, - qudaGaugeParam.cpu_prec); - } - - strong_check_link(mylonglink_eps, "GPU results: ", long_reflink_eps, "CPU reference results:", V, - qudaGaugeParam.cpu_prec); - - printfQuda("Long-link eps_naik test %s\n\n", (1 == res) ? "PASSED" : "FAILED"); - } +public: + virtual void SetUp() + { + hisq_stencil_test_wrapper.init_test(); + display_test_info(); } - // FIXME: does not include unitarization, extra naiks - int volume = qudaGaugeParam.X[0] * qudaGaugeParam.X[1] * qudaGaugeParam.X[2] * qudaGaugeParam.X[3]; - long long flops = 61632 * (long long)niter; // Constructing V field - // Constructing W field? - // Constructing separate Naiks - flops += 61632 * (long long)niter; // Constructing X field - flops += (252 * 4) * (long long)niter; // long-link contribution - - double perf = flops * volume / (secs * 1024 * 1024 * 1024); - printfQuda("link computation time =%.2f ms, flops= %.2f Gflops\n", (secs * 1000) / niter, perf); + virtual void TearDown() { hisq_stencil_test_wrapper.end(); } - for (int i = 0; i < 4; i++) { - host_free(myfatlink[i]); - host_free(mylonglink[i]); - if (n_naiks > 1) { - host_free(myfatlink_eps[i]); - host_free(mylonglink_eps[i]); - } - } + static void SetUpTestCase() { initQuda(device_ordinal); } - for (int i = 0; i < 4; i++) { - host_free(sitelink[i]); - host_free(fat_reflink[i]); - host_free(long_reflink[i]); - if (n_naiks > 1) { - host_free(fat_reflink_eps[i]); - host_free(long_reflink_eps[i]); - } + // Per-test-case tear-down. + // Called after the last test in this test case. + // Can be omitted if not needed. + static void TearDownTestCase() + { + HisqStencilTestWrapper::destroy(); + endQuda(); } +}; - // Clean up GPU compute links - host_free(vlink); - host_free(wlink); - host_free(fatlink); - host_free(longlink); +TEST_F(HisqStencilTest, benchmark) { hisq_stencil_test_wrapper.run_test(niter, /**show_metrics =*/true); } - if (n_naiks > 1) { - host_free(fatlink_eps); - host_free(longlink_eps); - } +TEST_F(HisqStencilTest, verify) +{ + if (!verify_results) GTEST_SKIP(); - if (milc_sitelink) host_free(milc_sitelink); -#ifdef MULTI_GPU - exchange_llfat_cleanup(); -#endif - endQuda(); -} + hisq_stencil_test_wrapper.run_test(2); -static void display_test_info() -{ - printfQuda("running the following test:\n"); + std::array res = hisq_stencil_test_wrapper.verify(); - printfQuda("link_precision link_reconstruct space_dimension T_dimension Ordering\n"); - printfQuda("%s %s %d/%d/%d/ %d %s \n", - get_prec_str(prec), get_recon_str(link_recon), xdim, ydim, zdim, tdim, get_gauge_order_str(gauge_order)); + // extra factor of 10 b/c the norm isn't normalized + double max_dev = 10. * getTolerance(prec); - printfQuda("Grid partition info: X Y Z T\n"); - printfQuda(" %d %d %d %d\n", dimPartitioned(0), dimPartitioned(1), dimPartitioned(2), - dimPartitioned(3)); + // fat link + EXPECT_LE(res[0], max_dev) << "Reference CPU and QUDA implementations of fat link do not agree"; - printfQuda("Number of Naiks: %d\n", n_naiks); + // long link + EXPECT_LE(res[1], max_dev) << "Reference CPU and QUDA implementations of long link do not agree"; } int main(int argc, char **argv) { + // initalize google test + ::testing::InitGoogleTest(&argc, argv); + // for speed xdim = ydim = zdim = tdim = 8; @@ -420,21 +73,34 @@ int main(int argc, char **argv) link_recon = QUDA_RECONSTRUCT_NO; cpu_prec = prec = QUDA_DOUBLE_PRECISION; + // Parse command line options auto app = make_app(); - // app->get_formatter()->column_width(40); - // add_eigen_option_group(app); - // add_deflation_option_group(app); - // add_multigrid_option_group(app); try { app->parse(argc, argv); } catch (const CLI::ParseError &e) { return app->exit(e); } + if (prec == QUDA_HALF_PRECISION || prec == QUDA_QUARTER_PRECISION) + errorQuda("Precision %d is unsupported in some link fattening routines\n", prec); + + if (link_recon != QUDA_RECONSTRUCT_NO) + errorQuda("Reconstruct %d is unsupported in some link fattening routines\n", link_recon); + + if (gauge_order != QUDA_MILC_GAUGE_ORDER) errorQuda("Unsupported gauge order %d", gauge_order); + if (eps_naik != 0.0) { n_naiks = 2; } + setVerbosity(verbosity); initComms(argc, argv, gridsize_from_cmdline); - display_test_info(); - hisq_test(); + + // Ensure gtest prints only from rank 0 + ::testing::TestEventListeners &listeners = ::testing::UnitTest::GetInstance()->listeners(); + if (comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); } + + int test_rc = RUN_ALL_TESTS(); + finalizeComms(); + + return test_rc; } diff --git a/tests/hisq_stencil_test_utils.h b/tests/hisq_stencil_test_utils.h new file mode 100644 index 0000000000..d340c0edda --- /dev/null +++ b/tests/hisq_stencil_test_utils.h @@ -0,0 +1,459 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +using namespace quda; + +// Number of naiks. If eps_naik is 0.0, we only need +// to construct one naik. +static QudaGaugeFieldOrder gauge_order = QUDA_MILC_GAUGE_ORDER; + +// The file "generic_ks/fermion_links_hisq_load_milc.c" +// within MILC is the ultimate reference for what's going on here. + +// Unitarization coefficients +static double unitarize_eps = 1e-6; +static bool reunit_allow_svd = true; +static bool reunit_svd_only = false; +static double svd_rel_error = 1e-4; +static double svd_abs_error = 1e-4; +static double max_allowed_error = 1e-11; + +struct HisqStencilTestWrapper { + + static inline QudaGaugeParam gauge_param; + + // staple coefficients for different portions of the HISQ stencil build + static inline std::array, 3> act_paths; + + // initial links in MILC order + static inline void *milc_sitelink = nullptr; + + // storage for CPU reference fat and long links w/zero Naik + static inline void *fat_reflink[4] = {nullptr, nullptr, nullptr, nullptr}; + static inline void *long_reflink[4] = {nullptr, nullptr, nullptr, nullptr}; + + // storage for CPU reference fat and long links w/non-zero Naik + static inline void *fat_reflink_eps[4] = {nullptr, nullptr, nullptr, nullptr}; + static inline void *long_reflink_eps[4] = {nullptr, nullptr, nullptr, nullptr}; + + // Paths for step 1: + static inline void *vlink = nullptr; + static inline void *wlink = nullptr; + + // Paths for step 2: + static inline void *fatlink = nullptr; + static inline void *longlink = nullptr; + + // Place to accumulate Naiks + static inline void *fatlink_eps = nullptr; + static inline void *longlink_eps = nullptr; + + static inline void *qdp_sitelink[4] = {nullptr, nullptr, nullptr, nullptr}; + static inline void *qdp_fatlink[4] = {nullptr, nullptr, nullptr, nullptr}; + static inline void *qdp_longlink[4] = {nullptr, nullptr, nullptr, nullptr}; + static inline void *qdp_fatlink_eps[4] = {nullptr, nullptr, nullptr, nullptr}; + static inline void *qdp_longlink_eps[4] = {nullptr, nullptr, nullptr, nullptr}; + + void set_naik(bool has_naik) + { + if (has_naik) { + eps_naik = -0.03; // semi-arbitrary + n_naiks = 2; + } else { + eps_naik = 0.0; + n_naiks = 1; + } + } + + void init_ctest(QudaPrecision prec_, QudaReconstructType link_recon_, bool has_naik) + { + prec = prec_; + link_recon = link_recon_; + + set_naik(has_naik); + + gauge_param = newQudaGaugeParam(); + setStaggeredGaugeParam(gauge_param); + + gauge_param.cuda_prec = prec; + + static bool first_time = true; + if (first_time) { + // force the Naik build up front, it doesn't effect the non-naik fields + set_naik(true); + init_host(); + set_naik(has_naik); + first_time = false; + } + init(); + } + + void init_test() + { + gauge_param = newQudaGaugeParam(); + setStaggeredGaugeParam(gauge_param); + + static bool first_time = true; + if (first_time) { + init_host(); + first_time = false; + } + init(); + } + + void init_host() + { + setDims(gauge_param.X); + dw_setDims(gauge_param.X, 1); + + /////////////////////////////////////////////////////////////// + // Set up the coefficients for each part of the HISQ stencil // + /////////////////////////////////////////////////////////////// + + // Reference: "generic_ks/imp_actions/hisq/hisq_action.h", + // in QHMC: https://github.com/jcosborn/qhmc/blob/master/lib/qopqdp/hisq.c + + double u1 = 1.0 / tadpole_factor; + double u2 = u1 * u1; + double u4 = u2 * u2; + double u6 = u4 * u2; + + // First path: create V, W links + act_paths[0] = { + (1.0 / 8.0), /* one link */ + u2 * (0.0), /* Naik */ + u2 * (-1.0 / 8.0) * 0.5, /* simple staple */ + u4 * (1.0 / 8.0) * 0.25 * 0.5, /* displace link in two directions */ + u6 * (-1.0 / 8.0) * 0.125 * (1.0 / 6.0), /* displace link in three directions */ + u4 * (0.0) /* Lepage term */ + }; + + // Second path: create X, long links + act_paths[1] = { + ((1.0 / 8.0) + (2.0 * 6.0 / 16.0) + (1.0 / 8.0)), /* one link */ + /* One link is 1/8 as in fat7 + 2*3/8 for Lepage + 1/8 for Naik */ + (-1.0 / 24.0), /* Naik */ + (-1.0 / 8.0) * 0.5, /* simple staple */ + (1.0 / 8.0) * 0.25 * 0.5, /* displace link in two directions */ + (-1.0 / 8.0) * 0.125 * (1.0 / 6.0), /* displace link in three directions */ + (-2.0 / 16.0) /* Lepage term, correct O(a^2) 2x ASQTAD */ + }; + + // Paths for epsilon corrections. Not used if n_naiks = 1. + act_paths[2] = { + (1.0 / 8.0), /* one link b/c of Naik */ + (-1.0 / 24.0), /* Naik */ + 0.0, /* simple staple */ + 0.0, /* displace link in two directions */ + 0.0, /* displace link in three directions */ + 0.0 /* Lepage term */ + }; + + //////////////////////////////////// + // Set unitarization coefficients // + //////////////////////////////////// + + setUnitarizeLinksConstants(unitarize_eps, max_allowed_error, reunit_allow_svd, reunit_svd_only, svd_rel_error, + svd_abs_error); + + ///////////////// + // Input links // + ///////////////// + + for (int i = 0; i < 4; i++) qdp_sitelink[i] = pinned_malloc(V * gauge_site_size * host_gauge_data_type_size); + + // Note: this could be replaced with loading a gauge field + createSiteLinkCPU(qdp_sitelink, gauge_param.cpu_prec, 0); // 0 -> no phases + + /////////////////////// + // Perform CPU Build // + /////////////////////// + + for (int i = 0; i < 4; i++) { + // fat and long links for fermions with zero epsilon + fat_reflink[i] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); + long_reflink[i] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); + + // fat and long links for fermions with non-zero epsilon + if (n_naiks > 1) { + fat_reflink_eps[i] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); + long_reflink_eps[i] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); + } + } + + computeHISQLinksCPU(fat_reflink, long_reflink, fat_reflink_eps, long_reflink_eps, qdp_sitelink, &gauge_param, + act_paths, eps_naik); + + ///////////////////////////////////////////////////////////////////// + // Allocate CPU-precision host storage for fields built on the GPU // + ///////////////////////////////////////////////////////////////////// + + // QDP order fields + for (int i = 0; i < 4; i++) { + qdp_fatlink[i] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); + qdp_longlink[i] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); + if (n_naiks > 1) { + qdp_fatlink_eps[i] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); + qdp_longlink_eps[i] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); + } + } + +#ifdef MULTI_GPU + exchange_llfat_cleanup(); +#endif + } + + void init() + { + + // reset the reconstruct in gauge param + gauge_param.reconstruct = link_recon; + + ///////////////////////////////////////////////////////////////// + // Create a CPU copy of the initial field in the GPU precision // + ///////////////////////////////////////////////////////////////// + + milc_sitelink = (void *)safe_malloc(4 * V * gauge_site_size * gauge_param.cuda_prec); + reorderQDPtoMILC(milc_sitelink, qdp_sitelink, V, gauge_site_size, gauge_param.cuda_prec, gauge_param.cpu_prec); + + /////////////////////////////////////////////////////// + // Allocate host storage for fields built on the GPU // + /////////////////////////////////////////////////////// + + // Paths for step 1: + vlink = pinned_malloc(4 * V * gauge_site_size * gauge_param.cuda_prec); // V links + wlink = pinned_malloc(4 * V * gauge_site_size * gauge_param.cuda_prec); // W links + + // Paths for step 2: + fatlink = pinned_malloc(4 * V * gauge_site_size * gauge_param.cuda_prec); // final fat ("X") links + longlink = pinned_malloc(4 * V * gauge_site_size * gauge_param.cuda_prec); // final long links + + // Place to accumulate Naiks + if (n_naiks > 1) { + fatlink_eps = pinned_malloc(4 * V * gauge_site_size * gauge_param.cuda_prec); // epsilon fat links + longlink_eps = pinned_malloc(4 * V * gauge_site_size * gauge_param.cuda_prec); // epsilon long naiks + } + } + + static void end() + { + if (milc_sitelink) host_free(milc_sitelink); + + // Clean up GPU compute links + if (vlink) host_free(vlink); + if (wlink) host_free(wlink); + if (fatlink) host_free(fatlink); + if (longlink) host_free(longlink); + + if (n_naiks > 1) { + if (fatlink_eps) host_free(fatlink_eps); + if (longlink_eps) host_free(longlink_eps); + } + + freeGaugeQuda(); + } + + static void destroy() + { + + for (int i = 0; i < 4; i++) { + host_free(fat_reflink[i]); + host_free(long_reflink[i]); + if (n_naiks > 1) { + host_free(fat_reflink_eps[i]); + host_free(long_reflink_eps[i]); + } + } + + for (int i = 0; i < 4; i++) { + host_free(qdp_sitelink[i]); + host_free(qdp_fatlink[i]); + host_free(qdp_longlink[i]); + if (n_naiks > 1) { + host_free(qdp_fatlink_eps[i]); + host_free(qdp_longlink_eps[i]); + } + } + } + + /*--------------------------------------------------------------------*/ + // Some notation: + // U -- original link, SU(3), copied to "field" from "site" + // V -- after 1st level of smearing, non-SU(3) + // W -- unitarized, SU(3) + // X -- after 2nd level of smearing, non-SU(3) + /*--------------------------------------------------------------------*/ + + double llfatCUDA(int niter) + { + host_timer_t host_timer; + + comm_barrier(); + host_timer.start(); + + // manually override precision of input fields + auto cpu_param_backup = gauge_param.cpu_prec; + gauge_param.cpu_prec = gauge_param.cuda_prec; + + for (int i = 0; i < niter; i++) { + // If we create cudaGaugeField objs, we can do this 100% on the GPU, no copying! + + // Create V links (fat7 links) and W links (unitarized V links), 1st path table set + computeKSLinkQuda(vlink, nullptr, wlink, milc_sitelink, act_paths[0].data(), &gauge_param); + + if (n_naiks > 1) { + // Create Naiks, 3rd path table set + computeKSLinkQuda(fatlink, longlink, nullptr, wlink, act_paths[2].data(), &gauge_param); + + // Rescale+copy Naiks into Naik field + cpu_axy(gauge_param.cuda_prec, eps_naik, fatlink, fatlink_eps, V * 4 * gauge_site_size); + cpu_axy(gauge_param.cuda_prec, eps_naik, longlink, longlink_eps, V * 4 * gauge_site_size); + } else { + memset(fatlink, 0, V * 4 * gauge_site_size * gauge_param.cuda_prec); + memset(longlink, 0, V * 4 * gauge_site_size * gauge_param.cuda_prec); + } + + // Create X and long links, 2nd path table set + computeKSLinkQuda(fatlink, longlink, nullptr, wlink, act_paths[1].data(), &gauge_param); + + if (n_naiks > 1) { + // Add into Naik field + cpu_xpy(gauge_param.cuda_prec, fatlink, fatlink_eps, V * 4 * gauge_site_size); + cpu_xpy(gauge_param.cuda_prec, longlink, longlink_eps, V * 4 * gauge_site_size); + } + } + + gauge_param.cpu_prec = cpu_param_backup; + + host_timer.stop(); + + return host_timer.last(); + } + + void run_test(int niter, bool print_metrics = false) + { + ////////////////////// + // Perform GPU test // + ////////////////////// + + printfQuda("Tuning...\n"); + llfatCUDA(1); + + auto flops0 = quda::Tunable::flops_global(); + auto bytes0 = quda::Tunable::bytes_global(); + + printfQuda("Running %d iterations of computation\n", niter); + double secs = llfatCUDA(niter); + + unsigned long long flops = (quda::Tunable::flops_global() - flops0); + unsigned long long bytes = (quda::Tunable::bytes_global() - bytes0); + + if (print_metrics) { + // FIXME: does not include unitarization, extra naiks + int volume = gauge_param.X[0] * gauge_param.X[1] * gauge_param.X[2] * gauge_param.X[3]; + // long long flops = 61632 * (long long)niter; // Constructing V field + // Constructing W field? + // Constructing separate Naiks + // flops += 61632 * (long long)niter; // Constructing X field + // flops += (252 * 4) * (long long)niter; // long-link contribution + + printfQuda("%fus per HISQ link build\n", 1e6 * secs / niter); + + printfQuda("%llu flops per HISQ link build, %llu flops per site %llu bytes per site\n", flops / niter, + (flops / niter) / volume, (bytes / niter) / volume); + + double gflops = 1.0e-9 * flops / secs; + printfQuda("GFLOPS = %f\n", gflops); + + double gbytes = 1.0e-9 * bytes / secs; + printfQuda("GBYTES = %f\n", gbytes); + + // Old metric + // double perf = flops / (secs * 1024 * 1024 * 1024); + // printfQuda("link computation time =%.2f ms, flops= %.2f Gflops\n", (secs * 1000) / niter, perf); + } + } + + std::array verify() + { + //////////////////////////////////////////////////////////////////// + // Layout change for fatlink, fatlink_eps, longlink, longlink_eps // + //////////////////////////////////////////////////////////////////// + + reorderMILCtoQDP(qdp_fatlink, fatlink, V, gauge_site_size, gauge_param.cpu_prec, gauge_param.cuda_prec); + reorderMILCtoQDP(qdp_longlink, longlink, V, gauge_site_size, gauge_param.cpu_prec, gauge_param.cuda_prec); + + if (n_naiks > 1) { + reorderMILCtoQDP(qdp_fatlink_eps, fatlink_eps, V, gauge_site_size, gauge_param.cpu_prec, gauge_param.cuda_prec); + reorderMILCtoQDP(qdp_longlink_eps, longlink_eps, V, gauge_site_size, gauge_param.cpu_prec, gauge_param.cuda_prec); + } + + ////////////////////////////// + // Perform the verification // + ////////////////////////////// + + std::array res = {0., 0.}; + + // extra factor of 10 b/c the norm isn't normalized + double max_dev = 10. * getTolerance(prec); + + // Non-zero epsilon check + if (n_naiks > 1) { + for (int dir = 0; dir < 4; dir++) { + res[0] = std::max(res[0], + compare_floats_v2(fat_reflink_eps[dir], qdp_fatlink_eps[dir], V * gauge_site_size, max_dev, + gauge_param.cpu_prec)); + } + + strong_check_link(qdp_fatlink_eps, "Fat link GPU results: ", fat_reflink_eps, "CPU reference results:", V, + gauge_param.cpu_prec); + + for (int dir = 0; dir < 4; ++dir) { + res[1] = std::max(res[1], + compare_floats_v2(long_reflink_eps[dir], qdp_longlink_eps[dir], V * gauge_site_size, max_dev, + gauge_param.cpu_prec)); + } + + strong_check_link(qdp_longlink_eps, "Long link GPU results: ", long_reflink_eps, "CPU reference results:", V, + gauge_param.cpu_prec); + } else { + for (int dir = 0; dir < 4; dir++) { + res[0] = std::max( + res[0], + compare_floats_v2(fat_reflink[dir], qdp_fatlink[dir], V * gauge_site_size, max_dev, gauge_param.cpu_prec)); + } + + strong_check_link(qdp_fatlink, "Fat link GPU results: ", fat_reflink, "CPU reference results:", V, + gauge_param.cpu_prec); + + for (int dir = 0; dir < 4; ++dir) { + res[1] = std::max( + res[1], + compare_floats_v2(long_reflink[dir], qdp_longlink[dir], V * gauge_site_size, max_dev, gauge_param.cpu_prec)); + } + + strong_check_link(qdp_longlink, "Long link GPU results: ", long_reflink, "CPU reference results:", V, + gauge_param.cpu_prec); + } + + printfQuda("Fat link test %s\n", (res[0] < max_dev) ? "PASSED" : "FAILED"); + printfQuda("Long link test %s\n", (res[1] < max_dev) ? "PASSED" : "FAILED"); + + return res; + } +}; diff --git a/tests/host_reference/dslash_reference.cpp b/tests/host_reference/dslash_reference.cpp index 4edc471143..bb7efa83a4 100644 --- a/tests/host_reference/dslash_reference.cpp +++ b/tests/host_reference/dslash_reference.cpp @@ -743,78 +743,183 @@ double verifyWilsonTypeSingularVector(void *spinor_left, void *spinor_right, dou return l2r; } -double verifyStaggeredInversion(quda::ColorSpinorField &tmp, quda::ColorSpinorField &ref, quda::ColorSpinorField &in, - quda::ColorSpinorField &out, double mass, quda::GaugeField &fatlink, - quda::GaugeField &longlink, QudaGaugeParam &gauge_param, QudaInvertParam &inv_param, - int shift) +std::array verifyStaggeredInversion(quda::ColorSpinorField &in, quda::ColorSpinorField &out, + quda::GaugeField &fat_link, quda::GaugeField &long_link, + QudaInvertParam &inv_param) { - void *qdp_fatlink[] = {fatlink.data(0), fatlink.data(1), fatlink.data(2), fatlink.data(3)}; - void *qdp_longlink[] = {longlink.data(0), longlink.data(1), longlink.data(2), longlink.data(3)}; - void *ghost_fatlink[] - = {fatlink.Ghost()[0].data(), fatlink.Ghost()[1].data(), fatlink.Ghost()[2].data(), fatlink.Ghost()[3].data()}; - void *ghost_longlink[] - = {longlink.Ghost()[0].data(), longlink.Ghost()[1].data(), longlink.Ghost()[2].data(), longlink.Ghost()[3].data()}; - - switch (test_type) { - case 0: // full parity solution, full parity system - case 1: // full parity solution, solving EVEN EVEN prec system - case 2: // full parity solution, solving ODD ODD prec system - - // In QUDA, the full staggered operator has the sign convention - // {{m, -D_eo},{-D_oe,m}}, while the CPU verify function does not - // have the minus sign. Passing in QUDA_DAG_YES solves this - // discrepancy. - staggeredDslash(ref.Even(), qdp_fatlink, qdp_longlink, ghost_fatlink, ghost_longlink, out.Odd(), QUDA_EVEN_PARITY, - QUDA_DAG_YES, inv_param.cpu_prec, gauge_param.cpu_prec, dslash_type); - staggeredDslash(ref.Odd(), qdp_fatlink, qdp_longlink, ghost_fatlink, ghost_longlink, out.Even(), QUDA_ODD_PARITY, - QUDA_DAG_YES, inv_param.cpu_prec, gauge_param.cpu_prec, dslash_type); - - if (dslash_type == QUDA_LAPLACE_DSLASH) { - xpay(out.data(), kappa, ref.data(), ref.Length(), gauge_param.cpu_prec); - ax(0.5 / kappa, ref.data(), ref.Length(), gauge_param.cpu_prec); - } else { - axpy(2 * mass, out.data(), ref.data(), ref.Length(), gauge_param.cpu_prec); + std::vector out_vector(1); + out_vector[0] = out; + return verifyStaggeredInversion(in, out_vector, fat_link, long_link, inv_param); +} + +std::array verifyStaggeredInversion(quda::ColorSpinorField &in, + std::vector &out_vector, + quda::GaugeField &fat_link, quda::GaugeField &long_link, + QudaInvertParam &inv_param) +{ + int dagger = inv_param.dagger == QUDA_DAG_YES ? 1 : 0; + double l2r_max = 0.0; + double hqr_max = 0.0; + + // Create temporary spinors + quda::ColorSpinorParam csParam(in); + quda::ColorSpinorField ref(csParam); + + if (multishift > 1) { + if (dslash_type == QUDA_LAPLACE_DSLASH) errorQuda("Multishift solves do not support the laplace operator (yet)"); + + if (inv_param.solution_type != QUDA_MATPC_SOLUTION) + errorQuda("Invalid staggered multishift solution type %d, expected QUDA_MATPC_SOLUTION", inv_param.solution_type); + + // Check the mat_pc type and make sure it's sane + QudaParity parity = QUDA_INVALID_PARITY; + switch (inv_param.matpc_type) { + case QUDA_MATPC_EVEN_EVEN: parity = QUDA_EVEN_PARITY; break; + case QUDA_MATPC_ODD_ODD: parity = QUDA_ODD_PARITY; break; + default: errorQuda("Unexpected matpc_type %s", get_matpc_str(inv_param.matpc_type)); break; } - break; - case 3: // even parity solution, solving EVEN system - case 4: // odd parity solution, solving ODD system - case 5: // multi mass CG, even parity solution, solving EVEN system - case 6: // multi mass CG, odd parity solution, solving ODD system + for (int i = 0; i < multishift; i++) { + auto &out = out_vector[i]; + double mass = 0.5 * sqrt(inv_param.offset[i]); + stag_matpc(ref, fat_link, long_link, out, mass, 0, parity, dslash_type); - staggeredMatDagMat(ref, qdp_fatlink, qdp_longlink, ghost_fatlink, ghost_longlink, out, mass, 0, inv_param.cpu_prec, - gauge_param.cpu_prec, tmp, - (test_type == 3 || test_type == 5) ? QUDA_EVEN_PARITY : QUDA_ODD_PARITY, dslash_type); - break; - } + mxpy(in.data(), ref.data(), in.Volume() * stag_spinor_site_size, inv_param.cpu_prec); + double nrm2 = norm_2(ref.data(), ref.Volume() * stag_spinor_site_size, inv_param.cpu_prec); + double src2 = norm_2(in.data(), in.Volume() * stag_spinor_site_size, inv_param.cpu_prec); + double hqr = sqrt(quda::blas::HeavyQuarkResidualNorm(out, ref).z); + double l2r = sqrt(nrm2 / src2); + + printfQuda("%dth solution: mass=%f, ", i, mass); + printfQuda("Shift %2d residuals: (L2 relative) tol %9.6e, QUDA = %9.6e, host = %9.6e; (heavy-quark) tol %9.6e, " + "QUDA = %9.6e, host = %9.6e\n", + i, inv_param.tol_offset[i], inv_param.true_res_offset[i], l2r, inv_param.tol_hq_offset[i], + inv_param.true_res_hq_offset[i], hqr); + // Empirical: if the cpu residue is more than 1 order the target accuracy, then it fails to converge + if (sqrt(nrm2 / src2) > 10 * inv_param.tol_offset[i]) { + printfQuda("Shift %2d has empirically failed to converge\n", i); + } + + l2r_max = std::max(l2r_max, l2r); + hqr_max = std::max(hqr_max, hqr); + } - int len = 0; - if (solution_type == QUDA_MAT_SOLUTION || solution_type == QUDA_MATDAG_MAT_SOLUTION) { - len = V; } else { - len = Vh; - } + auto &out = out_vector[0]; + double mass = inv_param.mass; + if (inv_param.solution_type == QUDA_MAT_SOLUTION) { + stag_mat(ref, fat_link, long_link, out, mass, dagger, dslash_type); - mxpy(in.data(), ref.data(), len * stag_spinor_site_size, inv_param.cpu_prec); - double nrm2 = norm_2(ref.data(), len * stag_spinor_site_size, inv_param.cpu_prec); - double src2 = norm_2(in.data(), len * stag_spinor_site_size, inv_param.cpu_prec); - double hqr = sqrt(quda::blas::HeavyQuarkResidualNorm(out, ref).z); - double l2r = sqrt(nrm2 / src2); + // correct for the massRescale function inside invertQuda + if (is_laplace(dslash_type)) ax(0.5 / kappa, ref.data(), ref.Length(), ref.Precision()); + } else if (inv_param.solution_type == QUDA_MATPC_SOLUTION) { + QudaParity parity = QUDA_INVALID_PARITY; + switch (inv_param.matpc_type) { + case QUDA_MATPC_EVEN_EVEN: parity = QUDA_EVEN_PARITY; break; + case QUDA_MATPC_ODD_ODD: parity = QUDA_ODD_PARITY; break; + default: errorQuda("Unexpected matpc_type %s", get_matpc_str(inv_param.matpc_type)); break; + } + stag_matpc(ref, fat_link, long_link, out, mass, 0, parity, dslash_type); + } else if (inv_param.solution_type == QUDA_MATDAG_MAT_SOLUTION) { + stag_matdag_mat(ref, fat_link, long_link, out, mass, dagger, dslash_type); + } else { + errorQuda("Invalid staggered solution type %d", inv_param.solution_type); + } + + mxpy(in.data(), ref.data(), in.Volume() * stag_spinor_site_size, inv_param.cpu_prec); + double nrm2 = norm_2(ref.data(), ref.Volume() * stag_spinor_site_size, inv_param.cpu_prec); + double src2 = norm_2(in.data(), in.Volume() * stag_spinor_site_size, inv_param.cpu_prec); + double hqr = sqrt(quda::blas::HeavyQuarkResidualNorm(out, ref).z); + double l2r = sqrt(nrm2 / src2); - if (multishift == 1) { printfQuda("Residuals: (L2 relative) tol %9.6e, QUDA = %9.6e, host = %9.6e; (heavy-quark) tol %9.6e, QUDA = %9.6e, " "host = %9.6e\n", inv_param.tol, inv_param.true_res, l2r, inv_param.tol_hq, inv_param.true_res_hq, hqr); + + l2r_max = l2r; + hqr_max = hqr; + } + + return {l2r_max, hqr_max}; +} + +double verifyStaggeredTypeEigenvector(quda::ColorSpinorField &spinor, double _Complex lambda, int i, + QudaEigParam &eig_param, quda::GaugeField &fat_link, quda::GaugeField &long_link) +{ + QudaInvertParam &inv_param = *(eig_param.invert_param); + int dagger = inv_param.dagger == QUDA_DAG_YES ? 1 : 0; + bool use_pc = (eig_param.use_pc == QUDA_BOOLEAN_TRUE ? true : false); + bool normop = (eig_param.use_norm_op == QUDA_BOOLEAN_TRUE ? true : false); + double mass = inv_param.mass; + + // Reverse engineer a "solution_type" to help determine which host dslash needs to be applied + QudaSolutionType sol_type = QUDA_INVALID_SOLUTION; + if (normop) { + if (use_pc) + errorQuda("The normal preconditioned staggered op is not supported"); + else + sol_type = QUDA_MATDAG_MAT_SOLUTION; } else { - printfQuda("Shift %2d residuals: (L2 relative) tol %9.6e, QUDA = %9.6e, host = %9.6e; (heavy-quark) tol %9.6e, " - "QUDA = %9.6e, host = %9.6e\n", - shift, inv_param.tol_offset[shift], inv_param.true_res_offset[shift], l2r, - inv_param.tol_hq_offset[shift], inv_param.true_res_hq_offset[shift], hqr); - // Empirical: if the cpu residue is more than 1 order the target accuracy, then it fails to converge - if (sqrt(nrm2 / src2) > 10 * inv_param.tol_offset[shift]) { - printfQuda("Shift %2d has empirically failed to converge\n", shift); + if (use_pc) + sol_type = QUDA_MATPC_SOLUTION; + else + sol_type = QUDA_MAT_SOLUTION; + } + + // Create temporary spinors + quda::ColorSpinorParam csParam(spinor); + quda::ColorSpinorField ref(csParam); + + if (sol_type == QUDA_MAT_SOLUTION) { + stag_mat(ref, fat_link, long_link, spinor, mass, dagger, dslash_type); + } else if (sol_type == QUDA_MATPC_SOLUTION) { + QudaParity parity = QUDA_INVALID_PARITY; + switch (inv_param.matpc_type) { + case QUDA_MATPC_EVEN_EVEN: parity = QUDA_EVEN_PARITY; break; + case QUDA_MATPC_ODD_ODD: parity = QUDA_ODD_PARITY; break; + default: errorQuda("Unexpected matpc_type %s", get_matpc_str(inv_param.matpc_type)); break; } + stag_matpc(ref, fat_link, long_link, spinor, mass, 0, parity, dslash_type); + } else if (sol_type == QUDA_MATDAG_MAT_SOLUTION) { + stag_matdag_mat(ref, fat_link, long_link, spinor, mass, dagger, dslash_type); } + // Compute M * x - \lambda * x + caxpy(-lambda, spinor.data(), ref.data(), spinor.Volume() * stag_spinor_site_size, inv_param.cpu_prec); + double nrm2 = norm_2(ref.data(), ref.Volume() * stag_spinor_site_size, inv_param.cpu_prec); + double src2 = norm_2(spinor.data(), spinor.Volume() * stag_spinor_site_size, inv_param.cpu_prec); + double l2r = sqrt(nrm2 / src2); + + printfQuda("Eigenvector %4d: tol %.2e, host residual = %.15e\n", i, eig_param.tol, l2r); + + return l2r; +} + +double verifyStaggeredTypeSingularVector(quda::ColorSpinorField &spinor_left, quda::ColorSpinorField &spinor_right, + double _Complex sigma, int i, QudaEigParam &eig_param, + quda::GaugeField &fat_link, quda::GaugeField &long_link) +{ + QudaInvertParam &inv_param = *(eig_param.invert_param); + int dagger = inv_param.dagger == QUDA_DAG_YES ? 1 : 0; + bool use_pc = (eig_param.use_pc == QUDA_BOOLEAN_TRUE ? true : false); + double mass = inv_param.mass; + + if (use_pc) errorQuda("The SVD of the preconditioned staggered op is not supported"); + + // Create temporary spinors + quda::ColorSpinorParam csParam(spinor_left); + quda::ColorSpinorField ref(csParam); + + // Only `mat` is used here + stag_mat(ref, fat_link, long_link, spinor_left, mass, dagger, dslash_type); + + // Compute M * x_left - \sigma * x_right + caxpy(-sigma, spinor_right.data(), ref.data(), spinor_right.Volume() * stag_spinor_site_size, inv_param.cpu_prec); + double nrm2 = norm_2(ref.data(), ref.Volume() * stag_spinor_site_size, inv_param.cpu_prec); + double src2 = norm_2(spinor_left.data(), spinor_left.Volume() * stag_spinor_site_size, inv_param.cpu_prec); + double l2r = sqrt(nrm2 / src2); + + printfQuda("Singular vector pair %4d: tol %.2e, host residual = %.15e\n", i, eig_param.tol, l2r); + return l2r; } diff --git a/tests/host_reference/dslash_reference.h b/tests/host_reference/dslash_reference.h index 82745008fc..c464836f71 100644 --- a/tests/host_reference/dslash_reference.h +++ b/tests/host_reference/dslash_reference.h @@ -109,10 +109,65 @@ std::array verifyWilsonTypeInversion(void *spinorOut, void **spinorOu void *spinorCheck, QudaGaugeParam &gauge_param, QudaInvertParam &inv_param, void **gauge, void *clover, void *clover_inv); -double verifyStaggeredInversion(quda::ColorSpinorField &tmp, quda::ColorSpinorField &ref, quda::ColorSpinorField &in, - quda::ColorSpinorField &out, double mass, quda::GaugeField &fatlink, - quda::GaugeField &longlink, QudaGaugeParam &gauge_param, QudaInvertParam &inv_param, - int shift); +/** + * @brief Verify a staggered inversion on the host. This version is a thin wrapper around a version that takes + * an array of outputs as is necessary for handling both single- and multi-shift solves. + * + * @param in The initial rhs + * @param out The solution to A out = in + * @param fat_link The fat links in the context of an ASQTAD solve; otherwise the base gauge links with phases applied + * @param long_link The long links; null for naive staggered and Laplace + * @param inv_param Invert params, used to query the solve type, etc + * @return The residual and HQ residual (if requested) + */ +std::array verifyStaggeredInversion(quda::ColorSpinorField &in, quda::ColorSpinorField &out, + quda::GaugeField &fat_link, quda::GaugeField &long_link, + QudaInvertParam &inv_param); + +/** + * @brief Verify a single- or multi-shift staggered inversion on the host + * + * @param in The initial rhs + * @param out The solutions to (A + shift) out = in for multiple shifts; shift == 0 for a single shift solve + * @param fat_link The fat links in the context of an ASQTAD solve; otherwise the base gauge links with phases applied + * @param long_link The long links; null for naive staggered and Laplace + * @param inv_param Invert params, used to query the solve type, etc, also includes the shifts + * @return The residual and HQ residual (if requested) + */ +std::array verifyStaggeredInversion(quda::ColorSpinorField &in, + std::vector &out_vector, + quda::GaugeField &fat_link, quda::GaugeField &long_link, + QudaInvertParam &inv_param); + +/** + * @brief Verify a staggered-type eigenvector + * + * @param spinor The host eigenvector to be verified + * @param lambda The host eigenvalue to be verified + * @param i The number of the eigenvalue, only used when printing outputs + * @param eig_param Eigensolve params, used to query the operator type, etc + * @param fat_link The fat links in the context of an ASQTAD solve; otherwise the base gauge links with phases applied + * @param long_link The long links; null for naive staggered and Laplace + * @return The residual norm + */ +double verifyStaggeredTypeEigenvector(quda::ColorSpinorField &spinor, double _Complex lambda, int i, + QudaEigParam &eig_param, quda::GaugeField &fat_link, quda::GaugeField &long_link); + +/** + * @brief Verify a staggered-type singular vector + * + * @param spinor The host left singular vector to be verified + * @param spinor_right The host right singular vector to be verified + * @param lambda The host singular value to be verified + * @param i The number of the singular value, only used when printing outputs + * @param eig_param Eigensolve params, used to query the operator type, etc + * @param fat_link The fat links in the context of an ASQTAD solve; otherwise the base gauge links with phases applied + * @param long_link The long links; null for naive staggered and Laplace + * @return The residual norm + */ +double verifyStaggeredTypeSingularVector(quda::ColorSpinorField &spinor_left, quda::ColorSpinorField &spinor_right, + double _Complex sigma, int i, QudaEigParam &eig_param, + quda::GaugeField &fat_link, quda::GaugeField &long_link); // i represents a "half index" into an even or odd "half lattice". // when oddBit={0,1} the half lattice is {even,odd}. diff --git a/tests/host_reference/staggered_dslash_reference.cpp b/tests/host_reference/staggered_dslash_reference.cpp index 14852a9f22..bf6bcf8b92 100644 --- a/tests/host_reference/staggered_dslash_reference.cpp +++ b/tests/host_reference/staggered_dslash_reference.cpp @@ -11,6 +11,7 @@ #include #include "misc.h" #include +#include #include @@ -32,25 +33,25 @@ template void display_link_internal(Float *link) // if oddBit is one: calculate odd parity spinor elements // if daggerBit is zero: perform ordinary dslash operator // if daggerBit is one: perform hermitian conjugate of dslash -template +template #ifdef MULTI_GPU -void staggeredDslashReference(sFloat *res, gFloat **fatlink, gFloat **longlink, gFloat **ghostFatlink, - gFloat **ghostLonglink, sFloat *spinorField, sFloat **fwd_nbr_spinor, - sFloat **back_nbr_spinor, int oddBit, int daggerBit, QudaDslashType dslash_type) +void staggeredDslashReference(real_t *res, real_t **fatlink, real_t **longlink, real_t **ghostFatlink, + real_t **ghostLonglink, real_t *spinorField, real_t **fwd_nbr_spinor, + real_t **back_nbr_spinor, int oddBit, int daggerBit, QudaDslashType dslash_type) #else -void staggeredDslashReference(sFloat *res, gFloat **fatlink, gFloat **longlink, gFloat **, gFloat **, sFloat *spinorField, - sFloat **, sFloat **, int oddBit, int daggerBit, QudaDslashType dslash_type) +void staggeredDslashReference(real_t *res, real_t **fatlink, real_t **longlink, real_t **, real_t **, real_t *spinorField, + real_t **, real_t **, int oddBit, int daggerBit, QudaDslashType dslash_type) #endif { #pragma omp parallel for for (auto i = 0lu; i < Vh * stag_spinor_site_size; i++) res[i] = 0.0; - gFloat *fatlinkEven[4], *fatlinkOdd[4]; - gFloat *longlinkEven[4], *longlinkOdd[4]; + real_t *fatlinkEven[4], *fatlinkOdd[4]; + real_t *longlinkEven[4], *longlinkOdd[4]; #ifdef MULTI_GPU - gFloat *ghostFatlinkEven[4], *ghostFatlinkOdd[4]; - gFloat *ghostLonglinkEven[4], *ghostLonglinkOdd[4]; + real_t *ghostFatlinkEven[4], *ghostFatlinkOdd[4]; + real_t *ghostLonglinkEven[4], *ghostLonglinkOdd[4]; #endif for (int dir = 0; dir < 4; dir++) { @@ -74,28 +75,28 @@ void staggeredDslashReference(sFloat *res, gFloat **fatlink, gFloat **longlink, for (int dir = 0; dir < 8; dir++) { #ifdef MULTI_GPU const int nFace = dslash_type == QUDA_ASQTAD_DSLASH ? 3 : 1; - gFloat *fatlnk + real_t *fatlnk = gaugeLink_mg4dir(sid, dir, oddBit, fatlinkEven, fatlinkOdd, ghostFatlinkEven, ghostFatlinkOdd, 1, 1); - gFloat *longlnk = dslash_type == QUDA_ASQTAD_DSLASH ? + real_t *longlnk = dslash_type == QUDA_ASQTAD_DSLASH ? gaugeLink_mg4dir(sid, dir, oddBit, longlinkEven, longlinkOdd, ghostLonglinkEven, ghostLonglinkOdd, 3, 3) : nullptr; - sFloat *first_neighbor_spinor = spinorNeighbor_5d_mgpu( + real_t *first_neighbor_spinor = spinorNeighbor_5d_mgpu( sid, dir, oddBit, spinorField, fwd_nbr_spinor, back_nbr_spinor, 1, nFace, stag_spinor_site_size); - sFloat *third_neighbor_spinor = dslash_type == QUDA_ASQTAD_DSLASH ? + real_t *third_neighbor_spinor = dslash_type == QUDA_ASQTAD_DSLASH ? spinorNeighbor_5d_mgpu(sid, dir, oddBit, spinorField, fwd_nbr_spinor, back_nbr_spinor, 3, nFace, stag_spinor_site_size) : nullptr; #else - gFloat *fatlnk = gaugeLink(sid, dir, oddBit, fatlinkEven, fatlinkOdd, 1); - gFloat *longlnk + real_t *fatlnk = gaugeLink(sid, dir, oddBit, fatlinkEven, fatlinkOdd, 1); + real_t *longlnk = dslash_type == QUDA_ASQTAD_DSLASH ? gaugeLink(sid, dir, oddBit, longlinkEven, longlinkOdd, 3) : nullptr; - sFloat *first_neighbor_spinor + real_t *first_neighbor_spinor = spinorNeighbor_5d(sid, dir, oddBit, spinorField, 1, stag_spinor_site_size); - sFloat *third_neighbor_spinor = dslash_type == QUDA_ASQTAD_DSLASH ? + real_t *third_neighbor_spinor = dslash_type == QUDA_ASQTAD_DSLASH ? spinorNeighbor_5d(sid, dir, oddBit, spinorField, 3, stag_spinor_site_size) : nullptr; #endif - sFloat gaugedSpinor[stag_spinor_site_size]; + real_t gaugedSpinor[stag_spinor_site_size]; if (dir % 2 == 0) { su3Mul(gaugedSpinor, fatlnk, first_neighbor_spinor); @@ -124,10 +125,18 @@ void staggeredDslashReference(sFloat *res, gFloat **fatlink, gFloat **longlink, } // 4-d volume } -void staggeredDslash(ColorSpinorField &out, void *const *fatlink, void *const *longlink, void *const *ghost_fatlink, - void *const *ghost_longlink, const ColorSpinorField &in, int oddBit, int daggerBit, - QudaPrecision sPrecision, QudaPrecision gPrecision, QudaDslashType dslash_type) +void stag_dslash(ColorSpinorField &out, const GaugeField &fat_link, const GaugeField &long_link, + const ColorSpinorField &in, int oddBit, int daggerBit, QudaDslashType dslash_type) { + // assert sPrecision and gPrecision must be the same + if (in.Precision() != fat_link.Precision()) { + errorQuda("The spinor precision and gauge precision are not the same"); + } + + // assert we have single-parity spinors + if (out.SiteSubset() != QUDA_PARITY_SITE_SUBSET || in.SiteSubset() != QUDA_PARITY_SITE_SUBSET) + errorQuda("Unexpected site subsets for stag_dslash, out %d in %d", out.SiteSubset(), in.SiteSubset()); + QudaParity otherparity = QUDA_INVALID_PARITY; if (oddBit == QUDA_EVEN_PARITY) { otherparity = QUDA_ODD_PARITY; @@ -140,39 +149,85 @@ void staggeredDslash(ColorSpinorField &out, void *const *fatlink, void *const *l in.exchangeGhost(otherparity, nFace, daggerBit); - auto fwd_nbr_spinor = in.fwdGhostFaceBuffer; - auto back_nbr_spinor = in.backGhostFaceBuffer; - - if (sPrecision == QUDA_DOUBLE_PRECISION) { - if (gPrecision == QUDA_DOUBLE_PRECISION) { - staggeredDslashReference((double *)out.data(), (double **)fatlink, (double **)longlink, (double **)ghost_fatlink, - (double **)ghost_longlink, (double *)in.data(), (double **)fwd_nbr_spinor, - (double **)back_nbr_spinor, oddBit, daggerBit, dslash_type); - } else { - staggeredDslashReference((double *)out.data(), (float **)fatlink, (float **)longlink, (float **)ghost_fatlink, - (float **)ghost_longlink, (double *)in.data(), (double **)fwd_nbr_spinor, - (double **)back_nbr_spinor, oddBit, daggerBit, dslash_type); - } + void *qdp_fatlink[] = {fat_link.data(0), fat_link.data(1), fat_link.data(2), fat_link.data(3)}; + void *qdp_longlink[] = {long_link.data(0), long_link.data(1), long_link.data(2), long_link.data(3)}; + void *ghost_fatlink[] + = {fat_link.Ghost()[0].data(), fat_link.Ghost()[1].data(), fat_link.Ghost()[2].data(), fat_link.Ghost()[3].data()}; + void *ghost_longlink[] = {long_link.Ghost()[0].data(), long_link.Ghost()[1].data(), long_link.Ghost()[2].data(), + long_link.Ghost()[3].data()}; + + if (in.Precision() == QUDA_DOUBLE_PRECISION) { + staggeredDslashReference(static_cast(out.data()), reinterpret_cast(qdp_fatlink), + reinterpret_cast(qdp_longlink), reinterpret_cast(ghost_fatlink), + reinterpret_cast(ghost_longlink), static_cast(in.data()), + reinterpret_cast(in.fwdGhostFaceBuffer), + reinterpret_cast(in.backGhostFaceBuffer), oddBit, daggerBit, dslash_type); + } else if (in.Precision() == QUDA_SINGLE_PRECISION) { + staggeredDslashReference(static_cast(out.data()), reinterpret_cast(qdp_fatlink), + reinterpret_cast(qdp_longlink), reinterpret_cast(ghost_fatlink), + reinterpret_cast(ghost_longlink), static_cast(in.data()), + reinterpret_cast(in.fwdGhostFaceBuffer), + reinterpret_cast(in.backGhostFaceBuffer), oddBit, daggerBit, dslash_type); + } +} + +void stag_mat(ColorSpinorField &out, const GaugeField &fat_link, const GaugeField &long_link, + const ColorSpinorField &in, double mass, int daggerBit, QudaDslashType dslash_type) +{ + // assert sPrecision and gPrecision must be the same + if (in.Precision() != fat_link.Precision()) { + errorQuda("The spinor precision and gauge precision are not the same"); + } + + // assert we have full-parity spinors + if (out.SiteSubset() != QUDA_FULL_SITE_SUBSET || in.SiteSubset() != QUDA_FULL_SITE_SUBSET) + errorQuda("Unexpected site subsets for stag_mat, out %d in %d", out.SiteSubset(), in.SiteSubset()); + + // In QUDA, the full staggered operator has the sign convention + // {{m, -D_eo},{-D_oe,m}}, while the CPU verify function does not + // have the minus sign. Inverting the expected dagger convention + // solves this discrepancy. + stag_dslash(out.Even(), fat_link, long_link, in.Odd(), QUDA_EVEN_PARITY, 1 - daggerBit, dslash_type); + stag_dslash(out.Odd(), fat_link, long_link, in.Even(), QUDA_ODD_PARITY, 1 - daggerBit, dslash_type); + + if (dslash_type == QUDA_LAPLACE_DSLASH) { + double kappa = 1.0 / (8 + mass); + xpay(in.data(), kappa, out.data(), out.Length(), out.Precision()); } else { - if (gPrecision == QUDA_DOUBLE_PRECISION) { - staggeredDslashReference((float *)out.data(), (double **)fatlink, (double **)longlink, (double **)ghost_fatlink, - (double **)ghost_longlink, (float *)in.data(), (float **)fwd_nbr_spinor, - (float **)back_nbr_spinor, oddBit, daggerBit, dslash_type); - } else { - staggeredDslashReference((float *)out.data(), (float **)fatlink, (float **)longlink, (float **)ghost_fatlink, - (float **)ghost_longlink, (float *)in.data(), (float **)fwd_nbr_spinor, - (float **)back_nbr_spinor, oddBit, daggerBit, dslash_type); - } + axpy(2 * mass, in.data(), out.data(), out.Length(), out.Precision()); } } -void staggeredMatDagMat(ColorSpinorField &out, void *const *fatlink, void *const *longlink, void *const *ghost_fatlink, - void *const *ghost_longlink, const ColorSpinorField &in, double mass, int dagger_bit, - QudaPrecision sPrecision, QudaPrecision gPrecision, ColorSpinorField &tmp, QudaParity parity, - QudaDslashType dslash_type) +void stag_matdag_mat(ColorSpinorField &out, const GaugeField &fat_link, const GaugeField &long_link, + const ColorSpinorField &in, double mass, int daggerBit, QudaDslashType dslash_type) { // assert sPrecision and gPrecision must be the same - if (sPrecision != gPrecision) { errorQuda("Spinor precision and gPrecison is not the same"); } + if (in.Precision() != fat_link.Precision()) { + errorQuda("The spinor precision and gauge precision are not the same"); + } + + // assert we have full-parity spinors + if (out.SiteSubset() != QUDA_FULL_SITE_SUBSET || in.SiteSubset() != QUDA_FULL_SITE_SUBSET) + errorQuda("Unexpected site subsets for stag_matdagmat, out %d in %d", out.SiteSubset(), in.SiteSubset()); + + // Create temporary spinors + quda::ColorSpinorParam csParam(in); + quda::ColorSpinorField tmp(csParam); + + // Apply mat in sequence + stag_mat(tmp, fat_link, long_link, in, mass, daggerBit, dslash_type); + stag_mat(out, fat_link, long_link, tmp, mass, 1 - daggerBit, dslash_type); +} + +void stag_matpc(ColorSpinorField &out, const GaugeField &fat_link, const GaugeField &long_link, + const ColorSpinorField &in, double mass, int, QudaParity parity, QudaDslashType dslash_type) +{ + // assert sPrecision and gPrecision must be the same + if (in.Precision() != fat_link.Precision()) { errorQuda("The spinor precision and gauge precison are not the same"); } + + // assert we have single-parity spinors + if (out.SiteSubset() != QUDA_PARITY_SITE_SUBSET || in.SiteSubset() != QUDA_PARITY_SITE_SUBSET) + errorQuda("Unexpected site subsets for stag_matpc, out %d in %d", out.SiteSubset(), in.SiteSubset()); QudaParity otherparity = QUDA_INVALID_PARITY; if (parity == QUDA_EVEN_PARITY) { @@ -183,16 +238,19 @@ void staggeredMatDagMat(ColorSpinorField &out, void *const *fatlink, void *const errorQuda("full parity not supported in function"); } - staggeredDslash(tmp, fatlink, longlink, ghost_fatlink, ghost_longlink, in, otherparity, dagger_bit, sPrecision, - gPrecision, dslash_type); + // Create temporary spinors + quda::ColorSpinorParam csParam(in); + quda::ColorSpinorField tmp(csParam); - staggeredDslash(out, fatlink, longlink, ghost_fatlink, ghost_longlink, tmp, parity, dagger_bit, sPrecision, - gPrecision, dslash_type); + // dagger bit does not matter + stag_dslash(tmp, fat_link, long_link, in, otherparity, 0, dslash_type); + stag_dslash(out, fat_link, long_link, tmp, parity, 0, dslash_type); double msq_x4 = mass * mass * 4; - if (sPrecision == QUDA_DOUBLE_PRECISION) { - axmy((double *)in.data(), (double)msq_x4, (double *)out.data(), Vh * stag_spinor_site_size); + if (in.Precision() == QUDA_DOUBLE_PRECISION) { + axmy(static_cast(in.data()), msq_x4, static_cast(out.data()), Vh * stag_spinor_site_size); } else { - axmy((float *)in.data(), (float)msq_x4, (float *)out.data(), Vh * stag_spinor_site_size); + axmy(static_cast(in.data()), static_cast(msq_x4), static_cast(out.data()), + Vh * stag_spinor_site_size); } } diff --git a/tests/host_reference/staggered_dslash_reference.h b/tests/host_reference/staggered_dslash_reference.h index 2d47138dc0..b39287bfb1 100644 --- a/tests/host_reference/staggered_dslash_reference.h +++ b/tests/host_reference/staggered_dslash_reference.h @@ -11,16 +11,80 @@ using namespace quda; void setDims(int *); -template -void staggeredDslashReference(sFloat *res, gFloat **fatlink, gFloat **longlink, gFloat **ghostFatlink, - gFloat **ghostLonglink, sFloat *spinorField, sFloat **fwd_nbr_spinor, - sFloat **back_nbr_spinor, int oddBit, int daggerBit, int nSrc, QudaDslashType dslash_type); - -void staggeredDslash(ColorSpinorField &out, void *const *fatlink, void *const *longlink, void *const *ghost_fatlink, - void *const *ghost_longlink, const ColorSpinorField &in, int oddBit, int daggerBit, - QudaPrecision sPrecision, QudaPrecision gPrecision, QudaDslashType dslash_type); - -void staggeredMatDagMat(ColorSpinorField &out, void *const *fatlink, void *const *longlink, void *const *ghost_fatlink, - void *const *ghost_longlink, const ColorSpinorField &in, double mass, int dagger_bit, - QudaPrecision sPrecision, QudaPrecision gPrecision, ColorSpinorField &tmp, QudaParity parity, - QudaDslashType dslash_type); +/** + * @brief Base host routine to apply the even-odd or odd-even component of a staggered-type dslash + * + * @tparam real_t Datatype used in the host dslash + * @param res Host output result + * @param fatlink Fat links for an asqtad dslash, or the gauge links for a staggered or Laplace dslash + * @param longlink Long links for an asqtad dslash, or an empty GaugeField for staggered or Laplace dslash + * @param ghostFatlink Ghost zones for the host fat links + * @param ghostLonglink Ghost zones for the host long links + * @param spinorField Host input spinor + * @param fwd_nbr_spinor Forward ghost zones for the host input spinor + * @param back_nbr_spinor Backwards ghost zones for the host input spinor + * @param oddBit 0 for D_eo, 1 for D_oe + * @param daggerBit 0 for the regular operator, 1 for the dagger operator + * @param dslash_type Dslash type + */ +template +void staggeredDslashReference(real_t *res, real_t **fatlink, real_t **longlink, real_t **ghostFatlink, + real_t **ghostLonglink, real_t *spinorField, real_t **fwd_nbr_spinor, + real_t **back_nbr_spinor, int oddBit, int daggerBit, QudaDslashType dslash_type); + +/** + * @brief Apply even-odd or odd-even component of a staggered-type dslash + * + * @param out Host output rhs + * @param fat_link Fat links for an asqtad dslash, or the gauge links for a staggered or Laplace dslash + * @param long_link Long links for an asqtad dslash, or an empty GaugeField for staggered or Laplace dslash + * @param in Host input spinor + * @param oddBit 0 for D_eo, 1 for D_oe + * @param daggerBit 0 for the regular operator, 1 for the dagger operator + * @param dslash_type Dslash type + */ +void stag_dslash(ColorSpinorField &out, const GaugeField &fat_link, const GaugeField &long_link, + const ColorSpinorField &in, int oddBit, int daggerBit, QudaDslashType dslash_type); + +/** + * @brief Apply the full parity staggered-type dslash + * + * @param out Host output rhs + * @param fat_link Fat links for an asqtad dslash, or the gauge links for a staggered or Laplace dslash + * @param long_link Long links for an asqtad dslash, or an empty GaugeField for staggered or Laplace dslash + * @param in Host input spinor + * @param mass Mass for the dslash operator + * @param daggerBit 0 for the regular operator, 1 for the dagger operator + * @param dslash_type Dslash type + */ +void stag_mat(ColorSpinorField &out, const GaugeField &fat_link, const GaugeField &long_link, + const ColorSpinorField &in, double mass, int daggerBit, QudaDslashType dslash_type); + +/** + * @brief Apply the full parity staggered-type matdag_mat + * + * @param out Host output rhs + * @param fat_link Fat links for an asqtad dslash, or the gauge links for a staggered or Laplace dslash + * @param long_link Long links for an asqtad dslash, or an empty GaugeField for staggered or Laplace dslash + * @param in Host input spinor + * @param mass Mass for the dslash operator + * @param daggerBit 0 for the regular operator, 1 for the dagger operator + * @param dslash_type Dslash type + */ +void stag_matdag_mat(ColorSpinorField &out, const GaugeField &fat_link, const GaugeField &long_link, + const ColorSpinorField &in, double mass, int daggerBit, QudaDslashType dslash_type); + +/** + * @brief Apply the even-even or odd-odd preconditioned staggered dslash + * + * @param out Host output rhs + * @param fat_link Fat links for an asqtad dslash, or the gauge links for a staggered or Laplace dslash + * @param long_link Long links for an asqtad dslash, or an empty GaugeField for staggered or Laplace dslash + * @param in Host input spinor + * @param mass Mass for the dslash operator + * @param dagger_bit 0 for the regular operator, 1 for the dagger operator --- irrelevant for the HPD preconditioned operator + * @param parity Parity of preconditioned dslash + * @param dslash_type Dslash type + */ +void stag_matpc(ColorSpinorField &out, const GaugeField &fat_link, const GaugeField &long_link, + const ColorSpinorField &in, double mass, int dagger_bit, QudaParity parity, QudaDslashType dslash_type); diff --git a/tests/invert_test_gtest.hpp b/tests/invert_test_gtest.hpp index fbbf6aebe0..55c1c3f788 100644 --- a/tests/invert_test_gtest.hpp +++ b/tests/invert_test_gtest.hpp @@ -17,79 +17,11 @@ class InvertTest : public ::testing::TestWithParam InvertTest() : param(GetParam()) { } }; -bool is_normal_residual(QudaInverterType type) -{ - switch (type) { - case QUDA_CGNR_INVERTER: - case QUDA_CA_CGNR_INVERTER: return true; - default: return false; - } -} - -bool is_preconditioned_solve(QudaSolveType type) -{ - switch (type) { - case QUDA_DIRECT_PC_SOLVE: - case QUDA_NORMOP_PC_SOLVE: return true; - default: return false; - } -} - -bool is_full_solution(QudaSolutionType type) -{ - switch (type) { - case QUDA_MAT_SOLUTION: - case QUDA_MATDAG_MAT_SOLUTION: return true; - default: return false; - } -} - -bool is_normal_solve(test_t param) -{ - auto inv_type = ::testing::get<0>(param); - auto solve_type = ::testing::get<2>(param); - - switch (solve_type) { - case QUDA_NORMOP_SOLVE: - case QUDA_NORMOP_PC_SOLVE: return true; - default: - switch (inv_type) { - case QUDA_CGNR_INVERTER: - case QUDA_CGNE_INVERTER: - case QUDA_CA_CGNR_INVERTER: - case QUDA_CA_CGNE_INVERTER: return true; - default: return false; - } - } -} - -bool is_chiral(QudaDslashType type) -{ - switch (type) { - case QUDA_DOMAIN_WALL_DSLASH: - case QUDA_DOMAIN_WALL_4D_DSLASH: - case QUDA_MOBIUS_DWF_DSLASH: - case QUDA_MOBIUS_DWF_EOFA_DSLASH: return true; - default: return false; - } -} - -bool support_solution_accumulator_pipeline(QudaInverterType type) -{ - switch (type) { - case QUDA_CG_INVERTER: - case QUDA_CA_CG_INVERTER: - case QUDA_CGNR_INVERTER: - case QUDA_CGNE_INVERTER: - case QUDA_PCG_INVERTER: return true; - default: return false; - } -} - bool skip_test(test_t param) { auto inverter_type = ::testing::get<0>(param); auto solution_type = ::testing::get<1>(param); + auto solve_type = ::testing::get<2>(param); auto prec_sloppy = ::testing::get<3>(param); auto multishift = ::testing::get<4>(param); auto solution_accumulator_pipeline = ::testing::get<5>(param); @@ -103,7 +35,7 @@ bool skip_test(test_t param) if (prec_sloppy < prec_precondition) return true; // sloppy precision >= preconditioner precision // dwf-style solves must use a normal solver - if (is_chiral(dslash_type) && !is_normal_solve(param)) return true; + if (is_chiral(dslash_type) && !is_normal_solve(inverter_type, solve_type)) return true; // FIXME this needs to be added to dslash_reference.cpp if (is_chiral(dslash_type) && multishift > 1) return true; // FIXME this needs to be added to dslash_reference.cpp @@ -111,14 +43,14 @@ bool skip_test(test_t param) // Skip if the inverter does not support batched update and batched update is greater than one if (!support_solution_accumulator_pipeline(inverter_type) && solution_accumulator_pipeline > 1) return true; // MdagMLocal only support for Mobius at present - if (is_normal_solve(param) && ::testing::get<0>(schwarz_param) != QUDA_INVALID_SCHWARZ) { + if (is_normal_solve(inverter_type, solve_type) && ::testing::get<0>(schwarz_param) != QUDA_INVALID_SCHWARZ) { #ifdef QUDA_MMA_AVAILABLE if (dslash_type != QUDA_MOBIUS_DWF_DSLASH) return true; #else return true; #endif } - // split-grid doesn't support split-grid at present + // split-grid doesn't support multishift at present if (use_split_grid && multishift > 1) return true; return false; @@ -137,12 +69,7 @@ TEST_P(InvertTest, verify) if (res_t & QUDA_HEAVY_QUARK_RESIDUAL) inv_param.tol_hq = tol_hq; auto tol = inv_param.tol; - if (inv_param.dslash_type == QUDA_DOMAIN_WALL_DSLASH || - inv_param.dslash_type == QUDA_DOMAIN_WALL_4D_DSLASH || - inv_param.dslash_type == QUDA_MOBIUS_DWF_DSLASH || - inv_param.dslash_type == QUDA_MOBIUS_DWF_EOFA_DSLASH) { - tol *= std::sqrt(static_cast(inv_param.Ls)); - } + if (is_chiral(inv_param.dslash_type)) { tol *= std::sqrt(static_cast(inv_param.Ls)); } // FIXME eventually we should build in refinement to the *NR solvers to remove the need for this if (is_normal_residual(::testing::get<0>(GetParam()))) tol *= 50; // Slight loss of precision possible when reconstructing full solution diff --git a/tests/staggered_dslash_ctest.cpp b/tests/staggered_dslash_ctest.cpp index 78a735ad13..65edd69124 100644 --- a/tests/staggered_dslash_ctest.cpp +++ b/tests/staggered_dslash_ctest.cpp @@ -2,9 +2,6 @@ using namespace quda; -// For loading the gauge fields -int argc_copy; -char **argv_copy; bool ctest_all_partitions = false; using ::testing::Bool; @@ -23,25 +20,11 @@ class StaggeredDslashTest : public ::testing::TestWithParam<::testing::tuple(::testing::get<1>(GetParam())); if ((QUDA_PRECISION & getPrecision(::testing::get<0>(GetParam()))) == 0 - || (QUDA_RECONSTRUCT & getReconstructNibble(recon)) == 0) { + || (QUDA_RECONSTRUCT & getReconstructNibble(recon)) == 0) return true; - } - if (dslash_type == QUDA_ASQTAD_DSLASH && compute_fatlong - && (::testing::get<0>(GetParam()) == 0 || ::testing::get<0>(GetParam()) == 1)) { - warningQuda("Fixed precision unsupported in fat/long compute, skipping..."); + if (is_laplace(dslash_type) && (::testing::get<0>(GetParam()) == 0 || ::testing::get<0>(GetParam()) == 1)) return true; - } - - if (dslash_type == QUDA_ASQTAD_DSLASH && compute_fatlong && (getReconstructNibble(recon) & 1)) { - warningQuda("Reconstruct 9 unsupported in fat/long compute, skipping..."); - return true; - } - - if (dslash_type == QUDA_LAPLACE_DSLASH && (::testing::get<0>(GetParam()) == 0 || ::testing::get<0>(GetParam()) == 1)) { - warningQuda("Fixed precision unsupported for Laplace operator, skipping..."); - return true; - } const std::array partition_enabled {true, true, true, false, true, false, false, false, true, false, false, false, true, false, true, true}; @@ -75,7 +58,7 @@ class StaggeredDslashTest : public ::testing::TestWithParam<::testing::tuple= QUDA_HALF_PRECISION) - tol *= 10; // if recon 8, we tolerate a greater deviation + tol *= 10; // if recon 9, we tolerate a greater deviation + ASSERT_LE(deviation, tol) << "Reference CPU and QUDA implementations do not agree"; } @@ -117,6 +101,10 @@ int main(int argc, char **argv) { // initalize google test ::testing::InitGoogleTest(&argc, argv); + + // override the default dslash from Wilson + dslash_type = QUDA_ASQTAD_DSLASH; + auto app = make_app(); app->add_option("--test", dtest_type, "Test method")->transform(CLI::CheckedTransformer(dtest_type_map)); app->add_option("--all-partitions", ctest_all_partitions, "Test all instead of reduced combination of partitions"); @@ -129,55 +117,35 @@ int main(int argc, char **argv) initComms(argc, argv, gridsize_from_cmdline); - // The 'SetUp()' method of the Google Test class from which DslashTest - // in derived has no arguments, but QUDA's implementation requires the - // use of argc and argv to set up the test via the function 'init'. - // As a workaround, we declare argc_copy and argv_copy as global pointers - // so that they are visible inside the 'init' function. - argc_copy = argc; - argv_copy = argv; - // Ensure gtest prints only from rank 0 ::testing::TestEventListeners &listeners = ::testing::UnitTest::GetInstance()->listeners(); if (comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); } - // Only these fermions are supported in this file. Ensure a reasonable default, - // ensure that the default is improved staggered - if (dslash_type != QUDA_STAGGERED_DSLASH && dslash_type != QUDA_ASQTAD_DSLASH && dslash_type != QUDA_LAPLACE_DSLASH) { - printfQuda("dslash_type %s not supported, defaulting to %s\n", get_dslash_str(dslash_type), - get_dslash_str(QUDA_ASQTAD_DSLASH)); - dslash_type = QUDA_ASQTAD_DSLASH; + // Only these fermions are supported in this file + if constexpr (is_enabled_laplace()) { + if (!is_staggered(dslash_type) && !is_laplace(dslash_type)) + errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type)); + } else { + if (is_laplace(dslash_type)) errorQuda("The Laplace dslash is not enabled, cmake configure with -DQUDA_LAPLACE=ON"); + if (!is_staggered(dslash_type)) errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type)); } // Sanity check: if you pass in a gauge field, want to test the asqtad/hisq dslash, and don't // ask to build the fat/long links... it doesn't make sense. - if (latfile.size() > 0 && !compute_fatlong && dslash_type == QUDA_ASQTAD_DSLASH) { + if (latfile.size() > 0 && !compute_fatlong && dslash_type == QUDA_ASQTAD_DSLASH) errorQuda( "Cannot load a gauge field and test the ASQTAD/HISQ operator without setting \"--compute-fat-long true\".\n"); - compute_fatlong = true; - } // Set n_naiks to 2 if eps_naik != 0.0 - if (dslash_type == QUDA_ASQTAD_DSLASH) { - if (eps_naik != 0.0) { - if (compute_fatlong) { - n_naiks = 2; - printfQuda("Note: epsilon-naik != 0, testing epsilon correction links.\n"); - } else { - eps_naik = 0.0; - printfQuda("Not computing fat-long, ignoring epsilon correction.\n"); - } - } else { - printfQuda("Note: epsilon-naik = 0, testing original HISQ links.\n"); - } + if (eps_naik != 0.0) { + if (compute_fatlong) + n_naiks = 2; + else + eps_naik = 0.0; // to avoid potential headaches } - if (dslash_type == QUDA_LAPLACE_DSLASH) { - if (dtest_type != dslash_test_type::Mat) { - errorQuda("Test type %s is not supported for the Laplace operator.\n", - get_string(dtest_type_map, dtest_type).c_str()); - } - } + if (is_laplace(dslash_type) && dtest_type != dslash_test_type::Mat) + errorQuda("Test type %s is not supported for the Laplace operator", get_string(dtest_type_map, dtest_type).c_str()); int test_rc = RUN_ALL_TESTS(); diff --git a/tests/staggered_dslash_test.cpp b/tests/staggered_dslash_test.cpp index 0a3a063e45..82c84c3225 100644 --- a/tests/staggered_dslash_test.cpp +++ b/tests/staggered_dslash_test.cpp @@ -2,9 +2,6 @@ using namespace quda; -int argc_copy; -char **argv_copy; - class StaggeredDslashTest : public ::testing::Test { protected: @@ -24,7 +21,7 @@ class StaggeredDslashTest : public ::testing::Test public: virtual void SetUp() { - dslash_test_wrapper.init_test(argc_copy, argv_copy); + dslash_test_wrapper.init_test(); display_test_info(); } @@ -53,6 +50,12 @@ TEST_F(StaggeredDslashTest, verify) double deviation = dslash_test_wrapper.verify(); double tol = getTolerance(dslash_test_wrapper.inv_param.cuda_prec); + + // give it a tiny bump for fixed precision, recon 8 + if (dslash_test_wrapper.inv_param.cuda_prec <= QUDA_HALF_PRECISION + && dslash_test_wrapper.gauge_param.reconstruct == QUDA_RECONSTRUCT_9) + tol *= 1.1; + ASSERT_LE(deviation, tol) << "reference and QUDA implementations do not agree"; } @@ -61,6 +64,9 @@ int main(int argc, char **argv) // initalize google test ::testing::InitGoogleTest(&argc, argv); + // override the default dslash from Wilson + dslash_type = QUDA_ASQTAD_DSLASH; + // command line options auto app = make_app(); app->add_option("--test", dtest_type, "Test method")->transform(CLI::CheckedTransformer(dtest_type_map)); @@ -74,59 +80,35 @@ int main(int argc, char **argv) initComms(argc, argv, gridsize_from_cmdline); - // The 'SetUp()' method of the Google Test class from which DslashTest - // in derived has no arguments, but QUDA's implementation requires the - // use of argc and argv to set up the test via the function 'init'. - // As a workaround, we declare argc_copy and argv_copy as global pointers - // so that they are visible inside the 'init' function. - argc_copy = argc; - argv_copy = argv; - // Ensure gtest prints only from rank 0 ::testing::TestEventListeners &listeners = ::testing::UnitTest::GetInstance()->listeners(); if (comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); } - // Only these fermions are supported in this file. Ensure a reasonable default, - // ensure that the default is improved staggered - if (dslash_type != QUDA_STAGGERED_DSLASH && dslash_type != QUDA_ASQTAD_DSLASH && dslash_type != QUDA_LAPLACE_DSLASH) { - printfQuda("dslash_type %s not supported, defaulting to %s\n", get_dslash_str(dslash_type), - get_dslash_str(QUDA_ASQTAD_DSLASH)); - dslash_type = QUDA_ASQTAD_DSLASH; + // Only these fermions are supported in this file + if constexpr (is_enabled_laplace()) { + if (!is_staggered(dslash_type) && !is_laplace(dslash_type)) + errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type)); + } else { + if (is_laplace(dslash_type)) errorQuda("The Laplace dslash is not enabled, cmake configure with -DQUDA_LAPLACE=ON"); + if (!is_staggered(dslash_type)) errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type)); } // Sanity check: if you pass in a gauge field, want to test the asqtad/hisq dslash, // and don't ask to build the fat/long links... it doesn't make sense. - if (latfile.size() > 0 && !compute_fatlong && dslash_type == QUDA_ASQTAD_DSLASH) { + if (latfile.size() > 0 && !compute_fatlong && dslash_type == QUDA_ASQTAD_DSLASH) errorQuda( "Cannot load a gauge field and test the ASQTAD/HISQ operator without setting \"--compute-fat-long true\"."); - } // Set n_naiks to 2 if eps_naik != 0.0 - if (dslash_type == QUDA_ASQTAD_DSLASH) { - if (eps_naik != 0.0) { - if (compute_fatlong) { - n_naiks = 2; - printfQuda("Note: epsilon-naik != 0, testing epsilon correction links.\n"); - } else { - eps_naik = 0.0; - printfQuda("Not computing fat-long, ignoring epsilon correction.\n"); - } - } else { - printfQuda("Note: epsilon-naik = 0, testing original HISQ links.\n"); - } - } - - if (dslash_type == QUDA_LAPLACE_DSLASH) { - if (dtest_type != dslash_test_type::Mat) { - errorQuda("Test type %s is not supported for the Laplace operator", get_string(dtest_type_map, dtest_type).c_str()); - } + if (eps_naik != 0.0) { + if (compute_fatlong) + n_naiks = 2; + else + eps_naik = 0.0; // to avoid potential headaches } - // If we're building fat/long links, there are some - // tests we have to skip. - if (dslash_type == QUDA_ASQTAD_DSLASH && compute_fatlong) { - if (prec < QUDA_SINGLE_PRECISION) { errorQuda("Fixed-point precision unsupported in fat/long compute"); } - } + if (is_laplace(dslash_type) && dtest_type != dslash_test_type::Mat) + errorQuda("Test type %s is not supported for the Laplace operator", get_string(dtest_type_map, dtest_type).c_str()); int test_rc = RUN_ALL_TESTS(); diff --git a/tests/staggered_dslash_test_utils.h b/tests/staggered_dslash_test_utils.h index a4aaac347b..0a3d589ca1 100644 --- a/tests/staggered_dslash_test_utils.h +++ b/tests/staggered_dslash_test_utils.h @@ -25,10 +25,12 @@ using namespace quda; dslash_test_type dtest_type = dslash_test_type::Dslash; CLI::TransformPairs dtest_type_map { - {"Dslash", dslash_test_type::Dslash}, {"MatPC", dslash_test_type::MatPC}, {"Mat", dslash_test_type::Mat} - // left here for completeness but not support in staggered dslash test + {"Dslash", dslash_test_type::Dslash}, + {"MatPC", dslash_test_type::MatPC}, + {"Mat", dslash_test_type::Mat}, + {"MatDagMat", dslash_test_type::MatDagMat}, + // left here for completeness but not supported in staggered dslash test // {"MatPCDagMatPC", dslash_test_type::MatPCDagMatPC}, - // {"MatDagMat", dslash_test_type::MatDagMat}, // {"M5", dslash_test_type::M5}, // {"M5inv", dslash_test_type::M5inv}, // {"Dslash4pre", dslash_test_type::Dslash4pre} @@ -45,41 +47,31 @@ struct DslashTime { struct StaggeredDslashTestWrapper { - static inline void *qdp_inlink[4] = {nullptr, nullptr, nullptr, nullptr}; - // In the HISQ case, we include building fat/long links in this unit test - static inline void *qdp_fatlink_cpu[4] = {}; - static inline void *qdp_longlink_cpu[4] = {}; - QudaGaugeParam gauge_param; QudaInvertParam inv_param; - void *milc_fatlink_gpu; - void *milc_longlink_gpu; - - GaugeField *cpuFat = nullptr; - GaugeField *cpuLong = nullptr; - static inline ColorSpinorField spinor; static inline ColorSpinorField spinorOut; static inline ColorSpinorField spinorRef; - static inline ColorSpinorField tmpCpu; + ColorSpinorField cudaSpinor; ColorSpinorField cudaSpinorOut; static inline std::vector vp_spinor; static inline std::vector vp_spinor_out; - void *ghost_fatlink_cpu[4] = {}; - void *ghost_longlink_cpu[4] = {}; + static inline void *qdp_inlink[4] = {nullptr, nullptr, nullptr, nullptr}; + static inline void *qdp_fatlink[4] = {nullptr, nullptr, nullptr, nullptr}; + static inline void *qdp_longlink[4] = {nullptr, nullptr, nullptr, nullptr}; + static inline void *milc_fatlink = nullptr; + static inline void *milc_longlink = nullptr; + static inline GaugeField cpuFat; + static inline GaugeField cpuLong; QudaParity parity = QUDA_EVEN_PARITY; Dirac *dirac; - // For loading the gauge fields - int argc_copy; - char **argv_copy; - // Split grid options static inline bool test_split_grid = false; int num_src = 1; @@ -89,31 +81,17 @@ struct StaggeredDslashTestWrapper { // compare to dslash reference implementation printfQuda("Calculating reference implementation..."); switch (dtest_type) { - case dslash_test_type::Dslash: - staggeredDslash(spinorRef, qdp_fatlink_cpu, qdp_longlink_cpu, ghost_fatlink_cpu, ghost_longlink_cpu, spinor, - parity, dagger, inv_param.cpu_prec, gauge_param.cpu_prec, dslash_type); - break; - case dslash_test_type::MatPC: - staggeredMatDagMat(spinorRef, qdp_fatlink_cpu, qdp_longlink_cpu, ghost_fatlink_cpu, ghost_longlink_cpu, spinor, - mass, 0, inv_param.cpu_prec, gauge_param.cpu_prec, tmpCpu, parity, dslash_type); - break; - case dslash_test_type::Mat: - // the !dagger is to reconcile the QUDA convention of D_stag = {{ 2m, -D_{eo}}, -D_{oe}, 2m}} vs the host convention without the minus signs - staggeredDslash(spinorRef.Even(), qdp_fatlink_cpu, qdp_longlink_cpu, ghost_fatlink_cpu, ghost_longlink_cpu, - spinor.Odd(), QUDA_EVEN_PARITY, !dagger, inv_param.cpu_prec, gauge_param.cpu_prec, dslash_type); - staggeredDslash(spinorRef.Odd(), qdp_fatlink_cpu, qdp_longlink_cpu, ghost_fatlink_cpu, ghost_longlink_cpu, - spinor.Even(), QUDA_ODD_PARITY, !dagger, inv_param.cpu_prec, gauge_param.cpu_prec, dslash_type); - if (dslash_type == QUDA_LAPLACE_DSLASH) { - xpay(spinor.data(), kappa, spinorRef.data(), spinor.Length(), gauge_param.cpu_prec); - } else { - axpy(2 * mass, spinor.data(), spinorRef.data(), spinor.Length(), gauge_param.cpu_prec); - } + case dslash_test_type::Dslash: stag_dslash(spinorRef, cpuFat, cpuLong, spinor, parity, dagger, dslash_type); break; + case dslash_test_type::MatPC: stag_matpc(spinorRef, cpuFat, cpuLong, spinor, mass, 0, parity, dslash_type); break; + case dslash_test_type::Mat: stag_mat(spinorRef, cpuFat, cpuLong, spinor, mass, dagger, dslash_type); break; + case dslash_test_type::MatDagMat: + stag_matdag_mat(spinorRef, cpuFat, cpuLong, spinor, mass, dagger, dslash_type); break; default: errorQuda("Test type %d not defined", static_cast(dtest_type)); } } - void init_ctest(int argc, char **argv, int precision, QudaReconstructType link_recon_) + void init_ctest(int precision, QudaReconstructType link_recon_) { gauge_param = newQudaGaugeParam(); inv_param = newQudaInvertParam(); @@ -122,7 +100,6 @@ struct StaggeredDslashTestWrapper { setStaggeredInvertParam(inv_param); auto prec = getPrecision(precision); - setVerbosity(QUDA_SUMMARIZE); gauge_param.cuda_prec = prec; gauge_param.cuda_prec_sloppy = prec; @@ -135,13 +112,13 @@ struct StaggeredDslashTestWrapper { static bool first_time = true; if (first_time) { - init_host(argc, argv); + init_host(); first_time = false; } init(); } - void init_test(int argc, char **argv) + void init_test() { gauge_param = newQudaGaugeParam(); inv_param = newQudaInvertParam(); @@ -151,13 +128,13 @@ struct StaggeredDslashTestWrapper { static bool first_time = true; if (first_time) { - init_host(argc, argv); + init_host(); first_time = false; } init(); } - void init_host(int argc, char **argv) + void init_host() { setDims(gauge_param.X); dw_setDims(gauge_param.X, 1); @@ -173,14 +150,21 @@ struct StaggeredDslashTestWrapper { for (int dir = 0; dir < 4; dir++) { qdp_inlink[dir] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); - qdp_fatlink_cpu[dir] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); - qdp_longlink_cpu[dir] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); + qdp_fatlink[dir] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); + qdp_longlink[dir] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); } bool compute_on_gpu = false; // reference fat/long fields should be computed on cpu - constructStaggeredHostGaugeField(qdp_inlink, qdp_longlink_cpu, qdp_fatlink_cpu, gauge_param, argc, argv, - compute_on_gpu); + constructStaggeredHostGaugeField(qdp_inlink, qdp_longlink, qdp_fatlink, gauge_param, 0, nullptr, compute_on_gpu); + + // create the reordered MILC-layout fields + milc_fatlink = safe_malloc(4 * V * gauge_site_size * host_gauge_data_type_size); + milc_longlink = safe_malloc(4 * V * gauge_site_size * host_gauge_data_type_size); + reorderQDPtoMILC(milc_fatlink, qdp_fatlink, V, gauge_site_size, gauge_param.cpu_prec, gauge_param.cpu_prec); + reorderQDPtoMILC(milc_longlink, qdp_longlink, V, gauge_site_size, gauge_param.cpu_prec, gauge_param.cpu_prec); + + // create some host-side spinors up front ColorSpinorParam csParam; csParam.nColor = 3; csParam.nSpin = 1; @@ -189,9 +173,8 @@ struct StaggeredDslashTestWrapper { csParam.x[4] = 1; csParam.setPrecision(inv_param.cpu_prec); - // inv_param.solution_type = QUDA_MAT_SOLUTION; csParam.pad = 0; - if (dtest_type != dslash_test_type::Mat && dslash_type != QUDA_LAPLACE_DSLASH) { + if (dtest_type != dslash_test_type::Mat && dtest_type != dslash_test_type::MatDagMat) { csParam.siteSubset = QUDA_PARITY_SITE_SUBSET; csParam.x[0] /= 2; inv_param.solution_type = QUDA_MATPC_SOLUTION; @@ -210,7 +193,6 @@ struct StaggeredDslashTestWrapper { spinor = ColorSpinorField(csParam); spinorOut = ColorSpinorField(csParam); spinorRef = ColorSpinorField(csParam); - tmpCpu = ColorSpinorField(csParam); spinor.Source(QUDA_RANDOM_SOURCE); @@ -230,95 +212,40 @@ struct StaggeredDslashTestWrapper { void init() { - // Prepare the fields to be used for the GPU computation - void *qdp_fatlink_gpu[4]; - void *qdp_longlink_gpu[4]; - for (int dir = 0; dir < 4; dir++) { - qdp_fatlink_gpu[dir] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); - qdp_longlink_gpu[dir] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); - } - // QUDA_STAGGERED_DSLASH follows the same codepath whether or not you - // "compute" the fat/long links or not. - if (dslash_type == QUDA_STAGGERED_DSLASH || dslash_type == QUDA_LAPLACE_DSLASH) { - for (int dir = 0; dir < 4; dir++) { - memcpy(qdp_fatlink_gpu[dir], qdp_inlink[dir], V * gauge_site_size * host_gauge_data_type_size); - memset(qdp_longlink_gpu[dir], 0, V * gauge_site_size * host_gauge_data_type_size); - } - } else { - // QUDA_ASQTAD_DSLASH - if (compute_fatlong) { - computeFatLongGPU(qdp_fatlink_gpu, qdp_longlink_gpu, qdp_inlink, gauge_param, host_gauge_data_type_size, - n_naiks, eps_naik); - } else { - // Not computing FatLong - for (int dir = 0; dir < 4; dir++) { - memcpy(qdp_fatlink_gpu[dir], qdp_inlink[dir], V * gauge_site_size * host_gauge_data_type_size); - memcpy(qdp_longlink_gpu[dir], qdp_longlink_cpu[dir], V * gauge_site_size * host_gauge_data_type_size); - } - } - } - - // Create ghost zones for CPU fields, - // prepare and load the GPU fields - void *milc_fatlink_cpu = safe_malloc(4 * V * gauge_site_size * host_gauge_data_type_size); - void *milc_longlink_cpu = safe_malloc(4 * V * gauge_site_size * host_gauge_data_type_size); - milc_fatlink_gpu = safe_malloc(4 * V * gauge_site_size * host_gauge_data_type_size); - milc_longlink_gpu = safe_malloc(4 * V * gauge_site_size * host_gauge_data_type_size); - - // Alright, we've created all the void** links. - // Create the void* pointers - reorderQDPtoMILC(milc_fatlink_gpu, qdp_fatlink_gpu, V, gauge_site_size, gauge_param.cpu_prec, gauge_param.cpu_prec); - reorderQDPtoMILC(milc_fatlink_cpu, qdp_fatlink_cpu, V, gauge_site_size, gauge_param.cpu_prec, gauge_param.cpu_prec); - reorderQDPtoMILC(milc_longlink_gpu, qdp_longlink_gpu, V, gauge_site_size, gauge_param.cpu_prec, gauge_param.cpu_prec); - reorderQDPtoMILC(milc_longlink_cpu, qdp_longlink_cpu, V, gauge_site_size, gauge_param.cpu_prec, gauge_param.cpu_prec); + // For load, etc + gauge_param.reconstruct = QUDA_RECONSTRUCT_NO; -#ifdef MULTI_GPU - gauge_param.type = (dslash_type == QUDA_ASQTAD_DSLASH) ? QUDA_ASQTAD_FAT_LINKS : QUDA_SU3_LINKS; + // Create ghost gauge fields in case of multi GPU builds. + gauge_param.type = (dslash_type == QUDA_STAGGERED_DSLASH || dslash_type == QUDA_LAPLACE_DSLASH) ? + QUDA_SU3_LINKS : + QUDA_ASQTAD_FAT_LINKS; gauge_param.reconstruct = QUDA_RECONSTRUCT_NO; - GaugeFieldParam cpuFatParam(gauge_param, milc_fatlink_cpu); - cpuFatParam.ghostExchange = QUDA_GHOST_EXCHANGE_PAD; - cpuFat = new GaugeField(cpuFatParam); - for (int i = 0; i < 4; i++) ghost_fatlink_cpu[i] = cpuFat->Ghost()[i].data(); - - if (dslash_type == QUDA_ASQTAD_DSLASH) { - gauge_param.type = QUDA_ASQTAD_LONG_LINKS; - GaugeFieldParam cpuLongParam(gauge_param, milc_longlink_cpu); - cpuLongParam.ghostExchange = QUDA_GHOST_EXCHANGE_PAD; - cpuLong = new GaugeField(cpuLongParam); - for (int i = 0; i < 4; i++) ghost_longlink_cpu[i] = cpuLong ? cpuLong->Ghost()[i].data() : nullptr; - } -#endif - - gauge_param.type = (dslash_type == QUDA_ASQTAD_DSLASH) ? QUDA_ASQTAD_FAT_LINKS : QUDA_SU3_LINKS; - if (dslash_type == QUDA_STAGGERED_DSLASH) { - gauge_param.reconstruct = gauge_param.reconstruct_sloppy = (link_recon == QUDA_RECONSTRUCT_12) ? - QUDA_RECONSTRUCT_13 : - (link_recon == QUDA_RECONSTRUCT_8) ? QUDA_RECONSTRUCT_9 : - link_recon; - } else { - gauge_param.reconstruct = gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_NO; - } + gauge_param.location = QUDA_CPU_FIELD_LOCATION; - printfQuda("Sending fat links to GPU\n"); - loadGaugeQuda(milc_fatlink_gpu, &gauge_param); + GaugeFieldParam cpuFatParam(gauge_param, qdp_fatlink); + cpuFatParam.order = QUDA_QDP_GAUGE_ORDER; + cpuFatParam.ghostExchange = QUDA_GHOST_EXCHANGE_PAD; + cpuFat = GaugeField(cpuFatParam); gauge_param.type = QUDA_ASQTAD_LONG_LINKS; - -#ifdef MULTI_GPU - gauge_param.ga_pad *= 3; -#endif - - if (dslash_type == QUDA_ASQTAD_DSLASH) { - gauge_param.staggered_phase_type = QUDA_STAGGERED_PHASE_NO; - gauge_param.reconstruct = gauge_param.reconstruct_sloppy = (link_recon == QUDA_RECONSTRUCT_12) ? - QUDA_RECONSTRUCT_13 : - (link_recon == QUDA_RECONSTRUCT_8) ? QUDA_RECONSTRUCT_9 : - link_recon; - printfQuda("Sending long links to GPU\n"); - loadGaugeQuda(milc_longlink_gpu, &gauge_param); + GaugeFieldParam cpuLongParam(gauge_param, qdp_longlink); + cpuLongParam.order = QUDA_QDP_GAUGE_ORDER; + cpuLongParam.ghostExchange = QUDA_GHOST_EXCHANGE_PAD; + cpuLong = GaugeField(cpuLongParam); + + // Override link reconstruct as appropriate for staggered or asqtad + if (is_staggered(dslash_type)) { + if (link_recon == QUDA_RECONSTRUCT_12) link_recon = QUDA_RECONSTRUCT_13; + if (link_recon == QUDA_RECONSTRUCT_8) link_recon = QUDA_RECONSTRUCT_9; } + loadFatLongGaugeQuda(milc_fatlink, milc_longlink, gauge_param); + + // reset the reconstruct in gauge param + gauge_param.reconstruct = link_recon; + + // create device-size spinors ColorSpinorParam csParam(spinor); csParam.fieldOrder = colorspinor::getNative(inv_param.cuda_prec, 1); csParam.pad = 0; @@ -335,19 +262,6 @@ struct StaggeredDslashTestWrapper { setDiracParam(diracParam, &inv_param, pc); dirac = Dirac::create(diracParam); - host_free(milc_fatlink_cpu); - host_free(milc_longlink_cpu); - - for (int dir = 0; dir < 4; dir++) { - if (qdp_fatlink_gpu[dir] != nullptr) { - host_free(qdp_fatlink_gpu[dir]); - qdp_fatlink_gpu[dir] = nullptr; - } - if (qdp_longlink_gpu[dir] != nullptr) { - host_free(qdp_longlink_gpu[dir]); - qdp_longlink_gpu[dir] = nullptr; - } - } } void end() @@ -356,22 +270,9 @@ struct StaggeredDslashTestWrapper { delete dirac; dirac = nullptr; } - - host_free(milc_fatlink_gpu); - milc_fatlink_gpu = nullptr; - host_free(milc_longlink_gpu); - milc_longlink_gpu = nullptr; - freeGaugeQuda(); - - if (cpuFat) { - delete cpuFat; - cpuFat = nullptr; - } - if (cpuLong) { - delete cpuLong; - cpuLong = nullptr; - } + cpuFat = {}; + cpuLong = {}; commDimPartitionedReset(); } @@ -379,14 +280,23 @@ struct StaggeredDslashTestWrapper { { for (int dir = 0; dir < 4; dir++) { if (qdp_inlink[dir]) host_free(qdp_inlink[dir]); - if (qdp_fatlink_cpu[dir]) host_free(qdp_fatlink_cpu[dir]); - if (qdp_longlink_cpu[dir]) host_free(qdp_longlink_cpu[dir]); + if (qdp_fatlink[dir]) host_free(qdp_fatlink[dir]); + if (qdp_longlink[dir]) host_free(qdp_longlink[dir]); + } + + if (milc_fatlink) { + host_free(milc_fatlink); + milc_fatlink = nullptr; + } + + if (milc_longlink) { + host_free(milc_longlink); + milc_longlink = nullptr; } spinor = {}; spinorOut = {}; spinorRef = {}; - tmpCpu = {}; if (test_split_grid) { vp_spinor.clear(); @@ -412,7 +322,7 @@ struct StaggeredDslashTestWrapper { _hp_x[i] = vp_spinor_out[i].data(); _hp_b[i] = vp_spinor[i].data(); } - dslashMultiSrcStaggeredQuda(_hp_x.data(), _hp_b.data(), &inv_param, parity, milc_fatlink_gpu, milc_longlink_gpu, + dslashMultiSrcStaggeredQuda(_hp_x.data(), _hp_b.data(), &inv_param, parity, milc_fatlink, milc_longlink, &gauge_param); } else { @@ -421,11 +331,21 @@ struct StaggeredDslashTestWrapper { host_timer.start(); - switch (dtest_type) { - case dslash_test_type::Dslash: dirac->Dslash(cudaSpinorOut, cudaSpinor, parity); break; - case dslash_test_type::MatPC: dirac->M(cudaSpinorOut, cudaSpinor); break; - case dslash_test_type::Mat: dirac->M(cudaSpinorOut, cudaSpinor); break; - default: errorQuda("Test type %d not defined on staggered dslash", static_cast(dtest_type)); + if (is_laplace(dslash_type)) { + switch (dtest_type) { + case dslash_test_type::Mat: dirac->M(cudaSpinorOut, cudaSpinor); break; + default: errorQuda("Test type %d not defined on Laplace operator", static_cast(dtest_type)); + } + } else if (is_staggered(dslash_type)) { + switch (dtest_type) { + case dslash_test_type::Dslash: dirac->Dslash(cudaSpinorOut, cudaSpinor, parity); break; + case dslash_test_type::MatPC: dirac->M(cudaSpinorOut, cudaSpinor); break; + case dslash_test_type::Mat: dirac->M(cudaSpinorOut, cudaSpinor); break; + case dslash_test_type::MatDagMat: dirac->MdagM(cudaSpinorOut, cudaSpinor); break; + default: errorQuda("Test type %d not defined on staggered dslash", static_cast(dtest_type)); + } + } else { + errorQuda("Invalid dslash type %d", dslash_type); } host_timer.stop(); diff --git a/tests/staggered_eigensolve_test.cpp b/tests/staggered_eigensolve_test.cpp index a495a11251..6e717437fe 100644 --- a/tests/staggered_eigensolve_test.cpp +++ b/tests/staggered_eigensolve_test.cpp @@ -15,126 +15,116 @@ #include #include -#define MAX(a, b) ((a) > (b) ? (a) : (b)) +QudaGaugeParam gauge_param; +QudaInvertParam eig_inv_param; +QudaEigParam eig_param; -void display_test_info() +// if "--enable-testing true" is passed, we run the tests defined in here +#include + +void display_test_info(QudaEigParam ¶m) { printfQuda("running the following test:\n"); - printfQuda("prec sloppy_prec link_recon sloppy_link_recon test_type S_dimension T_dimension\n"); - printfQuda("%s %s %s %s %s %d/%d/%d %d \n", get_prec_str(prec), - get_prec_str(prec_sloppy), get_recon_str(link_recon), get_recon_str(link_recon_sloppy), - get_staggered_test_type(test_type), xdim, ydim, zdim, tdim); + + printfQuda("prec sloppy_prec link_recon sloppy_link_recon S_dimension T_dimension Ls_dimension\n"); + printfQuda("%s %s %s %s %d/%d/%d %d %d\n", get_prec_str(prec), + get_prec_str(prec_sloppy), get_recon_str(link_recon), get_recon_str(link_recon_sloppy), xdim, ydim, zdim, + tdim, Lsdim); printfQuda("\n Eigensolver parameters\n"); - printfQuda(" - solver mode %s\n", get_eig_type_str(eig_type)); - printfQuda(" - spectrum requested %s\n", get_eig_spectrum_str(eig_spectrum)); - if (eig_type == QUDA_EIG_BLK_TR_LANCZOS) printfQuda(" - eigenvector block size %d\n", eig_block_size); - printfQuda(" - number of eigenvectors requested %d\n", eig_n_conv); - printfQuda(" - size of eigenvector search space %d\n", eig_n_ev); - printfQuda(" - size of Krylov space %d\n", eig_n_kr); - printfQuda(" - solver tolerance %e\n", eig_tol); - printfQuda(" - convergence required (%s)\n", eig_require_convergence ? "true" : "false"); - if (eig_compute_svd) { + printfQuda(" - solver mode %s\n", get_eig_type_str(param.eig_type)); + printfQuda(" - spectrum requested %s\n", get_eig_spectrum_str(param.spectrum)); + if (param.eig_type == QUDA_EIG_BLK_TR_LANCZOS) printfQuda(" - eigenvector block size %d\n", param.block_size); + printfQuda(" - number of eigenvectors requested %d\n", param.n_conv); + printfQuda(" - size of eigenvector search space %d\n", param.n_ev); + printfQuda(" - size of Krylov space %d\n", param.n_kr); + printfQuda(" - solver tolerance %e\n", param.tol); + printfQuda(" - convergence required (%s)\n", param.require_convergence ? "true" : "false"); + if (param.compute_svd) { printfQuda(" - Operator: MdagM. Will compute SVD of M\n"); printfQuda(" - ***********************************************************\n"); printfQuda(" - **** Overriding any previous choices of operator type. ****\n"); printfQuda(" - **** SVD demands normal operator, will use MdagM ****\n"); printfQuda(" - ***********************************************************\n"); } else { - printfQuda(" - Operator: daggered (%s) , norm-op (%s)\n", eig_use_dagger ? "true" : "false", - eig_use_normop ? "true" : "false"); + printfQuda(" - Operator: daggered (%s) , norm-op (%s), even-odd pc (%s)\n", param.use_dagger ? "true" : "false", + param.use_norm_op ? "true" : "false", param.use_pc ? "true" : "false"); } - if (eig_use_poly_acc) { - printfQuda(" - Chebyshev polynomial degree %d\n", eig_poly_deg); - printfQuda(" - Chebyshev polynomial minumum %e\n", eig_amin); - if (eig_amax < 0) + if (param.use_poly_acc) { + printfQuda(" - Chebyshev polynomial degree %d\n", param.poly_deg); + printfQuda(" - Chebyshev polynomial minumum %e\n", param.a_min); + if (param.a_max <= 0) printfQuda(" - Chebyshev polynomial maximum will be computed\n"); else - printfQuda(" - Chebyshev polynomial maximum %e\n\n", eig_amax); + printfQuda(" - Chebyshev polynomial maximum %e\n\n", param.a_max); } - printfQuda("Grid partition info: X Y Z T\n"); printfQuda(" %d %d %d %d\n", dimPartitioned(0), dimPartitioned(1), dimPartitioned(2), dimPartitioned(3)); } -int main(int argc, char **argv) -{ - // Set a default - solve_type = QUDA_INVALID_SOLVE; - - auto app = make_app(); - add_eigen_option_group(app); - CLI::TransformPairs test_type_map {{"full", 0}, {"even", 3}, {"odd", 4}}; - app->add_option("--test", test_type, "Test method")->transform(CLI::CheckedTransformer(test_type_map)); - - try { - app->parse(argc, argv); - } catch (const CLI::ParseError &e) { - return app->exit(e); - } - - // initialize QMP/MPI, QUDA comms grid and RNG (host_utils.cpp) - initComms(argc, argv, gridsize_from_cmdline); - - // Set values for precisions via the command line. - setQudaPrecisions(); - - // Only these fermions are supported in this file. Ensure a reasonable default, - // ensure that the default is improved staggered - if (dslash_type != QUDA_STAGGERED_DSLASH && dslash_type != QUDA_ASQTAD_DSLASH && dslash_type != QUDA_LAPLACE_DSLASH) { - printfQuda("dslash_type %s not supported, defaulting to %s\n", get_dslash_str(dslash_type), - get_dslash_str(QUDA_ASQTAD_DSLASH)); - dslash_type = QUDA_ASQTAD_DSLASH; - } - - setQudaStaggeredEigTestParams(); - - display_test_info(); +GaugeField cpuFatQDP = {}; +GaugeField cpuLongQDP = {}; +GaugeField cpuFatMILC = {}; +GaugeField cpuLongMILC = {}; +void init() +{ // Set QUDA internal parameters - QudaGaugeParam gauge_param = newQudaGaugeParam(); + gauge_param = newQudaGaugeParam(); setStaggeredGaugeParam(gauge_param); + // Though no inversions are performed, the inv_param // structure contains all the information we need to - // construct the dirac operator. We encapsualte the - // inv_param structure inside the eig_param structure - // to avoid any confusion - QudaInvertParam eig_inv_param = newQudaInvertParam(); + // construct the dirac operator. + eig_inv_param = newQudaInvertParam(); setStaggeredInvertParam(eig_inv_param); - QudaEigParam eig_param = newQudaEigParam(); - setEigParam(eig_param); - // We encapsulate the eigensolver parameters inside the invert parameter structure - eig_param.invert_param = &eig_inv_param; - if (eig_param.arpack_check && !(prec == QUDA_DOUBLE_PRECISION)) { - errorQuda("ARPACK check only available in double precision"); - } - - initQuda(device_ordinal); + eig_param = newQudaEigParam(); + // We encapsualte the inv_param structure inside the eig_param structure + eig_param.invert_param = &eig_inv_param; + setEigParam(eig_param); setDims(gauge_param.X); - dw_setDims(gauge_param.X, 1); // so we can use 5-d indexing from dwf + dw_setDims(gauge_param.X, 1); // Staggered Gauge construct START //----------------------------------------------------------------------------------- - void *qdp_inlink[4] = {nullptr, nullptr, nullptr, nullptr}; - void *qdp_fatlink[4] = {nullptr, nullptr, nullptr, nullptr}; - void *qdp_longlink[4] = {nullptr, nullptr, nullptr, nullptr}; - void *milc_fatlink = nullptr; - void *milc_longlink = nullptr; - - for (int dir = 0; dir < 4; dir++) { - qdp_inlink[dir] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); - qdp_fatlink[dir] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); - qdp_longlink[dir] = safe_malloc(V * gauge_site_size * host_gauge_data_type_size); - } - milc_fatlink = safe_malloc(4 * V * gauge_site_size * host_gauge_data_type_size); - milc_longlink = safe_malloc(4 * V * gauge_site_size * host_gauge_data_type_size); + // Allocate host staggered gauge fields + gauge_param.type = (dslash_type == QUDA_STAGGERED_DSLASH || dslash_type == QUDA_LAPLACE_DSLASH) ? + QUDA_SU3_LINKS : + QUDA_ASQTAD_FAT_LINKS; + gauge_param.reconstruct = QUDA_RECONSTRUCT_NO; + gauge_param.location = QUDA_CPU_FIELD_LOCATION; + + GaugeFieldParam cpuParam(gauge_param); + cpuParam.order = QUDA_QDP_GAUGE_ORDER; + cpuParam.ghostExchange = QUDA_GHOST_EXCHANGE_PAD; + cpuParam.create = QUDA_NULL_FIELD_CREATE; + GaugeField cpuIn = GaugeField(cpuParam); + cpuFatQDP = GaugeField(cpuParam); + cpuParam.order = QUDA_MILC_GAUGE_ORDER; + cpuFatMILC = GaugeField(cpuParam); + + cpuParam.link_type = QUDA_ASQTAD_LONG_LINKS; + cpuParam.nFace = 3; + cpuParam.order = QUDA_QDP_GAUGE_ORDER; + cpuLongQDP = GaugeField(cpuParam); + cpuParam.order = QUDA_MILC_GAUGE_ORDER; + cpuLongMILC = GaugeField(cpuParam); + + void *qdp_inlink[4] = {cpuIn.data(0), cpuIn.data(1), cpuIn.data(2), cpuIn.data(3)}; + void *qdp_fatlink[4] = {cpuFatQDP.data(0), cpuFatQDP.data(1), cpuFatQDP.data(2), cpuFatQDP.data(3)}; + void *qdp_longlink[4] = {cpuLongQDP.data(0), cpuLongQDP.data(1), cpuLongQDP.data(2), cpuLongQDP.data(3)}; + constructStaggeredHostGaugeField(qdp_inlink, qdp_longlink, qdp_fatlink, gauge_param, 0, nullptr, true); - constructStaggeredHostGaugeField(qdp_inlink, qdp_longlink, qdp_fatlink, gauge_param, argc, argv, true); + // Reorder gauge fields to MILC order + cpuFatMILC = cpuFatQDP; + cpuLongMILC = cpuLongQDP; // Compute plaquette. Routine is aware that the gauge fields already have the phases on them. + // This needs to be called before `loadFatLongGaugeQuda` because this routine also loads the + // gauge fields with different parameters. double plaq[3]; computeStaggeredPlaquetteQDPOrder(qdp_inlink, plaq, gauge_param, dslash_type); printfQuda("Computed plaquette is %e (spatial = %e, temporal = %e)\n", plaq[0], plaq[1], plaq[2]); @@ -145,62 +135,220 @@ int main(int argc, char **argv) printfQuda("Computed fat link plaquette is %e (spatial = %e, temporal = %e)\n", plaq[0], plaq[1], plaq[2]); } - // Reorder gauge fields to MILC order - reorderQDPtoMILC(milc_fatlink, qdp_fatlink, V, gauge_site_size, gauge_param.cpu_prec, gauge_param.cpu_prec); - reorderQDPtoMILC(milc_longlink, qdp_longlink, V, gauge_site_size, gauge_param.cpu_prec, gauge_param.cpu_prec); + freeGaugeQuda(); + + loadFatLongGaugeQuda(cpuFatMILC.data(), cpuLongMILC.data(), gauge_param); - loadFatLongGaugeQuda(milc_fatlink, milc_longlink, gauge_param); + // now copy back to QDP aliases, since these are used for the reference dslash + cpuFatQDP = cpuFatMILC; + cpuLongQDP = cpuLongMILC; + // ensure QDP alias has exchanged ghosts + cpuFatQDP.exchangeGhost(); + cpuLongQDP.exchangeGhost(); // Staggered Gauge construct END //----------------------------------------------------------------------------------- +} + +std::vector eigensolve(test_t test_param) +{ + // Collect testing parameters from gtest + eig_param.eig_type = ::testing::get<0>(test_param); + eig_param.use_norm_op = ::testing::get<1>(test_param); + eig_param.use_pc = ::testing::get<2>(test_param); + eig_param.compute_svd = ::testing::get<3>(test_param); + eig_param.spectrum = ::testing::get<4>(test_param); + + if (eig_param.use_pc) + eig_inv_param.solution_type = QUDA_MATPC_SOLUTION; + else + eig_inv_param.solution_type = QUDA_MAT_SOLUTION; + + // For gtest testing, we prohibit the use of polynomial acceleration as + // the fine tuning required can inhibit convergence of an otherwise + // perfectly good algorithm. We also have a default value of 4 + // for the block size in Block TRLM, and 4 for the batched rotation. + // The user may change these values via the command line: + // --eig-block-size + // --eig-batched-rotate + if (enable_testing) { + eig_use_poly_acc = false; + eig_param.use_poly_acc = QUDA_BOOLEAN_FALSE; + eig_batched_rotate != 0 ? eig_param.batched_rotate = eig_batched_rotate : eig_param.batched_rotate = 0; + } + + logQuda(QUDA_SUMMARIZE, "Action = %s, Solver = %s, norm-op = %s, even-odd = %s, with SVD = %s, spectrum = %s\n", + get_dslash_str(dslash_type), get_eig_type_str(eig_param.eig_type), + eig_param.use_norm_op == QUDA_BOOLEAN_TRUE ? "true" : "false", + eig_param.use_pc == QUDA_BOOLEAN_TRUE ? "true" : "false", + eig_param.compute_svd == QUDA_BOOLEAN_TRUE ? "true" : "false", get_eig_spectrum_str(eig_param.spectrum)); + + if (!enable_testing || (enable_testing && getVerbosity() >= QUDA_VERBOSE)) display_test_info(eig_param); + // Vector construct START + //---------------------------------------------------------------------------- // Host side arrays to store the eigenpairs computed by QUDA - void **host_evecs = (void **)safe_malloc(eig_n_conv * sizeof(void *)); - for (int i = 0; i < eig_n_conv; i++) { - host_evecs[i] = (void *)safe_malloc(V * stag_spinor_site_size * eig_inv_param.cpu_prec); + int n_eig = eig_n_conv; + if (eig_param.compute_svd == QUDA_BOOLEAN_TRUE) n_eig *= 2; + std::vector evecs(n_eig); + quda::ColorSpinorParam cs_param; + constructStaggeredTestSpinorParam(&cs_param, &eig_inv_param, &gauge_param); + // Void pointers to host side arrays, compatible with the QUDA interface. + std::vector host_evecs_ptr(n_eig); + // Allocate host side memory and pointers + for (int i = 0; i < n_eig; i++) { + evecs[i] = quda::ColorSpinorField(cs_param); + host_evecs_ptr[i] = evecs[i].data(); } - double _Complex *host_evals = (double _Complex *)safe_malloc(eig_param.n_ev * sizeof(double _Complex)); - double time = 0.0; + // Complex eigenvalues + std::vector<__complex__ double> evals(eig_n_conv); + // Vector construct END + //---------------------------------------------------------------------------- - // QUDA eigensolver test + // QUDA eigensolver test BEGIN + //---------------------------------------------------------------------------- + // This function returns the host_evecs and host_evals pointers, populated with the + // requested data, at the requested prec. All the information needed to perfom the + // solve is in the eig_param container. If eig_param.arpack_check == true and + // precision is double, the routine will use ARPACK rather than the GPU. + quda::host_timer_t host_timer; + host_timer.start(); + eigensolveQuda(host_evecs_ptr.data(), evals.data(), &eig_param); + host_timer.stop(); + printfQuda("Time for %s solution = %f\n", eig_param.arpack_check ? "ARPACK" : "QUDA", host_timer.last()); + + // Perform host side verification of eigenvector if requested. + // ... + + std::vector residua(eig_n_conv, 0.0); + // Perform host side verification of eigenvector if requested. + if (verify_results) { + for (int i = 0; i < eig_n_conv; i++) { + if (eig_param.compute_svd == QUDA_BOOLEAN_TRUE) { + double _Complex sigma = evals[i]; + residua[i] = verifyStaggeredTypeSingularVector(evecs[i], evecs[i + eig_n_conv], sigma, i, eig_param, cpuFatQDP, + cpuLongQDP); + } else { + double _Complex lambda = evals[i]; + residua[i] = verifyStaggeredTypeEigenvector(evecs[i], lambda, i, eig_param, cpuFatQDP, cpuLongQDP); + } + } + } + return residua; + // QUDA eigensolver test COMPLETE //---------------------------------------------------------------------------- - switch (test_type) { - case 0: // full parity solution - case 3: // even - case 4: // odd - // This function returns the host_evecs and host_evals pointers, populated with - // the requested data, at the requested prec. All the information needed to - // perfom the solve is in the eig_param container. - // If eig_param.arpack_check == true and precision is double, the routine will - // use ARPACK rather than the GPU. - - time = -((double)clock()); - eigensolveQuda(host_evecs, host_evals, &eig_param); - time += (double)clock(); - - printfQuda("Time for %s solution = %f\n", eig_param.arpack_check ? "ARPACK" : "QUDA", time / CLOCKS_PER_SEC); - break; - - default: errorQuda("Unsupported test type"); - - } // switch - - // Deallocate host memory - for (int i = 0; i < eig_n_conv; i++) host_free(host_evecs[i]); - host_free(host_evecs); - host_free(host_evals); - - // Clean up gauge fields. - for (int dir = 0; dir < 4; dir++) { - host_free(qdp_inlink[dir]); - host_free(qdp_fatlink[dir]); - host_free(qdp_longlink[dir]); +} + +void cleanup() +{ + cpuFatQDP = {}; + cpuLongQDP = {}; + cpuFatMILC = {}; + cpuLongMILC = {}; +} + +int main(int argc, char **argv) +{ + ::testing::InitGoogleTest(&argc, argv); + // Set defaults + setQudaStaggeredDefaultInvTestParams(); + + auto app = make_app(); + add_eigen_option_group(app); + add_testing_option_group(app); + try { + app->parse(argc, argv); + } catch (const CLI::ParseError &e) { + return app->exit(e); + } + setVerbosity(verbosity); + + // Set values for precisions via the command line. + setQudaPrecisions(); + + // initialize QMP/MPI, QUDA comms grid and RNG (host_utils.cpp) + initComms(argc, argv, gridsize_from_cmdline); + + initRand(); + + // Only these fermions are supported in this file + if constexpr (is_enabled_laplace()) { + if (!is_staggered(dslash_type) && !is_laplace(dslash_type)) + errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type)); + } else { + if (is_laplace(dslash_type)) errorQuda("The Laplace dslash is not enabled, cmake configure with -DQUDA_LAPLACE=ON"); + if (!is_staggered(dslash_type)) errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type)); + } + + if (eig_param.arpack_check && !(prec == QUDA_DOUBLE_PRECISION)) { + errorQuda("ARPACK check only available in double precision"); + } + + // Sanity check combinations of solve type and solution type + if ((solve_type == QUDA_DIRECT_SOLVE && solution_type != QUDA_MAT_SOLUTION) + || (solve_type == QUDA_DIRECT_PC_SOLVE && solution_type != QUDA_MATPC_SOLUTION) + || (solve_type == QUDA_NORMOP_SOLVE && solution_type != QUDA_MATDAG_MAT_SOLUTION)) { + errorQuda("Invalid combination of solve_type %s and solution_type %s", get_solve_str(solve_type), + get_solution_str(solution_type)); + } + + initQuda(device_ordinal); + + if (enable_testing) { + // We need to force a well-behaved operator + reasonable convergence, otherwise + // the staggered tests will fail. These checks are designed to be consistent + // with what's in [src]/tests/CMakeFiles.txt, which have been "sanity checked" + bool changes = false; + if (!compute_fatlong) { + compute_fatlong = true; + changes = true; + } + + double expected_tol = (prec == QUDA_SINGLE_PRECISION) ? 1e-4 : 1e-5; + if (eig_tol != expected_tol) { + eig_tol = expected_tol; + changes = true; + } + if (niter != 1000) { + niter = 1000; + changes = true; + } + if (eig_n_kr != 256) { + eig_n_kr = 256; + changes = true; + } + if (eig_block_size != 4) { eig_block_size = 4; } + + if (changes) { + printfQuda("For gtest, various defaults are changed:\n"); + printfQuda(" --compute-fat-long true\n"); + printfQuda(" --eig-tol (1e-5 for double, 1e-4 for single)\n"); + printfQuda(" --niter 1000\n"); + printfQuda(" --eig-n-kr 256\n"); + } } - host_free(milc_fatlink); - host_free(milc_longlink); + init(); + int result = 0; + if (enable_testing) { // tests are defined in invert_test_gtest.hpp + ::testing::TestEventListeners &listeners = ::testing::UnitTest::GetInstance()->listeners(); + if (quda::comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); } + result = RUN_ALL_TESTS(); + } else { + eigensolve( + test_t {eig_param.eig_type, eig_param.use_norm_op, eig_param.use_pc, eig_param.compute_svd, eig_param.spectrum}); + } + + cleanup(); + + // Memory clean-up + freeGaugeQuda(); + + // Finalize the QUDA library endQuda(); finalizeComms(); + + return result; } diff --git a/tests/staggered_eigensolve_test_gtest.hpp b/tests/staggered_eigensolve_test_gtest.hpp new file mode 100644 index 0000000000..382510f74b --- /dev/null +++ b/tests/staggered_eigensolve_test_gtest.hpp @@ -0,0 +1,174 @@ +#include + +using test_t = ::testing::tuple; + +class StaggeredEigensolveTest : public ::testing::TestWithParam +{ +protected: + test_t param; + +public: + StaggeredEigensolveTest() : param(GetParam()) { } +}; + +// Get the solve type that this combination corresponds to +QudaSolveType get_solve_type(QudaBoolean use_norm_op, QudaBoolean use_pc, QudaBoolean compute_svd) +{ + if (use_norm_op == QUDA_BOOLEAN_FALSE && use_pc == QUDA_BOOLEAN_TRUE && compute_svd == QUDA_BOOLEAN_FALSE) + return QUDA_DIRECT_PC_SOLVE; + else if (use_norm_op == QUDA_BOOLEAN_TRUE && use_pc == QUDA_BOOLEAN_FALSE && compute_svd == QUDA_BOOLEAN_TRUE) + return QUDA_NORMOP_SOLVE; + else if (use_norm_op == QUDA_BOOLEAN_FALSE && use_pc == QUDA_BOOLEAN_FALSE && compute_svd == QUDA_BOOLEAN_FALSE) + return QUDA_DIRECT_SOLVE; + else + return QUDA_INVALID_SOLVE; +} + +bool skip_test(test_t test_param) +{ + auto eig_type = ::testing::get<0>(test_param); + auto use_norm_op = ::testing::get<1>(test_param); + auto use_pc = ::testing::get<2>(test_param); + auto compute_svd = ::testing::get<3>(test_param); + auto spectrum = ::testing::get<4>(test_param); + + // Reverse engineer the operator type + QudaSolveType combo_solve_type = get_solve_type(use_norm_op, use_pc, compute_svd); + if (combo_solve_type == QUDA_DIRECT_PC_SOLVE) { + // matpc + + // this is only legal for the staggered and asqtad op + if (!is_staggered(dslash_type)) return true; + + // we can only compute the real part for Lanczos, and real or magnitude for Arnoldi + switch (eig_type) { + case QUDA_EIG_TR_LANCZOS: + case QUDA_EIG_BLK_TR_LANCZOS: + if (spectrum != QUDA_SPECTRUM_LR_EIG && spectrum != QUDA_SPECTRUM_SR_EIG) return true; + break; + case QUDA_EIG_IR_ARNOLDI: + if (spectrum == QUDA_SPECTRUM_LI_EIG || spectrum == QUDA_SPECTRUM_SI_EIG) return true; + break; + default: break; + } + } else if (combo_solve_type == QUDA_NORMOP_SOLVE) { + // matdag_mat + + // this is only legal for the staggered and asqtad op + if (!is_staggered(dslash_type)) return true; + + switch (eig_type) { + case QUDA_EIG_TR_LANCZOS: + case QUDA_EIG_BLK_TR_LANCZOS: + if (spectrum != QUDA_SPECTRUM_LR_EIG && spectrum != QUDA_SPECTRUM_SR_EIG) return true; + break; + case QUDA_EIG_IR_ARNOLDI: + // if (spectrum == QUDA_SPECTRUM_LI_EIG || spectrum == QUDA_SPECTRUM_SI_EIG) return true; + return true; // we skip this because it takes an unnecessarily long time and it's covered elsewhere + break; + default: return true; break; + } + } else if (combo_solve_type == QUDA_DIRECT_SOLVE) { + // mat + + switch (dslash_type) { + case QUDA_STAGGERED_DSLASH: + // only Arnoldi, imaginary part or magnitude works (real part is degenerate) + // We skip SM because it takes an unnecessarily long time and it's + // covered by HISQ + if (eig_type != QUDA_EIG_IR_ARNOLDI) return true; + if (spectrum != QUDA_SPECTRUM_LI_EIG && spectrum != QUDA_SPECTRUM_SI_EIG && spectrum != QUDA_SPECTRUM_LM_EIG) + return true; + break; + case QUDA_ASQTAD_DSLASH: + // only Arnoldi, imaginary part or magnitude works (real part is degenerate) + if (eig_type != QUDA_EIG_IR_ARNOLDI) return true; + if (spectrum == QUDA_SPECTRUM_LR_EIG || spectrum == QUDA_SPECTRUM_SR_EIG) return true; + break; + case QUDA_LAPLACE_DSLASH: + switch (eig_type) { + case QUDA_EIG_TR_LANCZOS: + case QUDA_EIG_BLK_TR_LANCZOS: + if (spectrum != QUDA_SPECTRUM_LR_EIG && spectrum != QUDA_SPECTRUM_SR_EIG) return true; + break; + case QUDA_EIG_IR_ARNOLDI: + if (spectrum == QUDA_SPECTRUM_LI_EIG || spectrum == QUDA_SPECTRUM_SI_EIG) return true; + break; + default: return true; break; + } + break; + default: return true; break; + } + } + return false; +} + +std::vector eigensolve(test_t test_param); + +TEST_P(StaggeredEigensolveTest, verify) +{ + if (skip_test(GetParam())) GTEST_SKIP(); + double factor = 1.0; + // The IRAM eigensolver will sometimes report convergence with tolerances slightly + // higher than requested. The same phenomenon occurs in ARPACK. This factor + // prevents failure when IRAM has solved to say 2e-6 when 1e-6 is requested. + // The solution to avoid this is to use a Krylov space (eig-n-kr) about 3-4 times the + // size of the search space (eig-n-ev), or use a well chosen Chebyshev polynomial, + // or use a tighter than necessary tolerance. + auto eig_type = ::testing::get<0>(GetParam()); + if (eig_type == QUDA_EIG_IR_ARNOLDI || eig_type == QUDA_EIG_BLK_IR_ARNOLDI) factor *= 10; + auto tol = factor * eig_param.tol; + for (auto rsd : eigensolve(GetParam())) EXPECT_LE(rsd, tol); +} + +std::string gettestname(::testing::TestParamInfo param) +{ + std::string name; + name += get_eig_type_str(::testing::get<0>(param.param)) + std::string("_"); + name += (::testing::get<1>(param.param) == QUDA_BOOLEAN_TRUE ? std::string("normop") : std::string("direct")) + + std::string("_"); + name += (::testing::get<2>(param.param) == QUDA_BOOLEAN_TRUE ? std::string("evenodd") : std::string("full")) + + std::string("_"); + name += (::testing::get<3>(param.param) == QUDA_BOOLEAN_TRUE ? std::string("withSVD") : std::string("noSVD")) + + std::string("_"); + name += get_eig_spectrum_str(::testing::get<4>(param.param)); + return name; +} + +using ::testing::Combine; +using ::testing::Values; + +// Can solve hermitian systems +auto hermitian_solvers = Values(QUDA_EIG_TR_LANCZOS, QUDA_EIG_BLK_TR_LANCZOS, QUDA_EIG_IR_ARNOLDI); + +// Can solve non-hermitian systems +auto non_hermitian_solvers = Values(QUDA_EIG_IR_ARNOLDI); + +// Eigensolver spectrum types +auto hermitian_spectrum = Values(QUDA_SPECTRUM_LR_EIG, QUDA_SPECTRUM_SR_EIG); +auto non_hermitian_spectrum = Values(QUDA_SPECTRUM_LR_EIG, QUDA_SPECTRUM_SR_EIG, QUDA_SPECTRUM_LM_EIG, + QUDA_SPECTRUM_SM_EIG, QUDA_SPECTRUM_LI_EIG, QUDA_SPECTRUM_SI_EIG); + +// using test_t = ::testing::tuple; // Largest real, smallest real, etc + +// Preconditioned direct operators, which are HPD for staggered! +INSTANTIATE_TEST_SUITE_P(DirectEvenOdd, StaggeredEigensolveTest, + ::testing::Combine(hermitian_solvers, Values(QUDA_BOOLEAN_FALSE), Values(QUDA_BOOLEAN_TRUE), + Values(QUDA_BOOLEAN_FALSE), hermitian_spectrum), + gettestname); + +// full system normal solve +INSTANTIATE_TEST_SUITE_P(NormalFull, StaggeredEigensolveTest, + ::testing::Combine(hermitian_solvers, Values(QUDA_BOOLEAN_TRUE), Values(QUDA_BOOLEAN_FALSE), + Values(QUDA_BOOLEAN_TRUE), hermitian_spectrum), + gettestname); + +// full system direct solve +INSTANTIATE_TEST_SUITE_P(DirectFull, StaggeredEigensolveTest, + ::testing::Combine(hermitian_solvers, Values(QUDA_BOOLEAN_FALSE), Values(QUDA_BOOLEAN_FALSE), + Values(QUDA_BOOLEAN_FALSE), non_hermitian_spectrum), + gettestname); diff --git a/tests/staggered_invert_test.cpp b/tests/staggered_invert_test.cpp index 5fcd338b4b..113d909fe7 100644 --- a/tests/staggered_invert_test.cpp +++ b/tests/staggered_invert_test.cpp @@ -18,7 +18,19 @@ #include #include -#define MAX(a, b) ((a) > (b) ? (a) : (b)) +QudaGaugeParam gauge_param; +QudaInvertParam inv_param; +QudaMultigridParam mg_param; +QudaInvertParam mg_inv_param; +QudaEigParam mg_eig_param[QUDA_MAX_MG_LEVEL]; +QudaEigParam eig_param; +bool use_split_grid = false; + +// print instructions on how to run the old tests +bool print_legacy_info = false; + +// if --enable-testing true is passed, we run the tests defined in here +#include void display_test_info() { @@ -102,32 +114,42 @@ void display_test_info() dimPartitioned(3)); } -void test(int argc, char **argv) +void display_legacy_info() +{ + printfQuda("Instructions for running legacy tests:\n"); + printfQuda("--test 0 -> --solve-type direct --solution-type mat --inv-type bicgstab\n"); + printfQuda("--test 1 -> --solve-type direct-pc --solution-type mat --inv-type cg --matpc even-even\n"); + printfQuda("--test 2 -> --solve-type direct-pc --solution-type mat --inv-type cg --matpc odd-odd\n"); + printfQuda("--test 3 -> --solve-type direct-pc --solution-type mat-pc --inv-type cg --matpc even-even\n"); + printfQuda("--test 4 -> --solve-type direct-pc --solution-type mat-pc --inv-type cg --matpc odd-odd\n"); + printfQuda( + "--test 5 -> --solve-type direct-pc --solution-type mat-pc --inv-type cg --matpc even-even --multishift 8\n"); + printfQuda( + "--test 6 -> --solve-type direct-pc --solution-type mat-pc --inv-type cg --matpc odd-odd --multishift 8\n"); +} + +GaugeField cpuFatQDP = {}; +GaugeField cpuLongQDP = {}; +GaugeField cpuFatMILC = {}; +GaugeField cpuLongMILC = {}; + +void init() { // Set QUDA internal parameters - QudaGaugeParam gauge_param = newQudaGaugeParam(); - QudaInvertParam inv_param = newQudaInvertParam(); + gauge_param = newQudaGaugeParam(); setStaggeredGaugeParam(gauge_param); - if (!inv_multigrid) setStaggeredInvertParam(inv_param); - QudaInvertParam mg_inv_param = newQudaInvertParam(); - QudaMultigridParam mg_param = newQudaMultigridParam(); - QudaEigParam mg_eig_param[mg_levels]; - - // params related to split grid. - for (int i = 0; i < 4; i++) inv_param.split_grid[i] = grid_partition[i]; - int num_sub_partition = grid_partition[0] * grid_partition[1] * grid_partition[2] * grid_partition[3]; - bool use_split_grid = num_sub_partition > 1; + inv_param = newQudaInvertParam(); + mg_inv_param = newQudaInvertParam(); + mg_param = newQudaMultigridParam(); + eig_param = newQudaEigParam(); if (inv_multigrid) { - // Set some default values for MG solve types setQudaMgSolveTypes(); - setStaggeredMGInvertParam(inv_param); // Set sub structures mg_param.invert_param = &mg_inv_param; - for (int i = 0; i < mg_levels; i++) { if (mg_eig[i]) { mg_eig_param[i] = newQudaEigParam(); @@ -137,10 +159,12 @@ void test(int argc, char **argv) mg_param.eig_param[i] = nullptr; } } + // Set MG setStaggeredMultigridParam(mg_param); + } else { + setStaggeredInvertParam(inv_param); } - QudaEigParam eig_param = newQudaEigParam(); if (inv_deflate) { setEigParam(eig_param); inv_param.eig_param = &eig_param; @@ -150,7 +174,6 @@ void test(int argc, char **argv) } setDims(gauge_param.X); - // Hack: use the domain wall dimensions so we may use the 5th dim for multi indexing dw_setDims(gauge_param.X, 1); // Staggered Gauge construct START @@ -163,25 +186,25 @@ void test(int argc, char **argv) gauge_param.location = QUDA_CPU_FIELD_LOCATION; GaugeFieldParam cpuParam(gauge_param); - cpuParam.create = QUDA_NULL_FIELD_CREATE; - cpuParam.ghostExchange = QUDA_GHOST_EXCHANGE_PAD; cpuParam.order = QUDA_QDP_GAUGE_ORDER; + cpuParam.ghostExchange = QUDA_GHOST_EXCHANGE_PAD; + cpuParam.create = QUDA_NULL_FIELD_CREATE; GaugeField cpuIn = GaugeField(cpuParam); - GaugeField cpuFatQDP = GaugeField(cpuParam); + cpuFatQDP = GaugeField(cpuParam); cpuParam.order = QUDA_MILC_GAUGE_ORDER; - GaugeField cpuFatMILC = GaugeField(cpuParam); + cpuFatMILC = GaugeField(cpuParam); cpuParam.link_type = QUDA_ASQTAD_LONG_LINKS; cpuParam.nFace = 3; cpuParam.order = QUDA_QDP_GAUGE_ORDER; - GaugeField cpuLongQDP = GaugeField(cpuParam); + cpuLongQDP = GaugeField(cpuParam); cpuParam.order = QUDA_MILC_GAUGE_ORDER; - GaugeField cpuLongMILC = GaugeField(cpuParam); + cpuLongMILC = GaugeField(cpuParam); void *qdp_inlink[4] = {cpuIn.data(0), cpuIn.data(1), cpuIn.data(2), cpuIn.data(3)}; void *qdp_fatlink[4] = {cpuFatQDP.data(0), cpuFatQDP.data(1), cpuFatQDP.data(2), cpuFatQDP.data(3)}; void *qdp_longlink[4] = {cpuLongQDP.data(0), cpuLongQDP.data(1), cpuLongQDP.data(2), cpuLongQDP.data(3)}; - constructStaggeredHostGaugeField(qdp_inlink, qdp_longlink, qdp_fatlink, gauge_param, argc, argv, true); + constructStaggeredHostGaugeField(qdp_inlink, qdp_longlink, qdp_fatlink, gauge_param, 0, nullptr, true); // Reorder gauge fields to MILC order cpuFatMILC = cpuFatQDP; @@ -200,6 +223,8 @@ void test(int argc, char **argv) printfQuda("Computed fat link plaquette is %e (spatial = %e, temporal = %e)\n", plaq[0], plaq[1], plaq[2]); } + freeGaugeQuda(); + loadFatLongGaugeQuda(cpuFatMILC.data(), cpuLongMILC.data(), gauge_param); // now copy back to QDP aliases, since these are used for the reference dslash @@ -211,6 +236,36 @@ void test(int argc, char **argv) // Staggered Gauge construct END //----------------------------------------------------------------------------------- +} + +std::vector> solve(test_t param) +{ + inv_param.inv_type = ::testing::get<0>(param); + inv_param.solution_type = ::testing::get<1>(param); + inv_param.solve_type = ::testing::get<2>(param); + inv_param.cuda_prec_sloppy = ::testing::get<3>(param); + multishift = ::testing::get<4>(param); + inv_param.solution_accumulator_pipeline = ::testing::get<5>(param); + + // schwarz parameters + auto schwarz_param = ::testing::get<6>(param); + inv_param.schwarz_type = ::testing::get<0>(schwarz_param); + inv_param.inv_type_precondition = ::testing::get<1>(schwarz_param); + inv_param.cuda_prec_precondition = ::testing::get<2>(schwarz_param); + + inv_param.residual_type = ::testing::get<7>(param); + + // reset lambda_max if we're doing a testing loop to ensure correct lambma_max + if (enable_testing) inv_param.ca_lambda_max = -1.0; + + logQuda(QUDA_SUMMARIZE, "Solution = %s, Solve = %s, Solver = %s, Sloppy precision = %s\n", + get_solution_str(inv_param.solution_type), get_solve_str(inv_param.solve_type), + get_solver_str(inv_param.inv_type), get_prec_str(inv_param.cuda_prec_sloppy)); + + // params related to split grid. + for (int i = 0; i < 4; i++) inv_param.split_grid[i] = grid_partition[i]; + int num_sub_partition = grid_partition[0] * grid_partition[1] * grid_partition[2] * grid_partition[3]; + use_split_grid = num_sub_partition > 1; // Setup the multigrid preconditioner void *mg_preconditioner = nullptr; @@ -226,88 +281,26 @@ void test(int argc, char **argv) //----------------------------------------------------------------------------------- std::vector in(Nsrc); std::vector out(Nsrc); + std::vector out_multishift(Nsrc * multishift); quda::ColorSpinorParam cs_param; constructStaggeredTestSpinorParam(&cs_param, &inv_param, &gauge_param); - for (int k = 0; k < Nsrc; k++) { - in[k] = quda::ColorSpinorField(cs_param); - out[k] = quda::ColorSpinorField(cs_param); - } - ColorSpinorField ref(cs_param); - ColorSpinorField tmp(cs_param); + std::vector> _hp_multi_x(Nsrc, std::vector(multishift)); + // Staggered vector construct END //----------------------------------------------------------------------------------- - // Prepare rng - quda::RNG rng(ref, 1234); - - // Performance measuring - std::vector time(Nsrc); - std::vector gflops(Nsrc); - std::vector iter(Nsrc); - - // QUDA invert test - //---------------------------------------------------------------------------- - - if (test_type >= 0 && test_type <= 4) { - // case 0: // full parity solution, full parity system - // case 1: // full parity solution, solving EVEN EVEN prec system - // case 2: // full parity solution, solving ODD ODD prec system - // case 3: // even parity solution, solving EVEN system - // case 4: // odd parity solution, solving ODD system - - if (multishift != 1) errorQuda("Multishift not supported for test %d\n", test_type); - - for (int k = 0; k < Nsrc; k++) { quda::spinorNoise(in[k], rng, QUDA_NOISE_UNIFORM); } - - if (!use_split_grid) { - for (int k = 0; k < Nsrc; k++) { - if (inv_deflate) eig_param.preserve_deflation = k < Nsrc - 1 ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE; - invertQuda(out[k].data(), in[k].data(), &inv_param); - time[k] = inv_param.secs; - gflops[k] = inv_param.gflops / inv_param.secs; - iter[k] = inv_param.iter; - printfQuda("Done: %i iter / %g secs = %g Gflops\n\n", inv_param.iter, inv_param.secs, - inv_param.gflops / inv_param.secs); - } - } else { - std::vector _hp_x(Nsrc); - std::vector _hp_b(Nsrc); - for (int k = 0; k < Nsrc; k++) { - _hp_x[k] = out[k].data(); - _hp_b[k] = in[k].data(); - } - inv_param.num_src = Nsrc; - inv_param.num_src_per_sub_partition = Nsrc / num_sub_partition; - invertMultiSrcStaggeredQuda(_hp_x.data(), _hp_b.data(), &inv_param, cpuFatMILC.data(), cpuLongMILC.data(), - &gauge_param); - quda::comm_allreduce_int(inv_param.iter); - inv_param.iter /= comm_size() / num_sub_partition; - quda::comm_allreduce_sum(inv_param.gflops); - inv_param.gflops /= comm_size() / num_sub_partition; - quda::comm_allreduce_max(inv_param.secs); - printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops\n\n", num_sub_partition, inv_param.iter, - inv_param.secs, inv_param.gflops / inv_param.secs); - } + // Setup multishift parameters (if needed) + //--------------------------------------------------------------------------- - for (int k = 0; k < Nsrc; k++) { - if (verify_results) - verifyStaggeredInversion(tmp, ref, in[k], out[k], mass, cpuFatQDP, cpuLongQDP, gauge_param, inv_param, 0); - } - } else if (test_type == 5 || test_type == 6) { - // case 5: // multi mass CG, even parity solution, solving EVEN system - // case 6: // multi mass CG, odd parity solution, solving ODD system + // Masses + std::vector masses(multishift); + if (multishift > 1) { if (use_split_grid) errorQuda("Multishift currently doesn't support split grid.\n"); - if (multishift < 2) - errorQuda("Multishift inverter requires more than one shift, multishift = %d\n", multishift); - inv_param.num_offset = multishift; - // Prepare vectors for masses - std::vector masses(multishift); - // Consistency check for masses, tols, tols_hq size if we're setting custom values if (multishift_shifts.size() != 0) errorQuda("Multishift shifts are not supported for Wilson-type fermions"); @@ -318,51 +311,129 @@ void test(int argc, char **argv) if (multishift_tols_hq.size() != 0 && multishift_tols_hq.size() != static_cast(multishift)) errorQuda("Multishift hq tolerance count %d does not agree with number of masses passed in %lu\n", multishift, multishift_tols_hq.size()); - // Allocate storage of output arrays - std::vector outArray(multishift); - std::vector qudaOutArray(multishift, cs_param); - - // Copy offsets and tolerances into inv_param; copy data pointers into outArray + // Copy offsets and tolerances into inv_param; allocate and copy data pointers for (int i = 0; i < multishift; i++) { masses[i] = (multishift_masses.size() == 0 ? (mass + i * i * 0.01) : multishift_masses[i]); inv_param.offset[i] = 4 * masses[i] * masses[i]; inv_param.tol_offset[i] = (multishift_tols.size() == 0 ? inv_param.tol : multishift_tols[i]); inv_param.tol_hq_offset[i] = (multishift_tols_hq.size() == 0 ? inv_param.tol_hq : multishift_tols_hq[i]); - outArray[i] = qudaOutArray[i].data(); + // Allocate memory and set pointers + for (int n = 0; n < Nsrc; n++) { + out_multishift[n * multishift + i] = quda::ColorSpinorField(cs_param); + _hp_multi_x[n][i] = out_multishift[n * multishift + i].data(); + } logQuda(QUDA_VERBOSE, "Multishift mass %d = %e ; tolerance %e ; hq tolerance %e\n", i, masses[i], inv_param.tol_offset[i], inv_param.tol_hq_offset[i]); } + } - for (int k = 0; k < Nsrc; k++) { - quda::spinorNoise(in[k], rng, QUDA_NOISE_UNIFORM); - invertMultiShiftQuda((void **)outArray.data(), in[k].data(), &inv_param); + // Setup multishift parameters END + //----------------------------------------------------------------------------------- - time[k] = inv_param.secs; - gflops[k] = inv_param.gflops / inv_param.secs; - iter[k] = inv_param.iter; - printfQuda("Done: %i iter / %g secs = %g Gflops\n\n", inv_param.iter, inv_param.secs, - inv_param.gflops / inv_param.secs); + // Prepare rng, fill host spinors with random numbers + //----------------------------------------------------------------------------------- - for (int i = 0; i < multishift; i++) { - printfQuda("%dth solution: mass=%f, ", i, masses[i]); - verifyStaggeredInversion(tmp, ref, in[k], qudaOutArray[i], masses[i], cpuFatQDP, cpuLongQDP, gauge_param, - inv_param, i); + std::vector time(Nsrc); + std::vector gflops(Nsrc); + std::vector iter(Nsrc); + + // Create a temporary spinor just to seed the rng + quda::ColorSpinorField tmp(cs_param); + quda::RNG rng(tmp, 1234); + tmp = quda::ColorSpinorField(); + + for (int n = 0; n < Nsrc; n++) { + // Populate the host spinor with random numbers. + in[n] = quda::ColorSpinorField(cs_param); + quda::spinorNoise(in[n], rng, QUDA_NOISE_UNIFORM); + out[n] = quda::ColorSpinorField(cs_param); + } + + // Prepare rng, fill host spinors with random numbers END + //----------------------------------------------------------------------------------- + + // QUDA invert test + //---------------------------------------------------------------------------- + + if (!use_split_grid) { + + for (int n = 0; n < Nsrc; n++) { + // If deflating, preserve the deflation space between solves + if (inv_deflate) eig_param.preserve_deflation = n < Nsrc - 1 ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE; + // Perform QUDA inversions + if (multishift > 1) { + invertMultiShiftQuda(_hp_multi_x[n].data(), in[n].data(), &inv_param); + } else { + invertQuda(out[n].data(), in[n].data(), &inv_param); } + + time[n] = inv_param.secs; + gflops[n] = inv_param.gflops / inv_param.secs; + iter[n] = inv_param.iter; + printfQuda("Done: %i iter / %g secs = %g Gflops\n\n", inv_param.iter, inv_param.secs, + inv_param.gflops / inv_param.secs); } } else { - errorQuda("Unsupported test type"); - } // switch + inv_param.num_src = Nsrc; + inv_param.num_src_per_sub_partition = Nsrc / num_sub_partition; + // Host arrays for solutions, sources, and check + std::vector _hp_x(Nsrc); + std::vector _hp_b(Nsrc); + for (int n = 0; n < Nsrc; n++) { + _hp_x[n] = out[n].data(); + _hp_b[n] = in[n].data(); + } + // Run split grid + invertMultiSrcStaggeredQuda(_hp_x.data(), _hp_b.data(), &inv_param, cpuFatMILC.data(), cpuLongMILC.data(), + &gauge_param); + + quda::comm_allreduce_int(inv_param.iter); + inv_param.iter /= comm_size() / num_sub_partition; + quda::comm_allreduce_sum(inv_param.gflops); + inv_param.gflops /= comm_size() / num_sub_partition; + quda::comm_allreduce_max(inv_param.secs); + printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops\n\n", num_sub_partition, inv_param.iter, + inv_param.secs, inv_param.gflops / inv_param.secs); + } + + // Free the multigrid solver + if (inv_multigrid) destroyMultigridQuda(mg_preconditioner); // Compute timings if (Nsrc > 1 && !use_split_grid) performanceStats(time, gflops, iter); - // Free the multigrid solver - if (inv_multigrid) destroyMultigridQuda(mg_preconditioner); + std::vector> res(Nsrc); + // Perform host side verification of inversion if requested + if (verify_results) { + for (int n = 0; n < Nsrc; n++) { + if (multishift > 1) { + printfQuda("\nSource %d:\n", n); + // Create an appropriate subset of the full out_multishift vector + std::vector out_subset + = {out_multishift.begin() + n * multishift, out_multishift.begin() + (n + 1) * multishift}; + res[n] = verifyStaggeredInversion(in[n], out_subset, cpuFatQDP, cpuLongQDP, inv_param); + } else { + res[n] = verifyStaggeredInversion(in[n], out[n], cpuFatQDP, cpuLongQDP, inv_param); + } + } + } + + return res; +} + +void cleanup() +{ + cpuFatQDP = {}; + cpuLongQDP = {}; + cpuFatMILC = {}; + cpuLongMILC = {}; } int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + setQudaStaggeredDefaultInvTestParams(); setQudaDefaultMgTestParams(); // Parse command line options auto app = make_app(); @@ -370,20 +441,15 @@ int main(int argc, char **argv) add_deflation_option_group(app); add_multigrid_option_group(app); add_comms_option_group(app); - CLI::TransformPairs test_type_map {{"full", 0}, {"full_ee_prec", 1}, {"full_oo_prec", 2}, {"even", 3}, - {"odd", 4}, {"mcg_even", 5}, {"mcg_odd", 6}}; - app->add_option("--test", test_type, "Test method")->transform(CLI::CheckedTransformer(test_type_map)); + add_testing_option_group(app); + app->add_option("--legacy-test-info", print_legacy_info, + "Print info on how to reproduce the old '--test #' behavior with flags, then exit"); try { app->parse(argc, argv); } catch (const CLI::ParseError &e) { return app->exit(e); } setVerbosity(verbosity); - if (!inv_multigrid) solve_type = QUDA_INVALID_SOLVE; - - if (inv_deflate && inv_multigrid) { - errorQuda("Error: Cannot use both deflation and multigrid preconditioners on top level solve"); - } // Set values for precisions via the command line. setQudaPrecisions(); @@ -391,37 +457,88 @@ int main(int argc, char **argv) // initialize QMP/MPI, QUDA comms grid and RNG (host_utils.cpp) initComms(argc, argv, gridsize_from_cmdline); + if (print_legacy_info) { + display_legacy_info(); + errorQuda("Exiting..."); + } + + if (inv_deflate && inv_multigrid) + errorQuda("Error: Cannot use both deflation and multigrid preconditioners on top level solve"); + initRand(); - // Only these fermions are supported in this file. Ensure a reasonable default, - // ensure that the default is improved staggered - if (dslash_type != QUDA_STAGGERED_DSLASH && dslash_type != QUDA_ASQTAD_DSLASH && dslash_type != QUDA_LAPLACE_DSLASH) { - printfQuda("dslash_type %s not supported, defaulting to %s\n", get_dslash_str(dslash_type), - get_dslash_str(QUDA_ASQTAD_DSLASH)); - dslash_type = QUDA_ASQTAD_DSLASH; + // Only these fermions are supported in this file + if constexpr (is_enabled_laplace()) { + if (!is_staggered(dslash_type) && !is_laplace(dslash_type)) + errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type)); + } else { + if (is_laplace(dslash_type)) errorQuda("The Laplace dslash is not enabled, cmake configure with -DQUDA_LAPLACE=ON"); + if (!is_staggered(dslash_type)) errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type)); } // Need to add support for LAPLACE MG? if (inv_multigrid) { - if (dslash_type != QUDA_STAGGERED_DSLASH && dslash_type != QUDA_ASQTAD_DSLASH) { + if (!is_staggered(dslash_type)) { errorQuda("dslash_type %s not supported for multigrid preconditioner", get_dslash_str(dslash_type)); } } - // Deduce operator, solution, and operator preconditioning types - if (!inv_multigrid) setQudaStaggeredInvTestParams(); - display_test_info(); initQuda(device_ordinal); - test(argc, argv); + if (enable_testing) { + // We need to force a well-behaved operator + reasonable convergence, otherwise + // the staggered tests will fail. These checks are designed to be consistent + // with what's in [src]/tests/CMakeFiles.txt, which have been "sanity checked" + bool changes = false; + if (!compute_fatlong) { + compute_fatlong = true; + changes = true; + } + + double expected_tol = (prec == QUDA_SINGLE_PRECISION) ? 1e-5 : 1e-6; + if (tol != expected_tol) { + tol = expected_tol; + changes = true; + } + if (tol_hq != expected_tol) { + tol_hq = expected_tol; + changes = true; + } + if (niter != 1000) { + niter = 1000; + changes = true; + } + + if (changes) { + printfQuda("For gtest, various defaults are changed:\n"); + printfQuda(" --compute-fat-long true\n"); + printfQuda(" --tol (1e-6 for double, 1e-5 for single)\n"); + printfQuda(" --tol-hq (1e-6 for double, 1e-5 for single)\n"); + printfQuda(" --niter 1000\n"); + } + } + + init(); + + int result = 0; + if (enable_testing) { // tests are defined in staggered_invert_test_gtest.hpp + ::testing::TestEventListeners &listeners = ::testing::UnitTest::GetInstance()->listeners(); + if (quda::comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); } + result = RUN_ALL_TESTS(); + } else { + solve(test_t {inv_type, solution_type, solve_type, prec_sloppy, multishift, solution_accumulator_pipeline, + schwarz_t {precon_schwarz_type, inv_multigrid ? QUDA_MG_INVERTER : precon_type, prec_precondition}, + inv_param.residual_type}); + } + + cleanup(); // Finalize the QUDA library + freeGaugeQuda(); endQuda(); - - // Finalize the communications layer finalizeComms(); - return 0; + return result; } diff --git a/tests/staggered_invert_test_gtest.hpp b/tests/staggered_invert_test_gtest.hpp new file mode 100644 index 0000000000..a4e7bcda90 --- /dev/null +++ b/tests/staggered_invert_test_gtest.hpp @@ -0,0 +1,213 @@ +#include +#include + +// tuple containing parameters for Schwarz solver +using schwarz_t = ::testing::tuple; + +using test_t + = ::testing::tuple; + +class StaggeredInvertTest : public ::testing::TestWithParam +{ +protected: + test_t param; + +public: + StaggeredInvertTest() : param(GetParam()) { } +}; + +bool skip_test(test_t param) +{ + auto inverter_type = ::testing::get<0>(param); + auto solution_type = ::testing::get<1>(param); + auto solve_type = ::testing::get<2>(param); + auto prec_sloppy = ::testing::get<3>(param); + auto multishift = ::testing::get<4>(param); + auto solution_accumulator_pipeline = ::testing::get<5>(param); + auto schwarz_param = ::testing::get<6>(param); + auto prec_precondition = ::testing::get<2>(schwarz_param); + + if (prec < prec_sloppy) return true; // outer precision >= sloppy precision + if (!(QUDA_PRECISION & prec_sloppy)) return true; // precision not enabled so skip it + if (!(QUDA_PRECISION & prec_precondition) && prec_precondition != QUDA_INVALID_PRECISION) + return true; // precision not enabled so skip it + if (prec_sloppy < prec_precondition) return true; // sloppy precision >= preconditioner precision + + // Skip if the inverter does not support batched update and batched update is greater than one + if (!support_solution_accumulator_pipeline(inverter_type) && solution_accumulator_pipeline > 1) return true; + // There's no MLocal or MdagMLocal support yet, this is left in for reference + // if (is_normal_solve(param) && ::testing::get<0>(schwarz_param) != QUDA_INVALID_SCHWARZ) + // if (dslash_type != QUDA_MOBIUS_DWF_DSLASH) return true; + + if (is_laplace(dslash_type)) { + if (multishift > 1) return true; // Laplace doesn't support multishift + if (solution_type != QUDA_MAT_SOLUTION || solve_type != QUDA_DIRECT_SOLVE) + return true; // Laplace only supports direct solves + } + + if (is_staggered(dslash_type)) { + // the staggered and asqtad operators aren't HPD + if (solution_type == QUDA_MAT_SOLUTION && solve_type == QUDA_DIRECT_SOLVE && is_hermitian_solver(inverter_type)) + return true; + + // MR struggles with the staggered and asqtad spectrum, it's not MR's fault + if (solution_type == QUDA_MAT_SOLUTION && solve_type == QUDA_DIRECT_SOLVE && inverter_type == QUDA_MR_INVERTER) + return true; + } + + // split-grid doesn't support multigrid at present + if (use_split_grid && multishift > 1) return true; + + return false; +} + +std::vector> solve(test_t param); + +TEST_P(StaggeredInvertTest, verify) +{ + if (skip_test(GetParam())) GTEST_SKIP(); + + inv_param.tol = 0.0; + inv_param.tol_hq = 0.0; + auto res_t = ::testing::get<7>(GetParam()); + if (res_t & QUDA_L2_RELATIVE_RESIDUAL) inv_param.tol = tol; + if (res_t & QUDA_HEAVY_QUARK_RESIDUAL) inv_param.tol_hq = tol_hq; + + auto inverter_type = ::testing::get<0>(param); + auto solution_type = ::testing::get<1>(param); + auto solve_type = ::testing::get<2>(param); + + // Make a local copy of "tol" for modification in place + auto verify_tol = tol; + + // FIXME eventually we should build in refinement to the *NR solvers to remove the need for this + // The mass squared is a proxy for the condition number + if (is_normal_residual(inverter_type)) verify_tol /= (0.25 * mass * mass); + + // To solve the direct operator to a given tolerance, grind the preconditioned + // operator to 0.5 * mass * tol... to keep the target tolerance in inv_param + // in check, we shift the requirement to the verified tolerance instead. + if (solution_type == QUDA_MAT_SOLUTION) { + if (solve_type == QUDA_DIRECT_PC_SOLVE) + verify_tol /= (0.5 * mass); // to solve the full operator to eps, solve the preconditioned to mass * eps + if (solve_type == QUDA_NORMOP_SOLVE) verify_tol /= (0.5 * mass); // a proxy for the condition number + } + + // The power iterations method of determining the Chebyshev window + // breaks down due to the nature of the spectrum of the direct operator + auto ca_basis_tmp = inv_param.ca_basis; + if (solve_type == QUDA_DIRECT_SOLVE && inverter_type == QUDA_CA_GCR_INVERTER) inv_param.ca_basis = QUDA_POWER_BASIS; + + // Single precision needs a tiny bump due to small host/device precision deviations + if (prec == QUDA_SINGLE_PRECISION) verify_tol *= 1.01; + + for (auto rsd : solve(GetParam())) { + if (res_t & QUDA_L2_RELATIVE_RESIDUAL) { EXPECT_LE(rsd[0], verify_tol); } + if (res_t & QUDA_HEAVY_QUARK_RESIDUAL) { EXPECT_LE(rsd[1], tol_hq); } + } + + inv_param.ca_basis = ca_basis_tmp; +} + +std::string gettestname(::testing::TestParamInfo param) +{ + std::string name; + name += get_solver_str(::testing::get<0>(param.param)) + std::string("_"); + name += get_solution_str(::testing::get<1>(param.param)) + std::string("_"); + name += get_solve_str(::testing::get<2>(param.param)) + std::string("_"); + name += get_prec_str(::testing::get<3>(param.param)); + if (::testing::get<4>(param.param) > 1) + name += std::string("_shift") + std::to_string(::testing::get<4>(param.param)); + if (::testing::get<5>(param.param) > 1) + name += std::string("_solution_accumulator_pipeline") + std::to_string(::testing::get<5>(param.param)); + auto &schwarz_param = ::testing::get<6>(param.param); + if (::testing::get<0>(schwarz_param) != QUDA_INVALID_SCHWARZ) { + name += std::string("_") + get_schwarz_str(::testing::get<0>(schwarz_param)); + name += std::string("_") + get_solver_str(::testing::get<1>(schwarz_param)); + name += std::string("_") + get_prec_str(::testing::get<2>(schwarz_param)); + } + auto res_t = ::testing::get<7>(param.param); + if (res_t & QUDA_L2_RELATIVE_RESIDUAL) name += std::string("_l2"); + if (res_t & QUDA_HEAVY_QUARK_RESIDUAL) name += std::string("_heavy_quark"); + return name; +} + +using ::testing::Combine; +using ::testing::Values; + +auto staggered_pc_solvers + = Values(QUDA_CG_INVERTER, QUDA_CA_CG_INVERTER, QUDA_PCG_INVERTER, QUDA_GCR_INVERTER, QUDA_CA_GCR_INVERTER, + QUDA_BICGSTAB_INVERTER, QUDA_BICGSTABL_INVERTER, QUDA_MR_INVERTER); + +auto normal_solvers = Values(QUDA_CG_INVERTER, QUDA_CA_CG_INVERTER, QUDA_PCG_INVERTER); + +auto direct_solvers = Values(QUDA_CG_INVERTER, QUDA_CA_CG_INVERTER, QUDA_CGNE_INVERTER, QUDA_CGNR_INVERTER, + QUDA_CA_CGNE_INVERTER, QUDA_CA_CGNR_INVERTER, QUDA_GCR_INVERTER, QUDA_CA_GCR_INVERTER, + QUDA_BICGSTAB_INVERTER, QUDA_BICGSTABL_INVERTER, QUDA_MR_INVERTER); + +auto sloppy_precisions + = Values(QUDA_DOUBLE_PRECISION, QUDA_SINGLE_PRECISION, QUDA_HALF_PRECISION, QUDA_QUARTER_PRECISION); + +auto solution_accumulator_pipelines = Values(1, 8); + +auto no_schwarz = Combine(Values(QUDA_INVALID_SCHWARZ), Values(QUDA_INVALID_INVERTER), Values(QUDA_INVALID_PRECISION)); + +auto no_heavy_quark = Values(QUDA_L2_RELATIVE_RESIDUAL); + +// the staggered PC op doesn't support "normal" operators since it's already +// Hermitian positive definite + +// preconditioned solves +INSTANTIATE_TEST_SUITE_P(EvenOdd, StaggeredInvertTest, + Combine(staggered_pc_solvers, Values(QUDA_MATPC_SOLUTION, QUDA_MAT_SOLUTION), + Values(QUDA_DIRECT_PC_SOLVE), sloppy_precisions, Values(1), + solution_accumulator_pipelines, no_schwarz, no_heavy_quark), + gettestname); + +// full system normal solve +INSTANTIATE_TEST_SUITE_P(NormalFull, StaggeredInvertTest, + Combine(normal_solvers, Values(QUDA_MATDAG_MAT_SOLUTION, QUDA_MAT_SOLUTION), + Values(QUDA_NORMOP_SOLVE), sloppy_precisions, Values(1), + solution_accumulator_pipelines, no_schwarz, no_heavy_quark), + gettestname); + +// full system direct solve +INSTANTIATE_TEST_SUITE_P(Full, StaggeredInvertTest, + Combine(direct_solvers, Values(QUDA_MAT_SOLUTION), Values(QUDA_DIRECT_SOLVE), sloppy_precisions, + Values(1), solution_accumulator_pipelines, no_schwarz, no_heavy_quark), + gettestname); + +// preconditioned multi-shift solves +INSTANTIATE_TEST_SUITE_P(MultiShiftEvenOdd, StaggeredInvertTest, + Combine(Values(QUDA_CG_INVERTER), Values(QUDA_MATPC_SOLUTION), Values(QUDA_DIRECT_PC_SOLVE), + sloppy_precisions, Values(10), solution_accumulator_pipelines, no_schwarz, + no_heavy_quark), + gettestname); + +// Heavy-Quark preconditioned solves +INSTANTIATE_TEST_SUITE_P(HeavyQuarkEvenOdd, StaggeredInvertTest, + Combine(Values(QUDA_CG_INVERTER), Values(QUDA_MATPC_SOLUTION), Values(QUDA_DIRECT_PC_SOLVE), + sloppy_precisions, Values(1), solution_accumulator_pipelines, no_schwarz, + Values(QUDA_L2_RELATIVE_RESIDUAL | QUDA_HEAVY_QUARK_RESIDUAL, QUDA_HEAVY_QUARK_RESIDUAL)), + gettestname); + +// These are left in but commented out for future reference + +// Schwarz-preconditioned normal solves +// INSTANTIATE_TEST_SUITE_P(SchwarzNormal, StaggeredInvertTest, +// Combine(Values(QUDA_PCG_INVERTER), Values(QUDA_MATPCDAG_MATPC_SOLUTION), +// Values(QUDA_NORMOP_PC_SOLVE), sloppy_precisions, Values(1), +// solution_accumulator_pipelines, +// Combine(Values(QUDA_ADDITIVE_SCHWARZ), Values(QUDA_CG_INVERTER, QUDA_CA_CG_INVERTER), +// Values(QUDA_HALF_PRECISION, QUDA_QUARTER_PRECISION)), +// no_heavy_quark), +// gettestname); + +// Schwarz-preconditioned direct solves +// INSTANTIATE_TEST_SUITE_P(SchwarzEvenOdd, StaggeredInvertTest, +// Combine(Values(QUDA_GCR_INVERTER), Values(QUDA_MATPC_SOLUTION), Values(QUDA_DIRECT_PC_SOLVE), +// sloppy_precisions, Values(1), solution_accumulator_pipelines, +// Combine(Values(QUDA_ADDITIVE_SCHWARZ), Values(QUDA_MR_INVERTER, QUDA_CA_GCR_INVERTER), +// Values(QUDA_HALF_PRECISION, QUDA_QUARTER_PRECISION)), +// no_heavy_quark), +// gettestname); diff --git a/tests/utils/command_line_params.cpp b/tests/utils/command_line_params.cpp index ea7c779f05..17b12edcea 100644 --- a/tests/utils/command_line_params.cpp +++ b/tests/utils/command_line_params.cpp @@ -474,7 +474,7 @@ std::shared_ptr make_app(std::string app_description, std::string app_n quda_app->add_option("--device", device_ordinal, "Set the CUDA device to use (default 0, single GPU only)") ->check(CLI::Range(0, 16)); - quda_app->add_option("--dslash-type", dslash_type, "Set the dslash type") + quda_app->add_option("--dslash-type", dslash_type, "Set the dslash type (default wilson or asqtad as appropriate)") ->transform(CLI::QUDACheckedTransformer(dslash_type_map)); quda_app->add_option("--epsilon", epsilon, "Twisted-Mass flavor twist of Dirac operator (default 0.01)"); @@ -502,7 +502,8 @@ std::shared_ptr make_app(std::string app_description, std::string app_n ->transform(CLI::QUDACheckedTransformer(mass_normalization_map)); quda_app - ->add_option("--matpc", matpc_type, "Matrix preconditioning type (even-even, odd-odd, even-even-asym, odd-odd-asym)") + ->add_option("--matpc", matpc_type, + "Matrix preconditioning type (even-even (default), odd-odd, even-even-asym, odd-odd-asym)") ->transform(CLI::QUDACheckedTransformer(matpc_type_map)); quda_app->add_option("--msrc", Msrc, "Used for testing non-square block blas routines where nsrc defines the other dimension"); @@ -601,9 +602,9 @@ std::shared_ptr make_app(std::string app_description, std::string app_n "The pipeline length for fused solution accumulation (default 0, no pipelining)"); quda_app - ->add_option( - "--solution-type", solution_type, - "The solution we desire (mat (default), mat-dag-mat, mat-pc, mat-pc-dag-mat-pc (default for multi-shift))") + ->add_option("--solution-type", solution_type, + "The solution we desire (mat (default for Wilson-type), mat-dag-mat, mat-pc (default for " + "staggered-type), mat-pc-dag-mat-pc (default for Wilson-type multi-shift))") ->transform(CLI::QUDACheckedTransformer(solution_type_map)); quda_app @@ -617,8 +618,9 @@ std::shared_ptr make_app(std::string app_description, std::string app_n ->expected(4); quda_app - ->add_option("--solve-type", solve_type, - "The type of solve to do (direct, direct-pc, normop, normop-pc, normerr, normerr-pc)") + ->add_option( + "--solve-type", + solve_type, "The type of solve to do (direct, direct-pc (default for staggered-type), normop, normop-pc (default for Wilson-type), normerr, normerr-pc)") ->transform(CLI::QUDACheckedTransformer(solve_type_map)); quda_app ->add_option("--solver-ext-lib-type", solver_ext_lib, "Set external library for the solvers (default Eigen library)") @@ -759,9 +761,12 @@ void add_eigen_option_group(std::shared_ptr quda_app) opgroup->add_option("--eig-use-dagger", eig_use_dagger, "Solve the Mdag problem instead of M (MMdag if eig-use-normop == true) (default false)"); - opgroup->add_option("--eig-use-normop", eig_use_normop, - "Solve the MdagM problem instead of M (MMdag if eig-use-dagger == true) (default false)"); - opgroup->add_option("--eig-use-pc", eig_use_pc, "Solve the Even-Odd preconditioned problem (default false)"); + opgroup->add_option( + "--eig-use-normop", + eig_use_normop, "Solve the MdagM problem instead of M (MMdag if eig-use-dagger == true) (default false for Wilson-type, true for staggered-type)"); + opgroup->add_option( + "--eig-use-pc", eig_use_pc, + "Solve the Even-Odd preconditioned problem (default false for Wilson-type, true for staggered-type)"); opgroup->add_option("--eig-use-poly-acc", eig_use_poly_acc, "Use Chebyshev polynomial acceleration in the eigensolver"); } diff --git a/tests/utils/host_utils.cpp b/tests/utils/host_utils.cpp index ecf748976f..24659b23e1 100644 --- a/tests/utils/host_utils.cpp +++ b/tests/utils/host_utils.cpp @@ -245,7 +245,7 @@ void constructWilsonTestSpinorParam(quda::ColorSpinorParam *cs_param, const Quda } cs_param->pc_type = inv_param->dslash_type == QUDA_DOMAIN_WALL_DSLASH ? QUDA_5D_PC : QUDA_4D_PC; for (int d = 0; d < 4; d++) cs_param->x[d] = gauge_param->X[d]; - bool pc = isPCSolution(inv_param->solution_type); + bool pc = is_pc_solution(inv_param->solution_type); if (pc) cs_param->x[0] /= 2; cs_param->siteSubset = pc ? QUDA_PARITY_SITE_SUBSET : QUDA_FULL_SITE_SUBSET; @@ -271,15 +271,130 @@ void constructRandomSpinorSource(void *v, int nSpin, int nColor, QudaPrecision p param.fieldOrder = QUDA_SPACE_SPIN_COLOR_FIELD_ORDER; param.nDim = nDim; param.pc_type = QUDA_4D_PC; - param.siteSubset = isPCSolution(sol_type) ? QUDA_PARITY_SITE_SUBSET : QUDA_FULL_SITE_SUBSET; + param.siteSubset = is_pc_solution(sol_type) ? QUDA_PARITY_SITE_SUBSET : QUDA_FULL_SITE_SUBSET; param.siteOrder = QUDA_EVEN_ODD_SITE_ORDER; param.location = QUDA_CPU_FIELD_LOCATION; // DMH FIXME so one can construct device noise for (int d = 0; d < nDim; d++) param.x[d] = x[d]; - if (isPCSolution(sol_type)) param.x[0] /= 2; + if (is_pc_solution(sol_type)) param.x[0] /= 2; quda::ColorSpinorField spinor_in(param); quda::spinorNoise(spinor_in, rng, QUDA_NOISE_UNIFORM); } +// Helper functions +bool is_pc_solution(QudaSolutionType type) +{ + switch (type) { + case QUDA_MATPC_SOLUTION: + case QUDA_MATPC_DAG_SOLUTION: + case QUDA_MATPCDAG_MATPC_SOLUTION: + case QUDA_MATPCDAG_MATPC_SHIFT_SOLUTION: return true; + default: return false; + } +} + +bool is_full_solution(QudaSolutionType type) +{ + switch (type) { + case QUDA_MAT_SOLUTION: + case QUDA_MATDAG_MAT_SOLUTION: return true; + default: return false; + } +} + +bool is_full_solve(QudaSolveType type) +{ + switch (type) { + case QUDA_DIRECT_SOLVE: + case QUDA_NORMOP_SOLVE: + case QUDA_NORMERR_SOLVE: return true; + default: return false; + } +} + +bool is_preconditioned_solve(QudaSolveType type) +{ + switch (type) { + case QUDA_DIRECT_PC_SOLVE: + case QUDA_NORMOP_PC_SOLVE: + case QUDA_NORMERR_PC_SOLVE: return true; + default: return false; + } +} + +bool is_normal_solve(QudaInverterType inv_type, QudaSolveType solve_type) +{ + switch (solve_type) { + case QUDA_NORMOP_SOLVE: + case QUDA_NORMOP_PC_SOLVE: return true; + default: + switch (inv_type) { + case QUDA_CGNR_INVERTER: + case QUDA_CGNE_INVERTER: + case QUDA_CA_CGNR_INVERTER: + case QUDA_CA_CGNE_INVERTER: return true; + default: return false; + } + } +} + +bool is_hermitian_solver(QudaInverterType type) +{ + switch (type) { + case QUDA_CG_INVERTER: + case QUDA_CA_CG_INVERTER: return true; + default: return false; + } +} + +bool support_solution_accumulator_pipeline(QudaInverterType type) +{ + switch (type) { + case QUDA_CG_INVERTER: + case QUDA_CA_CG_INVERTER: + case QUDA_CGNR_INVERTER: + case QUDA_CGNE_INVERTER: + case QUDA_PCG_INVERTER: return true; + default: return false; + } +} + +bool is_normal_residual(QudaInverterType type) +{ + switch (type) { + case QUDA_CGNR_INVERTER: + case QUDA_CA_CGNR_INVERTER: return true; + default: return false; + } +} + +bool is_staggered(QudaDslashType type) +{ + switch (type) { + case QUDA_STAGGERED_DSLASH: + case QUDA_ASQTAD_DSLASH: return true; + default: return false; + } +} + +bool is_chiral(QudaDslashType type) +{ + switch (type) { + case QUDA_DOMAIN_WALL_DSLASH: + case QUDA_DOMAIN_WALL_4D_DSLASH: + case QUDA_MOBIUS_DWF_DSLASH: + case QUDA_MOBIUS_DWF_EOFA_DSLASH: return true; + default: return false; + } +} + +bool is_laplace(QudaDslashType type) +{ + switch (type) { + case QUDA_LAPLACE_DSLASH: return true; + default: return false; + } +} + void initComms(int argc, char **argv, std::array &commDims) { initComms(argc, argv, commDims.data()); } #if defined(QMP_COMMS) || defined(MPI_COMMS) diff --git a/tests/utils/host_utils.h b/tests/utils/host_utils.h index 9465da0d41..24a8668e7d 100644 --- a/tests/utils/host_utils.h +++ b/tests/utils/host_utils.h @@ -40,10 +40,19 @@ extern QudaPrecision &cuda_prec_eigensolver; extern QudaPrecision &cuda_prec_refinement_sloppy; extern QudaPrecision &cuda_prec_ritz; +// Determine if the Laplace operator has been defined +constexpr bool is_enabled_laplace() +{ +#ifdef QUDA_LAPLACE + return true; +#else + return false; +#endif +} + // Set some basic parameters via command line or use defaults // Implemented in set_params.cpp -void setQudaStaggeredEigTestParams(); -void setQudaStaggeredInvTestParams(); +void setQudaStaggeredDefaultInvTestParams(); // Staggered gauge field utils //------------------------------------------------------ @@ -112,11 +121,20 @@ void constructRandomSpinorSource(void *v, int nSpin, int nColor, QudaPrecision p // Helper functions //------------------------------------------------------ -inline bool isPCSolution(QudaSolutionType solution_type) -{ - return (solution_type == QUDA_MATPC_SOLUTION || solution_type == QUDA_MATPC_DAG_SOLUTION - || solution_type == QUDA_MATPCDAG_MATPC_SOLUTION); -} +bool is_pc_solution(QudaSolutionType solution_type); +bool is_full_solution(QudaSolutionType type); + +bool is_preconditioned_solve(QudaSolveType type); +bool is_normal_solve(QudaInverterType inv_type, QudaSolveType solve_type); + +bool is_hermitian_solver(QudaInverterType type); +bool support_solution_accumulator_pipeline(QudaInverterType type); +bool is_normal_residual(QudaInverterType type); + +bool is_staggered(QudaDslashType type); +bool is_chiral(QudaDslashType type); +bool is_laplace(QudaDslashType type); + //------------------------------------------------------ // Reports basic statistics of flops and solver iterations diff --git a/tests/utils/misc.cpp b/tests/utils/misc.cpp index fd920e5e16..f2b8f54bcc 100644 --- a/tests/utils/misc.cpp +++ b/tests/utils/misc.cpp @@ -96,23 +96,6 @@ const char *get_test_type(int t) return ret; } -const char *get_staggered_test_type(int t) -{ - const char *ret; - switch (t) { - case 0: ret = "full"; break; - case 1: ret = "full_ee_prec"; break; - case 2: ret = "full_oo_prec"; break; - case 3: ret = "even"; break; - case 4: ret = "odd"; break; - case 5: ret = "mcg_even"; break; - case 6: ret = "mcg_odd"; break; - default: ret = "unknown"; break; - } - - return ret; -} - const char *get_dslash_str(QudaDslashType type) { const char *ret; diff --git a/tests/utils/misc.h b/tests/utils/misc.h index bac9cf69c9..bf9a8d3039 100644 --- a/tests/utils/misc.h +++ b/tests/utils/misc.h @@ -7,7 +7,6 @@ const char *get_recon_str(QudaReconstructType recon); const char *get_prec_str(QudaPrecision prec); const char *get_gauge_order_str(QudaGaugeFieldOrder order); const char *get_test_type(int t); -const char *get_staggered_test_type(int t); const char *get_unitarization_str(bool svd_only); const char *get_mass_normalization_str(QudaMassNormalization); const char *get_verbosity_str(QudaVerbosity); diff --git a/tests/utils/set_params.cpp b/tests/utils/set_params.cpp index b862c3ced4..404401c2d6 100644 --- a/tests/utils/set_params.cpp +++ b/tests/utils/set_params.cpp @@ -937,9 +937,11 @@ void setStaggeredInvertParam(QudaInvertParam &inv_param) // domain decomposition preconditioner parameters inv_param.inv_type_precondition = precon_type; + inv_param.schwarz_type = precon_schwarz_type; + inv_param.precondition_cycle = precon_schwarz_cycle; inv_param.tol_precondition = tol_precondition; inv_param.maxiter_precondition = maxiter_precondition; - inv_param.verbosity_precondition = QUDA_SILENT; + inv_param.verbosity_precondition = verbosity_precondition; inv_param.cuda_prec_precondition = prec_precondition; inv_param.cuda_prec_eigensolver = prec_eigensolver; @@ -952,6 +954,11 @@ void setStaggeredInvertParam(QudaInvertParam &inv_param) inv_param.ca_lambda_min = ca_lambda_min; inv_param.ca_lambda_max = ca_lambda_max; + // Set preconditioner CA info + inv_param.ca_basis_precondition = ca_basis_precondition; + inv_param.ca_lambda_min_precondition = ca_lambda_min_precondition; + inv_param.ca_lambda_max_precondition = ca_lambda_max_precondition; + inv_param.solution_type = solution_type; inv_param.solve_type = solve_type; inv_param.matpc_type = matpc_type; @@ -1372,112 +1379,60 @@ void setDeflationParam(QudaEigParam &df_param) df_param.partfile = eig_partfile ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE; } -void setQudaStaggeredInvTestParams() +/**********/ +// The enumerated staggered tests have been removed, but for reference: +// +// Test 0: +// solve_type = QUDA_DIRECT_SOLVE +// matpc_type = QUDA_MATPC_EVEN_EVEN (doesn't matter) +// solution_type = QUDA_MAT_SOLUTION +// +// Test 1: +// solve_type = QUDA_DIRECT_PC_SOLVE +// matpc_type = QUDA_MATPC_EVEN_EVEN +// solution_type = QUDA_MAT_SOLUTION +// +// Test 2: +// solve_type = QUDA_DIRECT_PC_SOLVE +// matpc_type = QUDA_MATPC_ODD_ODD +// solution_type = QUDA_MAT_SOLUTION +// +// Test 3: +// solve_type = QUDA_DIRECT_PC_SOLVE +// matpc_type = QUDA_MATPC_EVEN_EVEN +// solution_type = QUDA_MATPC_SOLUTION +// +// Test 4: +// solve_type = QUDA_DIRECT_PC_SOLVE +// matpc_type = QUDA_MATPC_ODD_ODD +// solution_type = QUDA_MATPC_SOLUTION +// +// Test 5: multi-shift +// solve_type = QUDA_DIRECT_PC_SOLVE +// matpc_type = QUDA_MATPC_EVEN_EVEN +// solution_type = QUDA_MATPC_SOLUTION +// +// Test 6: multi-shift +// solve_type = QUDA_DIRECT_PC_SOLVE +// matpc_type = QUDA_MATPC_ODD_ODD +// solution_type = QUDA_MATPC_SOLUTION +/**********/ + +void setQudaStaggeredDefaultInvTestParams() { - if (dslash_type == QUDA_LAPLACE_DSLASH) { - if (test_type != 0) { errorQuda("Test type %d is not supported for the Laplace operator.\n", test_type); } - - solve_type = QUDA_DIRECT_SOLVE; - solution_type = QUDA_MAT_SOLUTION; - matpc_type = QUDA_MATPC_EVEN_EVEN; // doesn't matter - - } else { - - if (test_type == 0 && (inv_type == QUDA_CG_INVERTER || inv_type == QUDA_PCG_INVERTER) - && solve_type != QUDA_NORMOP_SOLVE && solve_type != QUDA_DIRECT_PC_SOLVE) { - warningQuda("The full spinor staggered operator (test 0) can't be inverted with (P)CG. Switching to BiCGstab.\n"); - inv_type = QUDA_BICGSTAB_INVERTER; - } + // Set some meaningful defaults for staggered tests - if (solve_type == QUDA_INVALID_SOLVE) { - if (test_type == 0) { - solve_type = QUDA_DIRECT_SOLVE; - } else { - solve_type = QUDA_DIRECT_PC_SOLVE; - } - } - - if (test_type == 1 || test_type == 3 || test_type == 5) { - matpc_type = QUDA_MATPC_EVEN_EVEN; - } else if (test_type == 2 || test_type == 4 || test_type == 6) { - matpc_type = QUDA_MATPC_ODD_ODD; - } else if (test_type == 0) { - matpc_type = QUDA_MATPC_EVEN_EVEN; // it doesn't matter - } - - if (test_type == 0 || test_type == 1 || test_type == 2) { - solution_type = QUDA_MAT_SOLUTION; - } else { - solution_type = QUDA_MATPC_SOLUTION; - } - } + // Default to the ASQTAD dslash + dslash_type = QUDA_ASQTAD_DSLASH; - if (prec_sloppy == QUDA_INVALID_PRECISION) { prec_sloppy = prec; } - - if (prec_refinement_sloppy == QUDA_INVALID_PRECISION) { prec_refinement_sloppy = prec_sloppy; } - if (link_recon_sloppy == QUDA_RECONSTRUCT_INVALID) { link_recon_sloppy = link_recon; } - - if (inv_type != QUDA_CG_INVERTER && (test_type == 5 || test_type == 6)) { - errorQuda("Preconditioning is currently not supported in multi-shift solver solvers"); - } + // Default to a Schur-preconditioned CG solve + solve_type = QUDA_DIRECT_PC_SOLVE; + solution_type = QUDA_MATPC_SOLUTION; + matpc_type = QUDA_MATPC_EVEN_EVEN; + inv_type = QUDA_CG_INVERTER; - // Set n_naiks to 2 if eps_naik != 0.0 - if (dslash_type == QUDA_ASQTAD_DSLASH) { - if (eps_naik != 0.0) { - if (compute_fatlong) { - n_naiks = 2; - printfQuda("Note: epsilon-naik != 0, testing epsilon correction links.\n"); - } else { - eps_naik = 0.0; - printfQuda("Not computing fat-long, ignoring epsilon correction.\n"); - } - } else { - printfQuda("Note: epsilon-naik = 0, testing original HISQ links.\n"); - } - } -} - -void setQudaStaggeredEigTestParams() -{ - if (dslash_type == QUDA_LAPLACE_DSLASH) { - // LAPLACE operator path, only DIRECT solves feasible. - if (test_type != 0) { errorQuda("Test type %d is not supported for the Laplace operator.\n", test_type); } - solve_type = QUDA_DIRECT_SOLVE; - solution_type = QUDA_MAT_SOLUTION; - } else { - // STAGGERED operator path - if (solve_type == QUDA_INVALID_SOLVE) { - if (test_type == 0) { - solve_type = QUDA_DIRECT_SOLVE; - } else { - solve_type = QUDA_DIRECT_PC_SOLVE; - } - } - // If test type is not 3, it is 4 or 0. If 0, the matpc type is irrelevant - if (test_type == 3) - matpc_type = QUDA_MATPC_EVEN_EVEN; - else - matpc_type = QUDA_MATPC_ODD_ODD; - - if (test_type == 0) { - solution_type = QUDA_MAT_SOLUTION; - } else { - solution_type = QUDA_MATPC_SOLUTION; - } - } - - // Set n_naiks to 2 if eps_naik != 0.0 - if (dslash_type == QUDA_ASQTAD_DSLASH) { - if (eps_naik != 0.0) { - if (compute_fatlong) { - n_naiks = 2; - printfQuda("Note: epsilon-naik != 0, testing epsilon correction links.\n"); - } else { - eps_naik = 0.0; - printfQuda("Not computing fat-long, ignoring epsilon correction.\n"); - } - } else { - printfQuda("Note: epsilon-naik = 0, testing original HISQ links.\n"); - } - } + // For an eigensolve, default to using the "regular" operator instead of the normal + // operator because the Schur operator is already HPD + eig_use_normop = QUDA_BOOLEAN_FALSE; + eig_use_pc = true; } diff --git a/tests/utils/staggered_gauge_utils.cpp b/tests/utils/staggered_gauge_utils.cpp index b020939c5f..85e7993ba5 100644 --- a/tests/utils/staggered_gauge_utils.cpp +++ b/tests/utils/staggered_gauge_utils.cpp @@ -27,8 +27,11 @@ void computeHISQLinksGPU(void **qdp_fatlink, void **qdp_longlink, void **qdp_fat std::array, 3> &act_path_coeffs, double eps_naik, size_t gSize, int n_naiks) { - // since a lot of intermediaries can be general matrices, override the recon in `gauge_param_in` + // Intermediates can be general matrices, so override the reconstruct. + // Similarly, gauge links can only be built in single or double, so upscale the build precision + // if neccessary. auto gauge_param = gauge_param_in; + if (gauge_param.cuda_prec < QUDA_SINGLE_PRECISION) gauge_param.cuda_prec = QUDA_SINGLE_PRECISION; gauge_param.reconstruct = QUDA_RECONSTRUCT_NO; gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_NO; // probably irrelevant diff --git a/tests/utils/staggered_host_utils.cpp b/tests/utils/staggered_host_utils.cpp index 95efb0d85a..fcc0b1d697 100644 --- a/tests/utils/staggered_host_utils.cpp +++ b/tests/utils/staggered_host_utils.cpp @@ -815,7 +815,7 @@ void constructStaggeredTestSpinorParam(quda::ColorSpinorParam *cs_param, const Q cs_param->nSpin = 1; cs_param->nDim = 4; for (int d = 0; d < 4; d++) cs_param->x[d] = gauge_param->X[d]; - bool pc = isPCSolution(inv_param->solution_type); + bool pc = is_pc_solution(inv_param->solution_type); if (pc) cs_param->x[0] /= 2; cs_param->pc_type = QUDA_4D_PC; cs_param->siteSubset = pc ? QUDA_PARITY_SITE_SUBSET : QUDA_FULL_SITE_SUBSET;