Skip to content

Commit

Permalink
Code cleanup, introducing extended field coordinate E
Browse files Browse the repository at this point in the history
  • Loading branch information
weinbe2 committed Feb 1, 2024
1 parent e0d97b4 commit 9545d28
Showing 1 changed file with 61 additions and 66 deletions.
127 changes: 61 additions & 66 deletions include/kernels/gauge_hyp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace quda
Gauge tmp[4];
const Gauge in;

int E[4]; // extended grid dimensions

This comment has been minimized.

Copy link
@maddyscientist

maddyscientist Feb 1, 2024

Member

I think both X and E should be int_fastdiv

This comment has been minimized.

Copy link
@weinbe2

weinbe2 Feb 1, 2024

Author Contributor
int X[4]; // grid dimensions
int border[4];
const Float alpha;
Expand All @@ -39,13 +40,14 @@ namespace quda
{
for (int dir = 0; dir < 4; ++dir) {
border[dir] = in.R()[dir];
E[dir] = in.X()[dir];
X[dir] = in.X()[dir] - border[dir] * 2;
}
}
};

template <typename Arg, typename Staple, typename Int>
__host__ __device__ inline void computeStapleLevel1(const Arg &arg, const int *x, const Int *X, const int parity,
template <typename Arg, typename Staple>
__host__ __device__ inline void computeStapleLevel1(const Arg &arg, const int *x, const int parity,
const int mu, Staple staple[3])
{
using Link = typename get_type<Staple>::type;
Expand All @@ -64,16 +66,16 @@ namespace quda

{
// Get link U_{\nu}(x)
Link U1 = arg.in(nu, linkIndexShift(x, dx, X), parity);
Link U1 = arg.in(nu, linkIndexShift(x, dx, arg.E), parity);

// Get link U_{\mu}(x+\nu)
dx[nu]++;
Link U2 = arg.in(mu, linkIndexShift(x, dx, X), 1 - parity);
Link U2 = arg.in(mu, linkIndexShift(x, dx, arg.E), 1 - parity);
dx[nu]--;

// Get link U_{\nu}(x+\mu)
dx[mu]++;
Link U3 = arg.in(nu, linkIndexShift(x, dx, X), 1 - parity);
Link U3 = arg.in(nu, linkIndexShift(x, dx, arg.E), 1 - parity);
dx[mu]--;

// staple += U_{\nu}(x) * U_{\mu}(x+\nu) * U^\dag_{\nu}(x+\mu)
Expand All @@ -83,13 +85,13 @@ namespace quda
{
// Get link U_{\nu}(x-\nu)
dx[nu]--;
Link U1 = arg.in(nu, linkIndexShift(x, dx, X), 1 - parity);
Link U1 = arg.in(nu, linkIndexShift(x, dx, arg.E), 1 - parity);
// Get link U_{\mu}(x-\nu)
Link U2 = arg.in(mu, linkIndexShift(x, dx, X), 1 - parity);
Link U2 = arg.in(mu, linkIndexShift(x, dx, arg.E), 1 - parity);

// Get link U_{\nu}(x-\nu+\mu)
dx[mu]++;
Link U3 = arg.in(nu, linkIndexShift(x, dx, X), parity);
Link U3 = arg.in(nu, linkIndexShift(x, dx, arg.E), parity);

// reset dx
dx[mu]--;
Expand All @@ -101,8 +103,8 @@ namespace quda
}
}

template <typename Arg, typename Staple, typename Int>
__host__ __device__ inline void computeStapleLevel2(const Arg &arg, const int *x, const Int *X, const int parity,
template <typename Arg, typename Staple>
__host__ __device__ inline void computeStapleLevel2(const Arg &arg, const int *x, const int parity,
const int mu, Staple staple[3])
{
using Link = typename get_type<Staple>::type;
Expand All @@ -129,16 +131,16 @@ namespace quda

{
// Get link U_{\rho}(x)
Link U1 = arg.tmp[rho / 2](sigma_with_rho, linkIndexShift(x, dx, X), parity);
Link U1 = arg.tmp[rho / 2](sigma_with_rho, linkIndexShift(x, dx, arg.E), parity);

// Get link U_{\mu}(x+\rho)
dx[rho]++;
Link U2 = arg.tmp[mu / 2](sigma_with_mu, linkIndexShift(x, dx, X), 1 - parity);
Link U2 = arg.tmp[mu / 2](sigma_with_mu, linkIndexShift(x, dx, arg.E), 1 - parity);
dx[rho]--;

// Get link U_{\rho}(x+\mu)
dx[mu]++;
Link U3 = arg.tmp[rho / 2](sigma_with_rho, linkIndexShift(x, dx, X), 1 - parity);
Link U3 = arg.tmp[rho / 2](sigma_with_rho, linkIndexShift(x, dx, arg.E), 1 - parity);
dx[mu]--;

// staple += U_{\rho}(x) * U_{\mu}(x+\rho) * U^\dag_{\rho}(x+\mu)
Expand All @@ -148,13 +150,13 @@ namespace quda
{
// Get link U_{\rho}(x-\rho)
dx[rho]--;
Link U1 = arg.tmp[rho / 2](sigma_with_rho, linkIndexShift(x, dx, X), 1 - parity);
Link U1 = arg.tmp[rho / 2](sigma_with_rho, linkIndexShift(x, dx, arg.E), 1 - parity);
// Get link U_{\mu}(x-\rho)
Link U2 = arg.tmp[mu / 2](sigma_with_mu, linkIndexShift(x, dx, X), 1 - parity);
Link U2 = arg.tmp[mu / 2](sigma_with_mu, linkIndexShift(x, dx, arg.E), 1 - parity);

// Get link U_{\rho}(x-\rho+\mu)
dx[mu]++;
Link U3 = arg.tmp[rho / 2](sigma_with_rho, linkIndexShift(x, dx, X), parity);
Link U3 = arg.tmp[rho / 2](sigma_with_rho, linkIndexShift(x, dx, arg.E), parity);

// reset dx
dx[mu]--;
Expand All @@ -167,8 +169,8 @@ namespace quda
}
}

template <typename Arg, typename Staple, typename Int>
__host__ __device__ inline void computeStapleLevel3(const Arg &arg, const int *x, const Int *X, const int parity,
template <typename Arg, typename Staple>
__host__ __device__ inline void computeStapleLevel3(const Arg &arg, const int *x, const int parity,
const int mu, Staple staple[3])
{
using Link = typename get_type<Staple>::type;
Expand All @@ -187,16 +189,16 @@ namespace quda

{
// Get link U_{\nu}(x)
Link U1 = arg.tmp[nu / 2 + 2](mu_with_nu, linkIndexShift(x, dx, X), parity);
Link U1 = arg.tmp[nu / 2 + 2](mu_with_nu, linkIndexShift(x, dx, arg.E), parity);

// Get link U_{\mu}(x+\nu)
dx[nu]++;
Link U2 = arg.tmp[mu / 2 + 2](nu_with_mu, linkIndexShift(x, dx, X), 1 - parity);
Link U2 = arg.tmp[mu / 2 + 2](nu_with_mu, linkIndexShift(x, dx, arg.E), 1 - parity);
dx[nu]--;

// Get link U_{\nu}(x+\mu)
dx[mu]++;
Link U3 = arg.tmp[nu / 2 + 2](mu_with_nu, linkIndexShift(x, dx, X), 1 - parity);
Link U3 = arg.tmp[nu / 2 + 2](mu_with_nu, linkIndexShift(x, dx, arg.E), 1 - parity);
dx[mu]--;

// staple += U_{\nu}(x) * U_{\mu}(x+\nu) * U^\dag_{\nu}(x+\mu)
Expand All @@ -206,13 +208,13 @@ namespace quda
{
// Get link U_{\nu}(x-\nu)
dx[nu]--;
Link U1 = arg.tmp[nu / 2 + 2](mu_with_nu, linkIndexShift(x, dx, X), 1 - parity);
Link U1 = arg.tmp[nu / 2 + 2](mu_with_nu, linkIndexShift(x, dx, arg.E), 1 - parity);
// Get link U_{\mu}(x-\nu)
Link U2 = arg.tmp[mu / 2 + 2](nu_with_mu, linkIndexShift(x, dx, X), 1 - parity);
Link U2 = arg.tmp[mu / 2 + 2](nu_with_mu, linkIndexShift(x, dx, arg.E), 1 - parity);

// Get link U_{\nu}(x-\nu+\mu)
dx[mu]++;
Link U3 = arg.tmp[nu / 2 + 2](mu_with_nu, linkIndexShift(x, dx, X), parity);
Link U3 = arg.tmp[nu / 2 + 2](mu_with_nu, linkIndexShift(x, dx, arg.E), parity);

// reset dx
dx[mu]--;
Expand All @@ -235,50 +237,46 @@ namespace quda
typedef Matrix<complex<real>, Arg::nColor> Link;

// compute spacetime and local coords
int X[4];
for (int dr = 0; dr < 4; ++dr) X[dr] = arg.X[dr];
int x[4];
getCoords(x, x_cb, X, parity);
for (int dr = 0; dr < 4; ++dr) {
x[dr] += arg.border[dr];
X[dr] += 2 * arg.border[dr];
}
getCoords(x, x_cb, arg.E, parity);
#pragma unroll
for (int dr=0; dr<4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates

int dx[4] = {0, 0, 0, 0};
Link U, Stap[3], TestU, I;

// Get link U
U = arg.in(dir, linkIndexShift(x, dx, X), parity);
U = arg.in(dir, linkIndexShift(x, dx, arg.X), parity);

This comment has been minimized.

Copy link
@maddyscientist

maddyscientist Feb 1, 2024

Member

@weinbe2 should this line be arg.E?

This comment has been minimized.

Copy link
@weinbe2

weinbe2 Feb 1, 2024

Author Contributor

2640ae9 Thanks :)

setIdentity(&I);

if constexpr (Arg::level == 1) {
computeStapleLevel1(arg, x, X, parity, dir, Stap);
computeStapleLevel1(arg, x, parity, dir, Stap);

for (int i = 0; i < 3; ++i) {
TestU = I * (static_cast<real>(1.0) - arg.alpha) + Stap[i] * conj(U) * (arg.alpha / ((real)2.));
polarSu3<real>(TestU, arg.tolerance);
arg.tmp[dir / 2](dir % 2 * 3 + i, linkIndexShift(x, dx, X), parity) = TestU * U;
arg.tmp[dir / 2](dir % 2 * 3 + i, linkIndexShift(x, dx, arg.E), parity) = TestU * U;
}
} else if constexpr (Arg::level == 2) {
computeStapleLevel2(arg, x, X, parity, dir, Stap);
computeStapleLevel2(arg, x, parity, dir, Stap);

for (int i = 0; i < 3; ++i) {
TestU = I * (static_cast<real>(1.0) - arg.alpha) + Stap[i] * conj(U) * (arg.alpha / ((real)4.));
polarSu3<real>(TestU, arg.tolerance);
arg.tmp[dir / 2 + 2](dir % 2 * 3 + i, linkIndexShift(x, dx, X), parity) = TestU * U;
arg.tmp[dir / 2 + 2](dir % 2 * 3 + i, linkIndexShift(x, dx, arg.E), parity) = TestU * U;
}
} else if constexpr (Arg::level == 3) {
computeStapleLevel3(arg, x, X, parity, dir, Stap);
computeStapleLevel3(arg, x, parity, dir, Stap);

TestU = I * (static_cast<real>(1.0) - arg.alpha) + Stap[0] * conj(U) * (arg.alpha / ((real)6.));
polarSu3<real>(TestU, arg.tolerance);
arg.out(dir, linkIndexShift(x, dx, X), parity) = TestU * U;
arg.out(dir, linkIndexShift(x, dx, arg.E), parity) = TestU * U;
}
}
};

template <typename Arg, typename Staple, typename Int>
__host__ __device__ inline void computeStaple3DLevel1(const Arg &arg, const int *x, const Int *X, const int parity,
template <typename Arg, typename Staple>
__host__ __device__ inline void computeStaple3DLevel1(const Arg &arg, const int *x, const int parity,
const int mu, Staple staple[2], const int dir_ignore)
{
using Link = typename get_type<Staple>::type;
Expand All @@ -297,16 +295,16 @@ namespace quda

{
// Get link U_{\nu}(x)
Link U1 = arg.in(nu, linkIndexShift(x, dx, X), parity);
Link U1 = arg.in(nu, linkIndexShift(x, dx, arg.E), parity);

// Get link U_{\mu}(x+\nu)
dx[nu]++;
Link U2 = arg.in(mu, linkIndexShift(x, dx, X), 1 - parity);
Link U2 = arg.in(mu, linkIndexShift(x, dx, arg.E), 1 - parity);
dx[nu]--;

// Get link U_{\nu}(x+\mu)
dx[mu]++;
Link U3 = arg.in(nu, linkIndexShift(x, dx, X), 1 - parity);
Link U3 = arg.in(nu, linkIndexShift(x, dx, arg.E), 1 - parity);
dx[mu]--;

// staple += U_{\nu}(x) * U_{\mu}(x+\nu) * U^\dag_{\nu}(x+\mu)
Expand All @@ -316,13 +314,13 @@ namespace quda
{
// Get link U_{\nu}(x-\nu)
dx[nu]--;
Link U1 = arg.in(nu, linkIndexShift(x, dx, X), 1 - parity);
Link U1 = arg.in(nu, linkIndexShift(x, dx, arg.E), 1 - parity);
// Get link U_{\mu}(x-\nu)
Link U2 = arg.in(mu, linkIndexShift(x, dx, X), 1 - parity);
Link U2 = arg.in(mu, linkIndexShift(x, dx, arg.E), 1 - parity);

// Get link U_{\nu}(x-\nu+\mu)
dx[mu]++;
Link U3 = arg.in(nu, linkIndexShift(x, dx, X), parity);
Link U3 = arg.in(nu, linkIndexShift(x, dx, arg.E), parity);

// reset dx
dx[mu]--;
Expand All @@ -334,8 +332,8 @@ namespace quda
}
}

template <typename Arg, typename Staple, typename Int>
__host__ __device__ inline void computeStaple3DLevel2(const Arg &arg, const int *x, const Int *X, const int parity,
template <typename Arg, typename Staple>
__host__ __device__ inline void computeStaple3DLevel2(const Arg &arg, const int *x, const int parity,
const int mu, Staple staple[2], int dir_ignore)
{
using Link = typename get_type<Staple>::type;
Expand All @@ -357,16 +355,16 @@ namespace quda

{
// Get link U_{\nu}(x)
Link U1 = arg.tmp[0](rho_with_nu, linkIndexShift(x, dx, X), parity);
Link U1 = arg.tmp[0](rho_with_nu, linkIndexShift(x, dx, arg.E), parity);

// Get link U_{\mu}(x+\nu)
dx[nu]++;
Link U2 = arg.tmp[0](rho_with_mu, linkIndexShift(x, dx, X), 1 - parity);
Link U2 = arg.tmp[0](rho_with_mu, linkIndexShift(x, dx, arg.E), 1 - parity);
dx[nu]--;

// Get link U_{\nu}(x+\mu)
dx[mu]++;
Link U3 = arg.tmp[0](rho_with_nu, linkIndexShift(x, dx, X), 1 - parity);
Link U3 = arg.tmp[0](rho_with_nu, linkIndexShift(x, dx, arg.E), 1 - parity);
dx[mu]--;

// staple += U_{\nu}(x) * U_{\mu}(x+\nu) * U^\dag_{\nu}(x+\mu)
Expand All @@ -376,13 +374,13 @@ namespace quda
{
// Get link U_{\nu}(x-\nu)
dx[nu]--;
Link U1 = arg.tmp[0](rho_with_nu, linkIndexShift(x, dx, X), 1 - parity);
Link U1 = arg.tmp[0](rho_with_nu, linkIndexShift(x, dx, arg.E), 1 - parity);
// Get link U_{\mu}(x-\nu)
Link U2 = arg.tmp[0](rho_with_mu, linkIndexShift(x, dx, X), 1 - parity);
Link U2 = arg.tmp[0](rho_with_mu, linkIndexShift(x, dx, arg.E), 1 - parity);

// Get link U_{\nu}(x-\nu+\mu)
dx[mu]++;
Link U3 = arg.tmp[0](rho_with_nu, linkIndexShift(x, dx, X), parity);
Link U3 = arg.tmp[0](rho_with_nu, linkIndexShift(x, dx, arg.E), parity);

// reset dx
dx[mu]--;
Expand All @@ -406,38 +404,35 @@ namespace quda
typedef Matrix<complex<real>, Arg::nColor> Link;

// compute spacetime and local coords
int X[4];
for (int dr = 0; dr < 4; ++dr) X[dr] = arg.X[dr];
int x[4];
getCoords(x, x_cb, X, parity);
for (int dr = 0; dr < 4; ++dr) {
x[dr] += arg.border[dr];
X[dr] += 2 * arg.border[dr];
}
getCoords(x, x_cb, arg.X, parity);
#pragma unroll
for (int dr=0; dr<4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates

int dir_ = dir;
dir = dir + (dir >= arg.dir_ignore);

int dx[4] = {0, 0, 0, 0};
Link U, Stap[2], TestU, I;

// Get link U
U = arg.in(dir, linkIndexShift(x, dx, X), parity);
U = arg.in(dir, linkIndexShift(x, dx, arg.E), parity);
setIdentity(&I);

if constexpr (Arg::level == 1) {
computeStaple3DLevel1(arg, x, X, parity, dir, Stap, arg.dir_ignore);
computeStaple3DLevel1(arg, x, parity, dir, Stap, arg.dir_ignore);

for (int i = 0; i < 2; ++i) {
TestU = I * (static_cast<real>(1.0) - arg.alpha) + Stap[i] * conj(U) * (arg.alpha / ((real)2.));
polarSu3<real>(TestU, arg.tolerance);
arg.tmp[0](dir_ * 2 + i, linkIndexShift(x, dx, X), parity) = TestU * U;
arg.tmp[0](dir_ * 2 + i, linkIndexShift(x, dx, arg.E), parity) = TestU * U;
}
} else if constexpr (Arg::level == 2) {
computeStaple3DLevel2(arg, x, X, parity, dir, Stap, arg.dir_ignore);
computeStaple3DLevel2(arg, x, parity, dir, Stap, arg.dir_ignore);

TestU = I * (static_cast<real>(1.0) - arg.alpha) + Stap[0] * conj(U) * (arg.alpha / ((real)4.));
polarSu3<real>(TestU, arg.tolerance);
arg.out(dir, linkIndexShift(x, dx, X), parity) = TestU * U;
arg.out(dir, linkIndexShift(x, dx, arg.E), parity) = TestU * U;
}
}
};
Expand Down

0 comments on commit 9545d28

Please sign in to comment.