Skip to content

Commit

Permalink
Merge pull request #1427 from lattice/feature/sycl-merge
Browse files Browse the repository at this point in the history
start tagging kernels
  • Loading branch information
maddyscientist authored May 17, 2024
2 parents 6198d60 + e1293ba commit 2aa17e2
Show file tree
Hide file tree
Showing 20 changed files with 409 additions and 135 deletions.
24 changes: 14 additions & 10 deletions include/kernels/clover_deriv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ namespace quda
}
};

template <typename Link, typename Force, typename Arg>
__device__ __host__ void computeForce(Force &force_total, const Arg &arg, int xIndex, int parity, int mu, int nu)
using computeForceOps = KernelOps<thread_array<int, 4>>;
template <typename Link, typename Force, typename Ftor>
__device__ __host__ void computeForce(Force &force_total, const Ftor &ftor, int xIndex, int parity, int mu, int nu)
{
const auto &arg = ftor.arg;
const int otherparity = (1 - parity);
const int tidx = mu > nu ? (mu - 1) * mu / 2 + nu : (nu - 1) * nu / 2 + mu;

Expand All @@ -44,7 +46,7 @@ namespace quda

// U[mu](x) U[nu](x+mu) U[*mu](x+nu) U[*nu](x) Oprod(x)
{
thread_array<int, 4> d = {};
thread_array<int, 4> d {ftor};

// load U(x)_(+mu)
Link U1 = arg.gauge(mu, linkIndexShift(x, d, arg.E), parity);
Expand Down Expand Up @@ -78,7 +80,7 @@ namespace quda
}

{
thread_array<int, 4> d = {};
thread_array<int, 4> d {ftor};

// load U(x-nu)(+nu)
d[nu]--;
Expand Down Expand Up @@ -117,7 +119,7 @@ namespace quda
}

{
thread_array<int, 4> d = {};
thread_array<int, 4> d {ftor};

// load U(x)_(+mu)
Link U1 = arg.gauge(mu, linkIndexShift(x, d, arg.E), parity);
Expand Down Expand Up @@ -155,7 +157,7 @@ namespace quda
// Lower leaf
// U[nu*](x-nu) U[mu](x-nu) U[nu](x+mu-nu) Oprod(x+mu) U[*mu](x)
{
thread_array<int, 4> d = {};
thread_array<int, 4> d {ftor};

// load U(x-nu)(+nu)
d[nu]--;
Expand Down Expand Up @@ -194,10 +196,12 @@ namespace quda
}
}

template <typename Arg> struct CloverDerivative
{
template <typename Arg> struct CloverDerivative : computeForceOps {
const Arg &arg;
constexpr CloverDerivative(const Arg &arg) : arg(arg) {}
template <typename... OpsArgs>
constexpr CloverDerivative(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg)
{
}
static constexpr const char *filename() { return KERNEL_FILE; }

__host__ __device__ void operator()(int x_cb, int parity, int mu)
Expand All @@ -210,7 +214,7 @@ namespace quda

for (int nu = 0; nu < 4; nu++) {
if (nu == mu) continue;
computeForce<Link>(force, arg, x_cb, parity, mu, nu);
computeForce<Link>(force, *this, x_cb, parity, mu, nu);
}

// Write to array
Expand Down
27 changes: 17 additions & 10 deletions include/kernels/field_strength_tensor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ namespace quda
}
};

template <typename Arg>
__device__ __host__ inline void computeFmunuCore(const Arg &arg, int idx, int parity, int mu, int nu)
using computeFmunuCoreOps = KernelOps<thread_array<int, 4>>;
template <typename Ftor>
__device__ __host__ inline void computeFmunuCore(const Ftor &ftor, int idx, int parity, int mu, int nu)
{
using Arg = typename Ftor::Arg;
using Link = Matrix<complex<typename Arg::Float>, 3>;
auto &arg = ftor.arg;

int x[4];
int X[4];
Expand All @@ -53,7 +56,7 @@ namespace quda
{ // U(x,mu) U(x+mu,nu) U[dagger](x+nu,mu) U[dagger](x,nu)

// load U(x)_(+mu)
thread_array<int, 4> dx = {0, 0, 0, 0};
thread_array<int, 4> dx {ftor};
Link U1 = arg.u(mu, linkIndexShift(x, dx, X), parity);

// load U(x+mu)_(+nu)
Expand All @@ -76,7 +79,7 @@ namespace quda
{ // U(x,nu) U[dagger](x+nu-mu,mu) U[dagger](x-mu,nu) U(x-mu, mu)

// load U(x)_(+nu)
thread_array<int, 4> dx = {0, 0, 0, 0};
thread_array<int, 4> dx {ftor};
Link U1 = arg.u(nu, linkIndexShift(x, dx, X), parity);

// load U(x+nu)_(-mu) = U(x+nu-mu)_(+mu)
Expand All @@ -103,7 +106,7 @@ namespace quda
{ // U[dagger](x-nu,nu) U(x-nu,mu) U(x+mu-nu,nu) U[dagger](x,mu)

// load U(x)_(-nu)
thread_array<int, 4> dx = {0, 0, 0, 0};
thread_array<int, 4> dx {ftor};
dx[nu]--;
Link U1 = arg.u(nu, linkIndexShift(x, dx, X), 1 - parity);
dx[nu]++;
Expand All @@ -130,7 +133,7 @@ namespace quda
{ // U[dagger](x-mu,mu) U[dagger](x-mu-nu,nu) U(x-mu-nu,mu) U(x-nu,nu)

// load U(x)_(-mu)
thread_array<int, 4> dx = {0, 0, 0, 0};
thread_array<int, 4> dx {ftor};
dx[mu]--;
Link U1 = arg.u(mu, linkIndexShift(x, dx, X), 1 - parity);
dx[mu]++;
Expand Down Expand Up @@ -169,14 +172,18 @@ namespace quda
F *= static_cast<typename Arg::Float>(0.125); // 18 real multiplications
// 36 floating point operations here
}

int munu_idx = (mu * (mu - 1)) / 2 + nu; // lower-triangular indexing
arg.f(munu_idx, idx, parity) = F;
}

template <typename Arg> struct ComputeFmunu {
template <typename Arg_> struct ComputeFmunu : computeFmunuCoreOps {
using Arg = Arg_;
const Arg &arg;
constexpr ComputeFmunu(const Arg &arg) : arg(arg) {}
template <typename... OpsArgs>
constexpr ComputeFmunu(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg)
{
}
static constexpr const char* filename() { return KERNEL_FILE; }

__device__ __host__ inline void operator()(int x_cb, int parity, int mu_nu)
Expand All @@ -190,7 +197,7 @@ namespace quda
case 4: mu = 3, nu = 1; break;
case 5: mu = 3, nu = 2; break;
}
computeFmunuCore(arg, x_cb, parity, mu, nu);
computeFmunuCore(*this, x_cb, parity, mu, nu);
}
};

Expand Down
18 changes: 10 additions & 8 deletions include/kernels/gauge_ape.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ namespace quda
static_assert(nColor == 3, "Only nColor=3 enabled at this time");
static constexpr QudaReconstructType recon = recon_;
static constexpr int apeDim = apeDim_;
typedef typename gauge_mapper<Float,recon>::type Gauge;
typedef typename gauge_mapper<Float, recon>::type Gauge;

Gauge out;
const Gauge in;

int X[4]; // grid dimensions
int X[4]; // grid dimensions
int border[4];
const Float alpha;
const int dir_ignore;
Expand All @@ -40,11 +40,13 @@ namespace quda
}
}
};
template <typename Arg> struct APE {

template <typename Arg> struct APE : computeStapleOps {
const Arg &arg;
constexpr APE(const Arg &arg) : arg(arg) {}
static constexpr const char* filename() { return KERNEL_FILE; }
template <typename... OpsArgs> constexpr APE(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg)
{
}
static constexpr const char *filename() { return KERNEL_FILE; }

__device__ __host__ inline void operator()(int x_cb, int parity, int dir)
{
Expand All @@ -65,7 +67,7 @@ namespace quda
int dx[4] = {0, 0, 0, 0};
Link U, Stap, TestU, I;
// This function gets stap = S_{mu,nu} i.e., the staple of length 3,
computeStaple(arg, x, X, parity, dir, Stap, arg.dir_ignore);
computeStaple(*this, x, X, parity, dir, Stap, arg.dir_ignore);

// Get link U
U = arg.in(dir, linkIndexShift(x, dx, X), parity);
Expand All @@ -76,7 +78,7 @@ namespace quda
TestU = I * (static_cast<real>(1.0) - arg.alpha) + Stap * conj(U);
polarSu3<real>(TestU, arg.tolerance);
U = TestU * U;

arg.out(dir, linkIndexShift(x, dx, X), parity) = U;
}
};
Expand Down
13 changes: 7 additions & 6 deletions include/kernels/gauge_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ namespace quda {
}
};

