Skip to content

Commit

Permalink
Merge pull request #1434 from lattice/feature/mrhs-transfer
Browse files Browse the repository at this point in the history
Multi-RHS support for Prolongator and Restrictor
  • Loading branch information
weinbe2 authored Feb 1, 2024
2 parents a979a4d + a6aafaf commit ed6160e
Show file tree
Hide file tree
Showing 27 changed files with 832 additions and 732 deletions.
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ if(QUDA_MAX_MULTI_BLAS_N GREATER 32)
message(SEND_ERROR "Maximum QUDA_MAX_MULTI_BLAS_N is 32.")
endif()

# Set the maximum multi-RHS per kernel
set(QUDA_MAX_MULTI_RHS "32" CACHE STRING "maximum number of simultaneous RHS in a kernel")

set(QUDA_PRECISION
"14"
CACHE STRING "which precisions to instantiate in QUDA (4-bit number - double, single, half, quarter)")
Expand Down Expand Up @@ -275,6 +278,7 @@ mark_as_advanced(QUDA_FAST_COMPILE_DSLASH)
mark_as_advanced(QUDA_ALTERNATIVE_I_TO_F)

mark_as_advanced(QUDA_MAX_MULTI_BLAS_N)
mark_as_advanced(QUDA_MAX_MULTI_RHS)
mark_as_advanced(QUDA_PRECISION)
mark_as_advanced(QUDA_RECONSTRUCT)
mark_as_advanced(QUDA_CLOVER_CHOLESKY_PROMOTE)
Expand Down
12 changes: 1 addition & 11 deletions include/blas_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,7 @@ namespace quda {
for (auto i = 0u; i < x.size(); i++) x[i].zero();
}

inline void copy(ColorSpinorField &dst, const ColorSpinorField &src)
{
if (dst.data() == src.data()) {
// check the fields are equivalent else error
if (ColorSpinorField::are_compatible(dst, src))
return;
else
errorQuda("Aliasing pointers with incompatible fields");
}
dst.copy(src);
}
inline void copy(ColorSpinorField &dst, const ColorSpinorField &src) { dst.copy(src); }

/**
@brief Apply the operation y = a * x
Expand Down
15 changes: 3 additions & 12 deletions include/color_spinor_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ namespace quda
QudaSiteOrder SiteOrder() const { return siteOrder; }
QudaFieldOrder FieldOrder() const { return fieldOrder; }
QudaGammaBasis GammaBasis() const { return gammaBasis; }
void GammaBasis(QudaGammaBasis new_basis) { gammaBasis = new_basis; }

const int *GhostFace() const { return ghostFace.data; }
const int *GhostFaceCB() const { return ghostFaceCB.data; }
Expand Down Expand Up @@ -781,16 +782,6 @@ namespace quda
*/
ColorSpinorField create_alias(const ColorSpinorParam &param = ColorSpinorParam());

/**
@brief Create a field that aliases this field's storage. The
alias field can use a different precision than this field,
though it cannot be greater. This functionality is useful for
the case where we have multiple temporaries in different
precisions, but do not need them simultaneously. Use this functionality with caution.
@param[in] param Parameters for the alias field
*/
ColorSpinorField *CreateAlias(const ColorSpinorParam &param);

