Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

start tagging kernels #1427

Merged
merged 9 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions include/kernels/clover_deriv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ namespace quda
}
};

template <typename Link, typename Force, typename Arg>
__device__ __host__ void computeForce(Force &force_total, const Arg &arg, int xIndex, int parity, int mu, int nu)
using computeForceOps = KernelOps<thread_array<int, 4>>;
template <typename Link, typename Force, typename Ftor>
__device__ __host__ void computeForce(Force &force_total, const Ftor &ftor, int xIndex, int parity, int mu, int nu)
{
const auto &arg = ftor.arg;
const int otherparity = (1 - parity);
const int tidx = mu > nu ? (mu - 1) * mu / 2 + nu : (nu - 1) * nu / 2 + mu;

Expand All @@ -44,7 +46,7 @@ namespace quda

// U[mu](x) U[nu](x+mu) U[*mu](x+nu) U[*nu](x) Oprod(x)
{
thread_array<int, 4> d = {};
thread_array<int, 4> d {ftor};

// load U(x)_(+mu)
Link U1 = arg.gauge(mu, linkIndexShift(x, d, arg.E), parity);
Expand Down Expand Up @@ -78,7 +80,7 @@ namespace quda
}

{
thread_array<int, 4> d = {};
thread_array<int, 4> d {ftor};

// load U(x-nu)(+nu)
d[nu]--;
Expand Down Expand Up @@ -117,7 +119,7 @@ namespace quda
}

{
thread_array<int, 4> d = {};
thread_array<int, 4> d {ftor};

// load U(x)_(+mu)
Link U1 = arg.gauge(mu, linkIndexShift(x, d, arg.E), parity);
Expand Down Expand Up @@ -155,7 +157,7 @@ namespace quda
// Lower leaf
// U[nu*](x-nu) U[mu](x-nu) U[nu](x+mu-nu) Oprod(x+mu) U[*mu](x)
{
thread_array<int, 4> d = {};
thread_array<int, 4> d {ftor};

// load U(x-nu)(+nu)
d[nu]--;
Expand Down Expand Up @@ -194,10 +196,12 @@ namespace quda
}
}

template <typename Arg> struct CloverDerivative
{
template <typename Arg> struct CloverDerivative : computeForceOps {
const Arg &arg;
constexpr CloverDerivative(const Arg &arg) : arg(arg) {}
template <typename... OpsArgs>
constexpr CloverDerivative(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg)
{
}
static constexpr const char *filename() { return KERNEL_FILE; }

__host__ __device__ void operator()(int x_cb, int parity, int mu)
Expand All @@ -210,7 +214,7 @@ namespace quda

for (int nu = 0; nu < 4; nu++) {
if (nu == mu) continue;
computeForce<Link>(force, arg, x_cb, parity, mu, nu);
computeForce<Link>(force, *this, x_cb, parity, mu, nu);
}

// Write to array
Expand Down
27 changes: 17 additions & 10 deletions include/kernels/field_strength_tensor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ namespace quda
}
};

template <typename Arg>
__device__ __host__ inline void computeFmunuCore(const Arg &arg, int idx, int parity, int mu, int nu)
using computeFmunuCoreOps = KernelOps<thread_array<int, 4>>;
template <typename Ftor>
__device__ __host__ inline void computeFmunuCore(const Ftor &ftor, int idx, int parity, int mu, int nu)
{
using Arg = typename Ftor::Arg;
using Link = Matrix<complex<typename Arg::Float>, 3>;
auto &arg = ftor.arg;

int x[4];
int X[4];
Expand All @@ -53,7 +56,7 @@ namespace quda
{ // U(x,mu) U(x+mu,nu) U[dagger](x+nu,mu) U[dagger](x,nu)

// load U(x)_(+mu)
thread_array<int, 4> dx = {0, 0, 0, 0};
thread_array<int, 4> dx {ftor};
Link U1 = arg.u(mu, linkIndexShift(x, dx, X), parity);

// load U(x+mu)_(+nu)
Expand All @@ -76,7 +79,7 @@ namespace quda
{ // U(x,nu) U[dagger](x+nu-mu,mu) U[dagger](x-mu,nu) U(x-mu, mu)

// load U(x)_(+nu)
thread_array<int, 4> dx = {0, 0, 0, 0};
thread_array<int, 4> dx {ftor};
Link U1 = arg.u(nu, linkIndexShift(x, dx, X), parity);

// load U(x+nu)_(-mu) = U(x+nu-mu)_(+mu)
Expand All @@ -103,7 +106,7 @@ namespace quda
{ // U[dagger](x-nu,nu) U(x-nu,mu) U(x+mu-nu,nu) U[dagger](x,mu)

// load U(x)_(-nu)
thread_array<int, 4> dx = {0, 0, 0, 0};
thread_array<int, 4> dx {ftor};
dx[nu]--;
Link U1 = arg.u(nu, linkIndexShift(x, dx, X), 1 - parity);
dx[nu]++;
Expand All @@ -130,7 +133,7 @@ namespace quda
{ // U[dagger](x-mu,mu) U[dagger](x-mu-nu,nu) U(x-mu-nu,mu) U(x-nu,nu)

// load U(x)_(-mu)
thread_array<int, 4> dx = {0, 0, 0, 0};
thread_array<int, 4> dx {ftor};
dx[mu]--;
Link U1 = arg.u(mu, linkIndexShift(x, dx, X), 1 - parity);
dx[mu]++;
Expand Down Expand Up @@ -169,14 +172,18 @@ namespace quda
F *= static_cast<typename Arg::Float>(0.125); // 18 real multiplications
// 36 floating point operations here
}

int munu_idx = (mu * (mu - 1)) / 2 + nu; // lower-triangular indexing
arg.f(munu_idx, idx, parity) = F;
}

template <typename Arg> struct ComputeFmunu {
template <typename Arg_> struct ComputeFmunu : computeFmunuCoreOps {
using Arg = Arg_;
const Arg &arg;
constexpr ComputeFmunu(const Arg &arg) : arg(arg) {}
template <typename... OpsArgs>
constexpr ComputeFmunu(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg)
{
}
static constexpr const char* filename() { return KERNEL_FILE; }

__device__ __host__ inline void operator()(int x_cb, int parity, int mu_nu)
Expand All @@ -190,7 +197,7 @@ namespace quda
case 4: mu = 3, nu = 1; break;
case 5: mu = 3, nu = 2; break;
}
computeFmunuCore(arg, x_cb, parity, mu, nu);
computeFmunuCore(*this, x_cb, parity, mu, nu);
}
};

Expand Down
18 changes: 10 additions & 8 deletions include/kernels/gauge_ape.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ namespace quda
static_assert(nColor == 3, "Only nColor=3 enabled at this time");
static constexpr QudaReconstructType recon = recon_;
static constexpr int apeDim = apeDim_;
typedef typename gauge_mapper<Float,recon>::type Gauge;
typedef typename gauge_mapper<Float, recon>::type Gauge;

Gauge out;
const Gauge in;

int X[4]; // grid dimensions
int X[4]; // grid dimensions
int border[4];
const Float alpha;
const int dir_ignore;
Expand All @@ -40,11 +40,13 @@ namespace quda
}
}
};
template <typename Arg> struct APE {

template <typename Arg> struct APE : computeStapleOps {
const Arg &arg;
constexpr APE(const Arg &arg) : arg(arg) {}
static constexpr const char* filename() { return KERNEL_FILE; }
template <typename... OpsArgs> constexpr APE(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg)
{
}
static constexpr const char *filename() { return KERNEL_FILE; }

__device__ __host__ inline void operator()(int x_cb, int parity, int dir)
{
Expand All @@ -65,7 +67,7 @@ namespace quda
int dx[4] = {0, 0, 0, 0};
Link U, Stap, TestU, I;
// This function gets stap = S_{mu,nu} i.e., the staple of length 3,
computeStaple(arg, x, X, parity, dir, Stap, arg.dir_ignore);
computeStaple(*this, x, X, parity, dir, Stap, arg.dir_ignore);

// Get link U
U = arg.in(dir, linkIndexShift(x, dx, X), parity);
Expand All @@ -76,7 +78,7 @@ namespace quda
TestU = I * (static_cast<real>(1.0) - arg.alpha) + Stap * conj(U);
polarSu3<real>(TestU, arg.tolerance);
U = TestU * U;

arg.out(dir, linkIndexShift(x, dx, X), parity) = U;
}
};
Expand Down
13 changes: 7 additions & 6 deletions include/kernels/gauge_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ namespace quda {
}
};