template <typename Arg> struct GaugeForce
{
template <typename Arg> struct GaugeForce : KernelOps<thread_array<int, 4>> {
const Arg &arg;
constexpr GaugeForce(const Arg &arg) : arg(arg) {}
static constexpr const char *filename() { return KERNEL_FILE; }
template <typename... OpsArgs>
constexpr GaugeForce(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg)
{
}
static constexpr const char *filename() { return KERNEL_FILE; }

__device__ __host__ void operator()(int x_cb, int parity, int dir)
{
Expand All @@ -62,7 +64,7 @@ namespace quda {
// prod: current matrix product
// accum: accumulator matrix
Link link_prod, accum;
thread_array<int, 4> dx{0};
thread_array<int, 4> dx {*this};

for (int i=0; i<arg.p.num_paths; i++) {
real coeff = arg.p.path_coeff[i];
Expand Down Expand Up @@ -95,5 +97,4 @@ namespace quda {
arg.mom(dir, x_cb, parity) = mom;
}
};

}
18 changes: 11 additions & 7 deletions include/kernels/gauge_hyp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,11 @@ namespace quda
}
}

template <typename Arg> struct HYP {
template <typename Arg> struct HYP : KernelOps<thread_array<int, 4>> {
const Arg &arg;
constexpr HYP(const Arg &arg) : arg(arg) { }
template <typename... OpsArgs> constexpr HYP(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg)
{
}
static constexpr const char *filename() { return KERNEL_FILE; }

__device__ __host__ inline void operator()(int x_cb, int parity, int dir)
Expand All @@ -213,7 +215,7 @@ namespace quda
#pragma unroll
for (int dr = 0; dr < 4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates

thread_array<int, 4> dx = {0, 0, 0, 0};
thread_array<int, 4> dx {*this};

Link U, Stap[3], TestU, I;

Expand Down Expand Up @@ -300,9 +302,11 @@ namespace quda
}
}

template <typename Arg> struct HYP3D {
template <typename Arg> struct HYP3D : KernelOps<thread_array<int, 4>> {
const Arg &arg;
constexpr HYP3D(const Arg &arg) : arg(arg) { }
template <typename... OpsArgs> constexpr HYP3D(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg)
{
}
static constexpr const char *filename() { return KERNEL_FILE; }

__device__ __host__ inline void operator()(int x_cb, int parity, int dir)
Expand All @@ -316,7 +320,7 @@ namespace quda
#pragma unroll
for (int dr = 0; dr < 4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates

thread_array<int, 4> dx = {0, 0, 0, 0};
thread_array<int, 4> dx {*this};

int dir_ = dir;
dir = dir + (dir >= arg.dir_ignore);
Expand Down Expand Up @@ -344,4 +348,4 @@ namespace quda
}
}
};
} // namespace quda
} // namespace quda
11 changes: 6 additions & 5 deletions include/kernels/gauge_loop_trace.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ namespace quda {
}
};

