Skip to content

Commit

Permalink
Merge pull request #1421 from lattice/feature/stag-cleanup
Browse files Browse the repository at this point in the history
Staggered invert test cleanup + expanded staggered gtest support
  • Loading branch information
weinbe2 authored Jan 10, 2024
2 parents fd50676 + 431c4ec commit b87195b
Show file tree
Hide file tree
Showing 34 changed files with 2,834 additions and 1,528 deletions.
8 changes: 7 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ option(QUDA_CLOVER_DYNAMIC "Dynamically invert the clover term" ON)
option(QUDA_CLOVER_RECONSTRUCT "set to ON to enable compressed clover storage (requires QUDA_CLOVER_DYNAMIC)" ON)
option(QUDA_CLOVER_CHOLESKY_PROMOTE "Whether to promote the internal precision when inverting the clover term" ON)

option(QUDA_IMPROVED_STAGGERED_EIGENSOLVER_CTEST "Whether to run eigensolver ctests against the improved staggered operator (requires QUDA_DIRAC_STAGGERED)" OFF)

# Set CTest options
option(QUDA_CTEST_SEP_DSLASH_POLICIES "Test Dslash policies separately in ctest instead of only autotuning them." OFF)
option(QUDA_CTEST_DISABLE_BENCHMARKS "Disable benchmark test" ON)
Expand Down Expand Up @@ -391,7 +393,11 @@ set(CMAKE_EXE_LINKER_FLAGS_SANITIZE
CACHE STRING "Flags used by the linker during sanitizer debug builds.")

if(QUDA_CLOVER_RECONSTRUCT AND NOT QUDA_CLOVER_DYNAMIC)
message(SEND_ERROR "QUDA_CLOVER_RECONSTRUCT requires QUDA_CLOVER_DYNAMIC)")
message(SEND_ERROR "QUDA_CLOVER_RECONSTRUCT requires QUDA_CLOVER_DYNAMIC")
endif()

if (QUDA_IMPROVED_STAGGERED_EIGENSOLVER_CTEST AND NOT QUDA_DIRAC_STAGGERED)
message(SEND_ERROR "QUDA_IMPROVED_STAGGERED_EIGENSOLVER_CTEST requires QUDA_DIRAC_STAGGERED")
endif()