/**
@brief Create a coarse color-spinor field, using this field to set the meta data
@param[in] geoBlockSize Geometric block size that defines the coarse grid dimensions
Expand All @@ -800,7 +791,7 @@ namespace quda
@param[in] location Optionally set the location of the coarse field
@param[in] mem_type Optionally set the memory type used (e.g., can override with mapped memory)
*/
ColorSpinorField *CreateCoarse(const int *geoBlockSize, int spinBlockSize, int Nvec,
ColorSpinorField create_coarse(const int *geoBlockSize, int spinBlockSize, int Nvec,
QudaPrecision precision = QUDA_INVALID_PRECISION,
QudaFieldLocation location = QUDA_INVALID_FIELD_LOCATION,
QudaMemoryType mem_Type = QUDA_MEMORY_INVALID);
Expand All @@ -814,7 +805,7 @@ namespace quda
@param[in] location Optionally set the location of the fine field
@param[in] mem_type Optionally set the memory type used (e.g., can override with mapped memory)
*/
ColorSpinorField *CreateFine(const int *geoblockSize, int spinBlockSize, int Nvec,
ColorSpinorField create_fine(const int *geoblockSize, int spinBlockSize, int Nvec,
QudaPrecision precision = QUDA_INVALID_PRECISION,
QudaFieldLocation location = QUDA_INVALID_FIELD_LOCATION,
QudaMemoryType mem_type = QUDA_MEMORY_INVALID);
Expand Down
2 changes: 1 addition & 1 deletion include/color_spinor_field_order.h
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ namespace quda
*/
template <int nSpinBlock>
__device__ __host__ inline void load(complex<Float> out[nSpinBlock * nColor * nVec], int parity, int x_cb,
int chi) const
int chi = 0) const
{
if (!fixed) {
accessor.template load<nSpinBlock>((complex<storeFloat> *)out, v.v, parity, x_cb, chi, volumeCB);
Expand Down
2 changes: 1 addition & 1 deletion include/field_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ namespace quda {
@param[in] a Vector of fields we wish to create a matching
temporary for
*/
template <typename T> auto getFieldTmp(const vector_ref<T> &a)
template <typename T> auto getFieldTmp(cvector_ref<T> &a)
{
std::vector<FieldTmp<T>> tmp;
tmp.reserve(a.size());
Expand Down
9 changes: 4 additions & 5 deletions include/kernels/block_orthogonalize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,14 @@ namespace quda {
static constexpr bool swizzle = false;
int_fastdiv swizzle_factor; // for transposing blockIdx.x mapping to coarse grid coordinate

const Vector B[nVec];
Vector B[nVec];

static constexpr bool launch_bounds = true;
dim3 grid_dim;
dim3 block_dim;

template <typename... T>
BlockOrthoArg(ColorSpinorField &V, const int *fine_to_coarse, const int *coarse_to_fine, int parity,
const int *geo_bs, const int n_block_ortho, const ColorSpinorField &meta, T... B) :
BlockOrthoArg(ColorSpinorField &V, const std::vector<ColorSpinorField> &B, const int *fine_to_coarse, const int *coarse_to_fine, int parity,
const int *geo_bs, const int n_block_ortho, const ColorSpinorField &meta) :
kernel_param(dim3(meta.VolumeCB() * (fineSpin > 1 ? meta.SiteSubset() : 1), 1, chiral_blocks)),
V(V),
fine_to_coarse(fine_to_coarse),
Expand All @@ -69,10 +68,10 @@ namespace quda {
nParity(meta.SiteSubset()),
nBlockOrtho(n_block_ortho),
fineVolumeCB(meta.VolumeCB()),
B{*B...},
grid_dim(),
block_dim()
{
for (int i = 0; i < nVec; i++) this->B[i] = B[i];
int aggregate_size = 1;
for (int d = 0; d < V.Ndim(); d++) aggregate_size *= geo_bs[d];
aggregate_size_cb = aggregate_size / 2;
Expand Down
2 changes: 1 addition & 1 deletion include/kernels/dslash_coarse.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ namespace quda {
using F = typename colorspinor::FieldOrderCB<real, nSpin, nColor, 1, csOrder, Float, ghostFloat, true>;
using GY = typename gauge::FieldOrder<real, nColor * nSpin, nSpin, gOrder, true, yFloat>;

static constexpr unsigned int max_n_src = 64;
static constexpr unsigned int max_n_src = MAX_MULTI_RHS;
const int_fastdiv n_src;
F out[max_n_src];
F inA[max_n_src];
Expand Down
86 changes: 61 additions & 25 deletions include/kernels/prolongator.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <color_spinor_field_order.h>
#include <color_spinor.h>
#include <multigrid_helper.cuh>
#include <kernel.h>

Expand All @@ -14,45 +15,64 @@ namespace quda {
/**
Kernel argument struct
*/
template <typename Float, typename vFloat, int fineSpin_, int fineColor_, int coarseSpin_, int coarseColor_>
template <typename Float, typename vFloat, int fineSpin_, int fineColor_, int coarseSpin_, int coarseColor_, bool to_non_rel_>
struct ProlongateArg : kernel_param<> {
using real = Float;
static constexpr int fineSpin = fineSpin_;
static constexpr int coarseSpin = coarseSpin_;
static constexpr int fineColor = fineColor_;
static constexpr int coarseColor = coarseColor_;

FieldOrderCB<Float,fineSpin,fineColor,1, colorspinor::getNative<Float>(fineSpin)> out;;
const FieldOrderCB<Float,coarseSpin,coarseColor,1, colorspinor::getNative<Float>(coarseSpin)> in;
const FieldOrderCB<Float,fineSpin,fineColor,coarseColor, colorspinor::getNative<vFloat>(fineSpin), vFloat> V;
static constexpr bool to_non_rel = to_non_rel_;

// disable ghost to reduce arg size
using F = FieldOrderCB<Float, fineSpin, fineColor, 1, colorspinor::getNative<Float>(fineSpin), Float, Float, true>;
using C = FieldOrderCB<Float, coarseSpin, coarseColor, 1, colorspinor::getNative<Float>(coarseSpin), Float, Float, true>;
using V = FieldOrderCB<Float, fineSpin, fineColor, coarseColor, colorspinor::getNative<vFloat>(fineSpin), vFloat, vFloat>;

static constexpr unsigned int max_n_src = MAX_MULTI_RHS;
const int_fastdiv n_src;
F out[max_n_src];
C in[max_n_src];
const V v;
const int *geo_map; // need to make a device copy of this
const spin_mapper<fineSpin,coarseSpin> spin_map;
const int parity; // the parity of the output field (if single parity)
const int nParity; // number of parities of input fine field

ProlongateArg(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &V,
ProlongateArg(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in, const ColorSpinorField &v,
const int *geo_map, const int parity) :
kernel_param(dim3(out.VolumeCB(), out.SiteSubset(), fineColor/fine_colors_per_thread<fineColor, coarseColor>())),
out(out), in(in), V(V), geo_map(geo_map), spin_map(), parity(parity), nParity(out.SiteSubset())
{ }
kernel_param(dim3(out[0].VolumeCB(), out[0].SiteSubset() * out.size(), fineColor/fine_colors_per_thread<fineColor, coarseColor>())),
n_src(out.size()),
v(v),
geo_map(geo_map),
spin_map(),
parity(parity),
nParity(out[0].SiteSubset())
{
if (out.size() > max_n_src) errorQuda("vector set size %lu greater than max size %d", out.size(), max_n_src);
for (auto i = 0u; i < out.size(); i++) {
this->out[i] = out[i];
this->in[i] = in[i];
}
}
};

/**
Applies the grid prolongation operator (coarse to fine)
*/
template <typename Arg>
__device__ __host__ inline void prolongate(complex<typename Arg::real> out[], const Arg &arg, int parity, int x_cb)
__device__ __host__ inline void prolongate(complex<typename Arg::real> out[], const Arg &arg, int src_idx, int parity, int x_cb)
{
int x = parity*arg.out.VolumeCB() + x_cb;
int x = parity * arg.out[src_idx].VolumeCB() + x_cb;
int x_coarse = arg.geo_map[x];
int parity_coarse = (x_coarse >= arg.in.VolumeCB()) ? 1 : 0;
int x_coarse_cb = x_coarse - parity_coarse*arg.in.VolumeCB();
int parity_coarse = (x_coarse >= arg.in[src_idx].VolumeCB()) ? 1 : 0;
int x_coarse_cb = x_coarse - parity_coarse * arg.in[src_idx].VolumeCB();

#pragma unroll
for (int s=0; s<Arg::fineSpin; s++) {
#pragma unroll
for (int c=0; c<Arg::coarseColor; c++) {
out[s*Arg::coarseColor+c] = arg.in(parity_coarse, x_coarse_cb, arg.spin_map(s,parity), c);
out[s*Arg::coarseColor+c] = arg.in[src_idx](parity_coarse, x_coarse_cb, arg.spin_map(s,parity), c);
}
}
}
Expand All @@ -62,14 +82,16 @@ namespace quda {
is the second step of applying the prolongator.
*/
template <typename Arg>
__device__ __host__ inline void rotateFineColor(const Arg &arg, const complex<typename Arg::real> in[], int parity, int x_cb, int fine_color_block)
__device__ __host__ inline void rotateFineColor(const Arg &arg, const complex<typename Arg::real> in[], int src_idx, int parity, int x_cb, int fine_color_block)
{
constexpr int fine_color_per_thread = fine_colors_per_thread<Arg::fineColor, Arg::coarseColor>();
const int spinor_parity = (arg.nParity == 2) ? parity : 0;
const int v_parity = (arg.V.Nparity() == 2) ? parity : 0;
const int v_parity = (arg.v.Nparity() == 2) ? parity : 0;

constexpr int color_unroll = 2;

ColorSpinor<typename Arg::real, fine_color_per_thread, Arg::fineSpin> out;

#pragma unroll
for (int s=0; s<Arg::fineSpin; s++) {
#pragma unroll
Expand All @@ -82,18 +104,30 @@ namespace quda {

#pragma unroll
for (int j=0; j<Arg::coarseColor; j+=color_unroll) {
// V is a ColorMatrixField with internal dimensions Ns * Nc * Nvec
// v is a ColorMatrixField with internal dimensions Ns * Nc * Nvec
#pragma unroll
for (int k=0; k<color_unroll; k++)
partial[k] = cmac(arg.V(v_parity, x_cb, s, i, j + k), in[s * Arg::coarseColor + j + k], partial[k]);
partial[k] = cmac(arg.v(v_parity, x_cb, s, i, j + k), in[s * Arg::coarseColor + j + k], partial[k]);
}

#pragma unroll
for (int k=1; k<color_unroll; k++) partial[0] += partial[k];
arg.out(spinor_parity, x_cb, s, i) = partial[0];
for (int k = 0; k < color_unroll; k++) out(s, fine_color_local) += partial[k];
}
}

if constexpr (Arg::fineSpin == 4 && Arg::to_non_rel) {
out.toNonRel();
out *= rsqrt(static_cast<typename Arg::real>(2.0));
}

#pragma unroll
for (int s = 0; s < Arg::fineSpin; s++) {
#pragma unroll
for (int fine_color_local = 0; fine_color_local < fine_color_per_thread; fine_color_local++) {
int i = fine_color_block + fine_color_local; // global fine color index
arg.out[src_idx](spinor_parity, x_cb, s, i) = out(s, fine_color_local);
}
}
}

template <typename Arg> struct Prolongator
Expand All @@ -103,13 +137,15 @@ namespace quda {
constexpr Prolongator(const Arg &arg) : arg(arg) {}
static constexpr const char *filename() { return KERNEL_FILE; }

__device__ __host__ inline void operator()(int x_cb, int parity, int fine_color_thread)
__device__ __host__ inline void operator()(int x_cb, int src_parity, int fine_color_thread)
{
if (arg.nParity == 1) parity = arg.parity;
int src_idx = src_parity % arg.n_src;
int parity = (arg.nParity == 2) ? (src_parity / arg.n_src) : arg.parity;
const int fine_color_block = fine_color_thread * fine_color_per_thread;
complex<typename Arg::real> tmp[Arg::fineSpin*Arg::coarseColor];
prolongate(tmp, arg, parity, x_cb);
rotateFineColor(arg, tmp, parity, x_cb, fine_color_block);

prolongate(tmp, arg, src_idx, parity, x_cb);
rotateFineColor(arg, tmp, src_idx, parity, x_cb, fine_color_block);
}
};

Expand Down
Loading

0 comments on commit ed6160e

Please sign in to comment.