template <typename Arg> struct GaugeForce
{
template <typename Arg> struct GaugeForce : KernelOps<thread_array<int, 4>> {
const Arg &arg;
constexpr GaugeForce(const Arg &arg) : arg(arg) {}
static constexpr const char *filename() { return KERNEL_FILE; }
template <typename... OpsArgs>
constexpr GaugeForce(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg)
{
}
static constexpr const char *filename() { return KERNEL_FILE; }

__device__ __host__ void operator()(int x_cb, int parity, int dir)
{
Expand All @@ -62,7 +64,7 @@ namespace quda {
// prod: current matrix product
// accum: accumulator matrix
Link link_prod, accum;
thread_array<int, 4> dx{0};
thread_array<int, 4> dx {*this};

for (int i=0; i<arg.p.num_paths; i++) {
real coeff = arg.p.path_coeff[i];
Expand Down Expand Up @@ -95,5 +97,4 @@ namespace quda {
arg.mom(dir, x_cb, parity) = mom;
}
};

}
18 changes: 11 additions & 7 deletions include/kernels/gauge_hyp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,11 @@ namespace quda
}
}

template <typename Arg> struct HYP {
template <typename Arg> struct HYP : KernelOps<thread_array<int, 4>> {
const Arg &arg;
constexpr HYP(const Arg &arg) : arg(arg) { }
template <typename... OpsArgs> constexpr HYP(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg)
{
}
static constexpr const char *filename() { return KERNEL_FILE; }

__device__ __host__ inline void operator()(int x_cb, int parity, int dir)
Expand All @@ -213,7 +215,7 @@ namespace quda
#pragma unroll
for (int dr = 0; dr < 4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates

thread_array<int, 4> dx = {0, 0, 0, 0};
thread_array<int, 4> dx {*this};

Link U, Stap[3], TestU, I;

Expand Down Expand Up @@ -300,9 +302,11 @@ namespace quda
}
}

template <typename Arg> struct HYP3D {
template <typename Arg> struct HYP3D : KernelOps<thread_array<int, 4>> {
const Arg &arg;
constexpr HYP3D(const Arg &arg) : arg(arg) { }
template <typename... OpsArgs> constexpr HYP3D(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg)
{
}
static constexpr const char *filename() { return KERNEL_FILE; }

__device__ __host__ inline void operator()(int x_cb, int parity, int dir)
Expand All @@ -316,7 +320,7 @@ namespace quda
#pragma unroll
for (int dr = 0; dr < 4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates

thread_array<int, 4> dx = {0, 0, 0, 0};
thread_array<int, 4> dx {*this};

int dir_ = dir;
dir = dir + (dir >= arg.dir_ignore);
Expand Down Expand Up @@ -344,4 +348,4 @@ namespace quda
}
}
};
} // namespace quda
} // namespace quda
11 changes: 6 additions & 5 deletions include/kernels/gauge_loop_trace.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ namespace quda {
}
};

template <typename Arg> struct GaugeLoop : plus<typename Arg::reduce_t>
{
template <typename Arg> struct GaugeLoop : plus<typename Arg::reduce_t>, KernelOps<thread_array<int, 4>> {
using reduce_t = typename Arg::reduce_t;
using plus<reduce_t>::operator();
static constexpr int reduce_block_dim = 2; // x_cb and parity are mapped to x
const Arg &arg;
constexpr GaugeLoop(const Arg &arg) : arg(arg) {}
template <typename... OpsArgs>
constexpr GaugeLoop(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg)
{
}
static constexpr const char *filename() { return KERNEL_FILE; }

__device__ __host__ inline reduce_t operator()(reduce_t &value, int x_cb, int parity, int path_id)
Expand All @@ -71,7 +73,7 @@ namespace quda {
getCoords(x, x_cb, arg.X, parity);
for (int dr=0; dr<4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates

thread_array<int, 4> dx{0};
thread_array<int, 4> dx {*this};

double coeff_loop = arg.factor * arg.p.path_coeff[path_id];
if (coeff_loop == 0) return operator()(loop_trace, value);
Expand All @@ -90,5 +92,4 @@ namespace quda {
return operator()(loop_trace, value);
}
};

}
Loading
Loading