find_package(Threads REQUIRED)
Expand Down
1 change: 0 additions & 1 deletion ci/docker/Dockerfile.build
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ RUN QUDA_TEST_GRID_SIZE="1 1 1 2" cmake -S /quda/src \
-DQUDA_DIRAC_DEFAULT_OFF=ON \
-DQUDA_DIRAC_WILSON=ON \
-DQUDA_DIRAC_CLOVER=ON \
-DQUDA_DIRAC_TWISTED_MASS=ON \
-DQUDA_DIRAC_TWISTED_CLOVER=ON \
-DQUDA_DIRAC_STAGGERED=ON \
-GNinja \
Expand Down
23 changes: 21 additions & 2 deletions include/invert_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -1048,17 +1048,36 @@ namespace quda {

private:
const DiracMdagM matMdagM; // used by the eigensolver
// pointers to fields to avoid multiple creation overhead
ColorSpinorField *yp, *rp, *pp, *vp, *tmpp, *tp;

ColorSpinorField y; // Full precision solution accumulator
ColorSpinorField r; // Full precision residual vector
ColorSpinorField p; // Sloppy precision search direction
ColorSpinorField v; // Sloppy precision A * p
ColorSpinorField t; // Sloppy precision vector used for minres step
ColorSpinorField r0; // Bi-orthogonalization vector
ColorSpinorField r_sloppy; // Slopy precision residual vector
ColorSpinorField x_sloppy; // Sloppy solution accumulator vector
bool init = false;

/**
@brief Initiate the fields needed by the solver
@param[in] x Solution vector
@param[in] b Source vector
*/
void create(ColorSpinorField &x, const ColorSpinorField &b);

public:
BiCGstab(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon,
const DiracMatrix &matEig, SolverParam &param, TimeProfile &profile);
virtual ~BiCGstab();

void operator()(ColorSpinorField &out, ColorSpinorField &in) override;

/**
@return Return the residual vector from the prior solve
*/
ColorSpinorField &get_residual() override;

virtual bool hermitian() const override { return false; } /** BiCGStab is for any linear system */

virtual QudaInverterType getInverterType() const final { return QUDA_BICGSTAB_INVERTER; }
Expand Down
19 changes: 14 additions & 5 deletions lib/eig_block_trlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,25 @@ namespace quda
eigensolveFromBlockArrowMat();
profile.TPSTART(QUDA_PROFILE_COMPUTE);

// mat_norm is updated.
// mat_norm is updated and used for LR
for (int i = num_locked; i < n_kr; i++)
if (fabs(alpha[i]) > mat_norm) mat_norm = fabs(alpha[i]);

// Lambda that returns mat_norm for LR and returns the relevant alpha
// (the corresponding Ritz value) for SR
auto check_norm = [&](double sr_norm) -> double {
if (eig_param->spectrum == QUDA_SPECTRUM_LR_EIG)
return mat_norm;
else
return sr_norm;
};

// Locking check
iter_locked = 0;
for (int i = 1; i < (n_kr - num_locked); i++) {
if (residua[i + num_locked] < epsilon * mat_norm) {
if (residua[i + num_locked] < epsilon * check_norm(alpha[i + num_locked])) {
logQuda(QUDA_DEBUG_VERBOSE, "**** Locking %d resid=%+.6e condition=%.6e ****\n", i, residua[i + num_locked],
epsilon * mat_norm);
epsilon * check_norm(alpha[i + num_locked]));
iter_locked = i;
} else {
// Unlikely to find new locked pairs
Expand All @@ -125,9 +134,9 @@ namespace quda
// Convergence check
iter_converged = iter_locked;
for (int i = iter_locked + 1; i < n_kr - num_locked; i++) {
if (residua[i + num_locked] < tol * mat_norm) {
if (residua[i + num_locked] < tol * check_norm(alpha[i + num_locked])) {
logQuda(QUDA_DEBUG_VERBOSE, "**** Converged %d resid=%+.6e condition=%.6e ****\n", i, residua[i + num_locked],
tol * mat_norm);
tol * check_norm(alpha[i + num_locked]));
iter_converged = i;
} else {
// Unlikely to find new converged pairs
Expand Down
23 changes: 16 additions & 7 deletions lib/eig_trlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,25 @@ namespace quda
eigensolveFromArrowMat();
profile.TPSTART(QUDA_PROFILE_COMPUTE);

// mat_norm is updated.
// mat_norm is updated and used for LR
for (int i = num_locked; i < n_kr; i++)
if (fabs(alpha[i]) > mat_norm) mat_norm = fabs(alpha[i]);

// Lambda that returns mat_norm for LR and returns the relevant alpha
// (the corresponding Ritz value) for SR
auto check_norm = [&](double sr_norm) -> double {
if (eig_param->spectrum == QUDA_SPECTRUM_LR_EIG)
return mat_norm;
else
return sr_norm;
};

// Locking check
iter_locked = 0;
for (int i = 1; i < (n_kr - num_locked); i++) {
if (residua[i + num_locked] < epsilon * mat_norm) {
if (residua[i + num_locked] < epsilon * check_norm(alpha[i + num_locked])) {
logQuda(QUDA_DEBUG_VERBOSE, "**** Locking %d resid=%+.6e condition=%.6e ****\n", i, residua[i + num_locked],
epsilon * mat_norm);
epsilon * check_norm(alpha[i + num_locked]));
iter_locked = i;
} else {
// Unlikely to find new locked pairs
Expand All @@ -106,9 +115,9 @@ namespace quda
// Convergence check
iter_converged = iter_locked;
for (int i = iter_locked + 1; i < n_kr - num_locked; i++) {
if (residua[i + num_locked] < tol * mat_norm) {
if (residua[i + num_locked] < tol * check_norm(alpha[i + num_locked])) {
logQuda(QUDA_DEBUG_VERBOSE, "**** Converged %d resid=%+.6e condition=%.6e ****\n", i, residua[i + num_locked],
tol * mat_norm);
tol * check_norm(alpha[i + num_locked]));
iter_converged = i;
} else {
// Unlikely to find new converged pairs
Expand Down Expand Up @@ -165,8 +174,8 @@ namespace quda
logQuda(QUDA_SUMMARIZE, "TRLM computed the requested %d vectors in %d restart steps and %d OP*x operations.\n",
n_conv, restart_iter, iter);

// Dump all Ritz values and residua if using Chebyshev
for (int i = 0; i < n_conv && eig_param->use_poly_acc; i++) {
// Dump all Ritz values and residua
for (int i = 0; i < n_conv; i++) {
logQuda(QUDA_SUMMARIZE, "RitzValue[%04d]: (%+.16e, %+.16e) residual %.16e\n", i, alpha[i], 0.0, residua[i]);
}

Expand Down
123 changes: 62 additions & 61 deletions lib/interface_quda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3014,7 +3014,7 @@ void loadFatLongGaugeQuda(QudaInvertParam *inv_param, QudaGaugeParam *gauge_para
template <class Interface, class... Args>
void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // color spinor field pointers, and inv_param
void *h_gauge, void *milc_fatlinks, void *milc_longlinks,
QudaGaugeParam *gauge_param, // gauge field pointers
QudaGaugeParam *gauge_param_, // gauge field pointers
void *h_clover, void *h_clovinv, // clover field pointers
Interface op, Args... args)
{
Expand All @@ -3034,14 +3034,17 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col
errorQuda("split_key = [%d,%d,%d,%d] is not valid", split_key[0], split_key[1], split_key[2], split_key[3]);
}

// Create a local copy of gauge_param that we can modify without perturbing
// the original one
if (!gauge_param_) errorQuda("Input gauge_param is null");
QudaGaugeParam gauge_param = *gauge_param_;

if (num_sub_partition == 1) { // In this case we don't split the grid.

for (int n = 0; n < param->num_src; n++) { op(_hp_x[n], _hp_b[n], param, args...); }

} else {

if (gauge_param == nullptr) { errorQuda("gauge_param == nullptr"); }

// Doing the sub-partition arithmatics
if (param->num_src_per_sub_partition * num_sub_partition != param->num_src) {
errorQuda("We need to have split_grid[0](=%d) * split_grid[1](=%d) * split_grid[2](=%d) * split_grid[3](=%d) * "
Expand All @@ -3058,44 +3061,50 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col

checkInvertParam(param, _hp_x[0], _hp_b[0]);

bool is_staggered;
bool is_staggered = false;
bool is_asqtad = false;
if (h_gauge) {
is_staggered = false;
} else if (milc_fatlinks) {
is_staggered = true;
if (param->dslash_type == QUDA_ASQTAD_DSLASH) {
if (!milc_longlinks) errorQuda("milc_longlinks is null for an asqtad dslash");
is_asqtad = true;
}
} else {
errorQuda("Both h_gauge and milc_fatlinks are null.");
is_staggered = true; // to suppress compiler warning/error.
}

// Gauge fields/params
GaugeFieldParam *gf_param = nullptr;
GaugeField *in = nullptr;
GaugeFieldParam gf_param;
GaugeField in;
// Staggered gauge fields/params
GaugeFieldParam *milc_fatlink_param = nullptr;
GaugeFieldParam *milc_longlink_param = nullptr;
GaugeField *milc_fatlink_field = nullptr;
GaugeField *milc_longlink_field = nullptr;
GaugeFieldParam milc_fatlink_param;
GaugeFieldParam milc_longlink_param;
quda::GaugeField milc_fatlink_field;
quda::GaugeField milc_longlink_field;

// set up the gauge field params.
if (!is_staggered) { // not staggered
gf_param = new GaugeFieldParam(*gauge_param, h_gauge);
if (gf_param->order <= 4) gf_param->ghostExchange = QUDA_GHOST_EXCHANGE_NO;
in = GaugeField::Create(*gf_param);
gf_param = GaugeFieldParam(gauge_param, h_gauge);
in = GaugeField(gf_param);
} else { // staggered
milc_fatlink_param = new GaugeFieldParam(*gauge_param, milc_fatlinks);
if (milc_fatlink_param->order <= 4) milc_fatlink_param->ghostExchange = QUDA_GHOST_EXCHANGE_NO;
milc_fatlink_field = GaugeField::Create(*milc_fatlink_param);
milc_longlink_param = new GaugeFieldParam(*gauge_param, milc_longlinks);
if (milc_longlink_param->order <= 4) milc_longlink_param->ghostExchange = QUDA_GHOST_EXCHANGE_NO;
milc_longlink_field = GaugeField::Create(*milc_longlink_param);
milc_fatlink_param = GaugeFieldParam(gauge_param, milc_fatlinks);
milc_fatlink_param.order = QUDA_MILC_GAUGE_ORDER;
milc_fatlink_field = GaugeField(milc_fatlink_param);

if (is_asqtad) {
milc_longlink_param = GaugeFieldParam(gauge_param, milc_longlinks);
milc_longlink_param.order = QUDA_MILC_GAUGE_ORDER;
milc_longlink_field = GaugeField(milc_longlink_param);
}
}

// Create the temp host side helper fields, which are just wrappers of the input pointers.
bool pc_solution
= (param->solution_type == QUDA_MATPC_SOLUTION) || (param->solution_type == QUDA_MATPCDAG_MATPC_SOLUTION);

lat_dim_t X = {gauge_param->X[0], gauge_param->X[1], gauge_param->X[2], gauge_param->X[3]};
lat_dim_t X = {gauge_param.X[0], gauge_param.X[1], gauge_param.X[2], gauge_param.X[3]};
ColorSpinorParam cpuParam(_hp_b[0], *param, X, pc_solution, param->input_location);
std::vector<ColorSpinorField *> _h_b(param->num_src);
for (int i = 0; i < param->num_src; i++) {
Expand All @@ -3119,16 +3128,14 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col
errorQuda("Split not possible: %2d %% %2d != 0", comm_dim(d), split_key[d]);
}
if (!is_staggered) {
gf_param->x[d] *= split_key[d];
gf_param->pad *= split_key[d];
gf_param.x[d] *= split_key[d];
gf_param.pad *= split_key[d];
} else {
milc_fatlink_param->x[d] *= split_key[d];
milc_fatlink_param->pad *= split_key[d];
milc_longlink_param->x[d] *= split_key[d];
milc_longlink_param->pad *= split_key[d];
milc_fatlink_param.x[d] *= split_key[d];
if (is_asqtad) milc_longlink_param.x[d] *= split_key[d];
}
gauge_param->X[d] *= split_key[d];
gauge_param->ga_pad *= split_key[d];
gauge_param.X[d] *= split_key[d];
if (!is_staggered) gauge_param.ga_pad *= split_key[d];
}

// Deal with clover field. For Multi source computatons, clover field construction is done
Expand Down Expand Up @@ -3171,26 +3178,30 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col
}
}

quda::GaugeField *collected_gauge = nullptr;
quda::GaugeField *collected_milc_fatlink_field = nullptr;
quda::GaugeField *collected_milc_longlink_field = nullptr;
quda::GaugeField collected_gauge;
quda::GaugeField collected_milc_fatlink_field;
quda::GaugeField collected_milc_longlink_field;

if (!is_staggered) {
gf_param->create = QUDA_NULL_FIELD_CREATE;
collected_gauge = new quda::GaugeField(*gf_param);
gf_param.create = QUDA_NULL_FIELD_CREATE;
collected_gauge = quda::GaugeField(gf_param);
std::vector<quda::GaugeField *> v_g(1);
v_g[0] = in;
quda::split_field(*collected_gauge, v_g, split_key);
v_g[0] = &in;
quda::split_field(collected_gauge, v_g, split_key);
} else {
milc_fatlink_param->create = QUDA_NULL_FIELD_CREATE;
milc_longlink_param->create = QUDA_NULL_FIELD_CREATE;
collected_milc_fatlink_field = new quda::GaugeField(*milc_fatlink_param);
collected_milc_longlink_field = new quda::GaugeField(*milc_longlink_param);
std::vector<quda::GaugeField *> v_g(1);
v_g[0] = milc_fatlink_field;
quda::split_field(*collected_milc_fatlink_field, v_g, split_key);
v_g[0] = milc_longlink_field;
quda::split_field(*collected_milc_longlink_field, v_g, split_key);

milc_fatlink_param.create = QUDA_NULL_FIELD_CREATE;
collected_milc_fatlink_field = GaugeField(milc_fatlink_param);
v_g[0] = &milc_fatlink_field;
quda::split_field(collected_milc_fatlink_field, v_g, split_key);

if (is_asqtad) {
milc_longlink_param.create = QUDA_NULL_FIELD_CREATE;
collected_milc_longlink_field = GaugeField(milc_longlink_param);
v_g[0] = &milc_longlink_field;
quda::split_field(collected_milc_longlink_field, v_g, split_key);
}
}

profileInvertMultiSrc.TPSTART(QUDA_PROFILE_PREAMBLE);
Expand Down Expand Up @@ -3223,10 +3234,10 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col
// the split topology.
logQuda(QUDA_DEBUG_VERBOSE, "Split grid loading gauge field...\n");
if (!is_staggered) {
loadGaugeQuda(collected_gauge->raw_pointer(), gauge_param);
loadGaugeQuda(collected_gauge.raw_pointer(), &gauge_param);
} else {
loadFatLongGaugeQuda(param, gauge_param, collected_milc_fatlink_field->raw_pointer(),
collected_milc_longlink_field->raw_pointer());
loadFatLongGaugeQuda(param, &gauge_param, collected_milc_fatlink_field.raw_pointer(),
(is_asqtad) ? collected_milc_longlink_field.raw_pointer() : nullptr);
}
logQuda(QUDA_DEBUG_VERBOSE, "Split grid loaded gauge field...\n");

Expand All @@ -3251,8 +3262,8 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col
comm_barrier();

for (int d = 0; d < CommKey::n_dim; d++) {
gauge_param->X[d] /= split_key[d];
gauge_param->ga_pad /= split_key[d];
gauge_param.X[d] /= split_key[d];
if (!is_staggered) gauge_param.ga_pad /= split_key[d];
}

for (int n = 0; n < param->num_src_per_sub_partition; n++) {
Expand All @@ -3268,27 +3279,17 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col
for (auto p : _h_x) { delete p; }
for (auto p : _h_b) { delete p; }

if (!is_staggered) {
delete in;
delete collected_gauge;
} else {
delete milc_fatlink_field;
delete milc_longlink_field;
delete collected_milc_fatlink_field;
delete collected_milc_longlink_field;
}

if (input_clover) { delete input_clover; }
if (collected_clover) { delete collected_clover; }

profileInvertMultiSrc.TPSTOP(QUDA_PROFILE_EPILOGUE);

// Restore the gauge field
if (!is_staggered) {
loadGaugeQuda(h_gauge, gauge_param);
loadGaugeQuda(h_gauge, gauge_param_);
} else {
freeGaugeQuda();
loadFatLongGaugeQuda(param, gauge_param, milc_fatlinks, milc_longlinks);
loadFatLongGaugeQuda(param, gauge_param_, milc_fatlinks, milc_longlinks);
}

if (param->dslash_type == QUDA_CLOVER_WILSON_DSLASH || param->dslash_type == QUDA_TWISTED_CLOVER_DSLASH) {
Expand Down
Loading

0 comments on commit b87195b

Please sign in to comment.