template <typename Arg> struct GaugeLoop : plus<typename Arg::reduce_t>
{
template <typename Arg> struct GaugeLoop : plus<typename Arg::reduce_t>, KernelOps<thread_array<int, 4>> {
using reduce_t = typename Arg::reduce_t;
using plus<reduce_t>::operator();
static constexpr int reduce_block_dim = 2; // x_cb and parity are mapped to x
const Arg &arg;
constexpr GaugeLoop(const Arg &arg) : arg(arg) {}
template <typename... OpsArgs>
constexpr GaugeLoop(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg)
{
}
static constexpr const char *filename() { return KERNEL_FILE; }

__device__ __host__ inline reduce_t operator()(reduce_t &value, int x_cb, int parity, int path_id)
Expand All @@ -71,7 +73,7 @@ namespace quda {
getCoords(x, x_cb, arg.X, parity);
for (int dr=0; dr<4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates

thread_array<int, 4> dx{0};
thread_array<int, 4> dx {*this};

double coeff_loop = arg.factor * arg.p.path_coeff[path_id];
if (coeff_loop == 0) return operator()(loop_trace, value);
Expand All @@ -90,5 +92,4 @@ namespace quda {
return operator()(loop_trace, value);
}
};

}
Loading

0 comments on commit 2aa17e2

Please sign in to comment.