diff --git a/include/clover_field.h b/include/clover_field.h index 380a399492..1bba0359a6 100644 --- a/include/clover_field.h +++ b/include/clover_field.h @@ -7,6 +7,19 @@ namespace quda { + /** + @brief Helper function that returns whether we have enabled + clover fermions. + */ + constexpr bool is_enabled_clover() + { +#ifdef GPU_CLOVER_DIRAC + return true; +#else + return false; +#endif + } + namespace clover { @@ -463,6 +476,29 @@ namespace quda { */ void cloverInvert(CloverField &clover, bool computeTraceLog); + /** + @brief Driver for the clover force computation. Eventually the + construction of the x and p fields will be delegated to this + function, but for now, we pre-compute these and pass them in. + @param mom[in,out] Momentum field to be updates + @param gaugeEx[in] Extended gauge field + @param gauge[in] Gauge field + @param clover[in] Clover field + @param x[in] Vector of quark solution fields + @param x0[in] Vector of auxilary quark fields for determinant ratio + @param coeff[in] Vector of coefficients for the quark field outer + products + @param epsilon[in] Vector of scalar coefficient pairs (one per + parity) for the clover sigma outer product + @param sigma_coeff[in] Coefficient for the tr log clover force + @param detratio[in] Whether to compute determinant ratio + @param parity[in] Which parity do we need compute the tr log clover force + */ + void computeCloverForce(GaugeField &mom, const GaugeField &gaugeEx, const GaugeField &gauge, + const CloverField &clover, cvector_ref &x, cvector_ref &x0, + const std::vector &coeff, const std::vector> &epsilon, + double sigma_coeff, bool detratio, QudaInvertParam ¶m); + /** @brief Compute the force contribution from the solver solution fields @@ -480,9 +516,8 @@ namespace quda { @param p Intermediate vectors (both parities) @param coeff Multiplicative coefficient (e.g., dt * residue) */ - void computeCloverForce(GaugeField& force, const GaugeField& U, - std::vector &x, std::vector &p, - std::vector &coeff); + void computeCloverForce(GaugeField &force, const GaugeField &U, cvector_ref &x, + cvector_ref &p, const std::vector &coeff); /** @brief Compute the outer product from the solver solution fields arising from the diagonal term of the fermion bilinear in @@ -493,10 +528,8 @@ namespace quda { @param p[in] Intermediate vectors (both parities) @coeff coeff[in] Multiplicative coefficient (e.g., dt * residiue), one for each parity */ - void computeCloverSigmaOprod(GaugeField& oprod, - std::vector &x, - std::vector &p, - std::vector< std::vector > &coeff); + void computeCloverSigmaOprod(GaugeField &oprod, cvector_ref &x, + cvector_ref &p, const std::vector> &coeff); /** @brief Compute the matrix tensor field necessary for the force calculation from the clover trace action. This computes a tensor field [mu,nu]. @@ -504,8 +537,9 @@ namespace quda { @param output The computed matrix field (tensor matrix field) @param clover The input clover field @param coeff Scalar coefficient multiplying the result (e.g., stepsize) + @param parity The field parity we are working on */ - void computeCloverSigmaTrace(GaugeField &output, const CloverField &clover, double coeff); + void computeCloverSigmaTrace(GaugeField &output, const CloverField &clover, double coeff, int parity); /** @brief Compute the derivative of the clover matrix in the direction @@ -516,9 +550,8 @@ namespace quda { @param gauge The input gauge field @param oprod The input outer-product field (tensor matrix field) @param coeff Multiplicative coefficient (e.g., clover coefficient) - @param parity The field parity we are working on */ - void cloverDerivative(GaugeField &force, GaugeField &gauge, GaugeField &oprod, double coeff, QudaParity parity); + void cloverDerivative(GaugeField &force, const GaugeField &gauge, const GaugeField &oprod, double coeff); /** @brief This function is used for copying from a source clover field to a destination clover field diff --git a/include/color_spinor_field.h b/include/color_spinor_field.h index 4801bdbf48..f6fc42d194 100644 --- a/include/color_spinor_field.h +++ b/include/color_spinor_field.h @@ -333,7 +333,7 @@ namespace quda size_t norm_offset = 0; /** offset to the norm (if applicable) */ // multi-GPU parameters - array_2d ghost = {}; // pointers to the ghost regions - NULL by default + mutable array_2d ghost = {}; // pointers to the ghost regions - NULL by default mutable lat_dim_t ghostFace = {}; // the size of each face mutable lat_dim_t ghostFaceCB = {}; // the size of each checkboarded face mutable array ghost_buf = {}; // wrapper that points to current ghost zone @@ -510,7 +510,7 @@ namespace quda @param[in] nFace Depth of each halo @param[in] spin_project Whether the halos are spin projected (Wilson-type fermions only) */ - void createComms(int nFace, bool spin_project = true); + void createComms(int nFace, bool spin_project = true) const; /** @brief Packs the ColorSpinorField's ghost zone @@ -530,7 +530,7 @@ namespace quda */ void packGhost(const int nFace, const QudaParity parity, const int dagger, const qudaStream_t &stream, MemoryLocation location[2 * QUDA_MAX_DIM], MemoryLocation location_label, bool spin_project, - double a = 0, double b = 0, double c = 0, int shmem = 0); + double a = 0, double b = 0, double c = 0, int shmem = 0) const; /** Pack the field halos in preparation for halo exchange, e.g., for Dslash @@ -550,7 +550,7 @@ namespace quda */ void pack(int nFace, int parity, int dagger, const qudaStream_t &stream, MemoryLocation location[2 * QUDA_MAX_DIM], MemoryLocation location_label, bool spin_project = true, double a = 0, double b = 0, double c = 0, - int shmem = 0); + int shmem = 0) const; /** @brief Initiate the gpu to cpu send of the ghost zone (halo) @@ -559,7 +559,7 @@ namespace quda @param dir The direction (QUDA_BACKWARDS or QUDA_FORWARDS) @param stream The array of streams to use */ - void sendGhost(void *ghost_spinor, const int dim, const QudaDirection dir, const qudaStream_t &stream); + void sendGhost(void *ghost_spinor, const int dim, const QudaDirection dir, const qudaStream_t &stream) const; /** Initiate the cpu to gpu send of the ghost zone (halo) @@ -568,7 +568,7 @@ namespace quda @param dir The direction (QUDA_BACKWARDS or QUDA_FORWARDS) @param stream The array of streams to use */ - void unpackGhost(const void *ghost_spinor, const int dim, const QudaDirection dir, const qudaStream_t &stream); + void unpackGhost(const void *ghost_spinor, const int dim, const QudaDirection dir, const qudaStream_t &stream) const; /** @brief Copies the ghost to the host from the device, prior to @@ -577,7 +577,7 @@ namespace quda the scatter-centric direction (0=backwards,1=forwards) @param[in] stream The stream in which to do the copy */ - void gather(int dir, const qudaStream_t &stream); + void gather(int dir, const qudaStream_t &stream) const; /** @brief Initiate halo communication receive @@ -585,7 +585,7 @@ namespace quda the scatter-centric direction (0=backwards,1=forwards) @param[in] gdr Whether we are using GDR on the receive side */ - void recvStart(int dir, const qudaStream_t &stream, bool gdr = false); + void recvStart(int dir, const qudaStream_t &stream, bool gdr = false) const; /** @brief Initiate halo communication sending @@ -596,7 +596,7 @@ namespace quda @param[in] gdr Whether we are using GDR on the send side @param[in] remote_write Whether we are writing direct to remote memory (or using copy engines) */ - void sendStart(int d, const qudaStream_t &stream, bool gdr = false, bool remote_write = false); + void sendStart(int d, const qudaStream_t &stream, bool gdr = false, bool remote_write = false) const; /** @brief Initiate halo communication @@ -606,7 +606,7 @@ namespace quda @param[in] gdr_send Whether we are using GDR on the send side @param[in] gdr_recv Whether we are using GDR on the receive side */ - void commsStart(int d, const qudaStream_t &stream, bool gdr_send = false, bool gdr_recv = false); + void commsStart(int d, const qudaStream_t &stream, bool gdr_send = false, bool gdr_recv = false) const; /** @brief Non-blocking query if the halo communication has completed @@ -616,7 +616,7 @@ namespace quda @param[in] gdr_send Whether we are using GDR on the send side @param[in] gdr_recv Whether we are using GDR on the receive side */ - int commsQuery(int d, const qudaStream_t &stream, bool gdr_send = false, bool gdr_recv = false); + int commsQuery(int d, const qudaStream_t &stream, bool gdr_send = false, bool gdr_recv = false) const; /** @brief Wait on halo communication to complete @@ -626,7 +626,7 @@ namespace quda @param[in] gdr_send Whether we are using GDR on the send side @param[in] gdr_recv Whether we are using GDR on the receive side */ - void commsWait(int d, const qudaStream_t &stream, bool gdr_send = false, bool gdr_recv = false); + void commsWait(int d, const qudaStream_t &stream, bool gdr_send = false, bool gdr_recv = false) const; /** @brief Unpacks the ghost from host to device after @@ -636,7 +636,7 @@ namespace quda @param[in] stream The stream in which to do the copy. If -1 is passed then the copy will be issied to the d^th stream */ - void scatter(int d, const qudaStream_t &stream); + void scatter(int d, const qudaStream_t &stream) const; /** Do the exchange between neighbouring nodes of the data in @@ -725,6 +725,9 @@ namespace quda ColorSpinorField &Even(); ColorSpinorField &Odd(); + const ColorSpinorField &operator[](QudaParity parity) const { return parity == QUDA_EVEN_PARITY ? Even() : Odd(); } + ColorSpinorField &operator[](QudaParity parity) { return parity == QUDA_EVEN_PARITY ? Even() : Odd(); } + CompositeColorSpinorField &Components() { return components; }; /** diff --git a/include/gauge_field.h b/include/gauge_field.h index 948960b36f..ffea94c22b 100644 --- a/include/gauge_field.h +++ b/include/gauge_field.h @@ -711,7 +711,7 @@ namespace quda { @param recon The reconsturction type @return the pointer to the extended gauge field */ - GaugeField *createExtendedGauge(GaugeField &in, const lat_dim_t &R, TimeProfile &profile, + GaugeField *createExtendedGauge(const GaugeField &in, const lat_dim_t &R, TimeProfile &profile = getProfile(), bool redundant_comms = false, QudaReconstructType recon = QUDA_RECONSTRUCT_INVALID); /** diff --git a/include/kernels/clover_deriv.cuh b/include/kernels/clover_deriv.cuh index 38ab548766..59213a7524 100644 --- a/include/kernels/clover_deriv.cuh +++ b/include/kernels/clover_deriv.cuh @@ -7,7 +7,8 @@ namespace quda { - template struct CloverDerivArg : kernel_param<> { + template struct CloverDerivArg : kernel_param<> { + static constexpr int nColor = nColor_; using Force = typename gauge_mapper::type; using Oprod = typename gauge_mapper::type; using Gauge = typename gauge_mapper::type; @@ -16,19 +17,13 @@ namespace quda int E[4]; int border[4]; real coeff; - int parity; Force force; Gauge gauge; Oprod oprod; - CloverDerivArg(const GaugeField &force, const GaugeField &gauge, const GaugeField &oprod, double coeff, int parity) : - kernel_param(dim3(force.VolumeCB(), 2, 4)), - coeff(coeff), - parity(parity), - force(force), - gauge(gauge), - oprod(oprod) + CloverDerivArg(GaugeField &force, const GaugeField &gauge, const GaugeField &oprod, double coeff) : + kernel_param(dim3(force.VolumeCB(), 2, 4)), coeff(coeff), force(force), gauge(gauge), oprod(oprod) { for (int dir = 0; dir < 4; ++dir) { X[dir] = force.X()[dir]; @@ -39,168 +34,167 @@ namespace quda }; using computeForceOps = SpecialOps>; - template - __device__ __host__ void computeForce(Link &force_total, const Ftor &ftor, int xIndex, int yIndex, int mu, int nu) + template + __device__ __host__ void computeForce(Force &force_total, const Ftor &ftor, int xIndex, int parity, int mu, int nu) { const auto &arg = ftor.arg; - const int otherparity = (1 - arg.parity); + const int otherparity = (1 - parity); const int tidx = mu > nu ? (mu - 1) * mu / 2 + nu : (nu - 1) * nu / 2 + mu; - if (yIndex == 0) { // do "this" force + int x[4]; + getCoordsExtended(x, xIndex, arg.X, parity, arg.border); - int x[4]; - getCoordsExtended(x, xIndex, arg.X, arg.parity, arg.border); - - // U[mu](x) U[nu](x+mu) U[*mu](x+nu) U[*nu](x) Oprod(x) - { - thread_array d{ftor}; - - // load U(x)_(+mu) - Link U1 = arg.gauge(mu, linkIndexShift(x, d, arg.E), arg.parity); - - // load U(x+mu)_(+nu) - d[mu]++; - Link U2 = arg.gauge(nu, linkIndexShift(x, d, arg.E), otherparity); - d[mu]--; - - // load U(x+nu)_(+mu) - d[nu]++; - Link U3 = arg.gauge(mu, linkIndexShift(x, d, arg.E), otherparity); - d[nu]--; - - // load U(x)_(+nu) - Link U4 = arg.gauge(nu, linkIndexShift(x, d, arg.E), arg.parity); - - // load Oprod - Link Oprod1 = arg.oprod(tidx, linkIndexShift(x, d, arg.E), arg.parity); - Link force = U1 * U2 * conj(U3) * conj(U4) * Oprod1; - - d[mu]++; - d[nu]++; - Link Oprod2 = arg.oprod(tidx, linkIndexShift(x, d, arg.E), arg.parity); - force += U1 * U2 * Oprod2 * conj(U3) * conj(U4); - - if (nu < mu) force_total -= force; - else force_total += force; - } - - { - thread_array d{ftor}; - - // load U(x-nu)(+nu) - d[nu]--; - Link U1 = arg.gauge(nu, linkIndexShift(x, d, arg.E), otherparity); - d[nu]++; - - // load U(x-nu)(+mu) - d[nu]--; - Link U2 = arg.gauge(mu, linkIndexShift(x, d, arg.E), otherparity); - d[nu]++; - - // load U(x+mu-nu)(nu) - d[mu]++; - d[nu]--; - Link U3 = arg.gauge(nu, linkIndexShift(x, d, arg.E), arg.parity); - d[mu]--; - d[nu]++; - - // load U(x)_(+mu) - Link U4 = arg.gauge(mu, linkIndexShift(x, d, arg.E), arg.parity); - - d[mu]++; - d[nu]--; - Link Oprod1 = arg.oprod(tidx, linkIndexShift(x, d, arg.E), arg.parity); - Link force = conj(U1) * U2 * Oprod1 * U3 * conj(U4); - - d[mu]--; - d[nu]++; - Link Oprod4 = arg.oprod(tidx, linkIndexShift(x, d, arg.E), arg.parity); - force += Oprod4 * conj(U1) * U2 * U3 * conj(U4); - - if (nu < mu) force_total += force; - else force_total -= force; - } - - } else { // else do other force - - int y[4] = { }; - getCoordsExtended(y, xIndex, arg.X, otherparity, arg.border); + // U[mu](x) U[nu](x+mu) U[*mu](x+nu) U[*nu](x) Oprod(x) + { + thread_array d{ftor}; - { - thread_array d{ftor}; + // load U(x)_(+mu) + Link U1 = arg.gauge(mu, linkIndexShift(x, d, arg.E), parity); - // load U(x)_(+mu) - Link U1 = arg.gauge(mu, linkIndexShift(y, d, arg.E), otherparity); + // load U(x+mu)_(+nu) + d[mu]++; + Link U2 = arg.gauge(nu, linkIndexShift(x, d, arg.E), otherparity); + d[mu]--; - // load U(x+mu)_(+nu) - d[mu]++; - Link U2 = arg.gauge(nu, linkIndexShift(y, d, arg.E), arg.parity); - d[mu]--; + // load U(x+nu)_(+mu) + d[nu]++; + Link U3 = arg.gauge(mu, linkIndexShift(x, d, arg.E), otherparity); + d[nu]--; - // load U(x+nu)_(+mu) - d[nu]++; - Link U3 = arg.gauge(mu, linkIndexShift(y, d, arg.E), arg.parity); - d[nu]--; + // load U(x)_(+nu) + Link U4 = arg.gauge(nu, linkIndexShift(x, d, arg.E), parity); - // load U(x)_(+nu) - Link U4 = arg.gauge(nu, linkIndexShift(y, d, arg.E), otherparity); + // load Oprod + Link Oprod1 = arg.oprod(tidx, linkIndexShift(x, d, arg.E), parity); + Link force = U1 * U2 * conj(U3) * conj(U4) * Oprod1; - // load opposite parity Oprod - d[nu]++; - Link Oprod3 = arg.oprod(tidx, linkIndexShift(y, d, arg.E), arg.parity); - Link force = U1 * U2 * conj(U3) * Oprod3 * conj(U4); + d[mu]++; + d[nu]++; + Link Oprod2 = arg.oprod(tidx, linkIndexShift(x, d, arg.E), parity); + force += U1 * U2 * Oprod2 * conj(U3) * conj(U4); - // load Oprod(x+mu) - d[nu]--; - d[mu]++; - Link Oprod4 = arg.oprod(tidx, linkIndexShift(y, d, arg.E), arg.parity); - force += U1 * Oprod4 * U2 * conj(U3) * conj(U4); + if (nu < mu) + force_total -= force; + else + force_total += force; + } - if (nu < mu) force_total -= force; - else force_total += force; - } + { + thread_array d{ftor}; + + // load U(x-nu)(+nu) + d[nu]--; + Link U1 = arg.gauge(nu, linkIndexShift(x, d, arg.E), otherparity); + d[nu]++; + + // load U(x-nu)(+mu) + d[nu]--; + Link U2 = arg.gauge(mu, linkIndexShift(x, d, arg.E), otherparity); + d[nu]++; + + // load U(x+mu-nu)(nu) + d[mu]++; + d[nu]--; + Link U3 = arg.gauge(nu, linkIndexShift(x, d, arg.E), parity); + d[mu]--; + d[nu]++; + + // load U(x)_(+mu) + Link U4 = arg.gauge(mu, linkIndexShift(x, d, arg.E), parity); + + d[mu]++; + d[nu]--; + Link Oprod1 = arg.oprod(tidx, linkIndexShift(x, d, arg.E), parity); + Link force = conj(U1) * U2 * Oprod1 * U3 * conj(U4); + + d[mu]--; + d[nu]++; + Link Oprod4 = arg.oprod(tidx, linkIndexShift(x, d, arg.E), parity); + force += Oprod4 * conj(U1) * U2 * U3 * conj(U4); + + if (nu < mu) + force_total += force; + else + force_total -= force; + } - // Lower leaf - // U[nu*](x-nu) U[mu](x-nu) U[nu](x+mu-nu) Oprod(x+mu) U[*mu](x) - { - thread_array d{ftor}; - - // load U(x-nu)(+nu) - d[nu]--; - Link U1 = arg.gauge(nu, linkIndexShift(y, d, arg.E), arg.parity); - d[nu]++; - - // load U(x-nu)(+mu) - d[nu]--; - Link U2 = arg.gauge(mu, linkIndexShift(y, d, arg.E), arg.parity); - d[nu]++; - - // load U(x+mu-nu)(nu) - d[mu]++; - d[nu]--; - Link U3 = arg.gauge(nu, linkIndexShift(y, d, arg.E), otherparity); - d[mu]--; - d[nu]++; - - // load U(x)_(+mu) - Link U4 = arg.gauge(mu, linkIndexShift(y, d, arg.E), otherparity); - - // load Oprod(x+mu) - d[mu]++; - Link Oprod1 = arg.oprod(tidx, linkIndexShift(y, d, arg.E), arg.parity); - Link force = conj(U1) * U2 * U3 * Oprod1 * conj(U4); - - d[mu]--; - d[nu]--; - Link Oprod2 = arg.oprod(tidx, linkIndexShift(y, d, arg.E), arg.parity); - force += conj(U1) * Oprod2 * U2 * U3 * conj(U4); - - if (nu < mu) force_total += force; - else force_total -= force; - } + { + thread_array d{ftor}; + + // load U(x)_(+mu) + Link U1 = arg.gauge(mu, linkIndexShift(x, d, arg.E), parity); + + // load U(x+mu)_(+nu) + d[mu]++; + Link U2 = arg.gauge(nu, linkIndexShift(x, d, arg.E), otherparity); + d[mu]--; + + // load U(x+nu)_(+mu) + d[nu]++; + Link U3 = arg.gauge(mu, linkIndexShift(x, d, arg.E), otherparity); + d[nu]--; + + // load U(x)_(+nu) + Link U4 = arg.gauge(nu, linkIndexShift(x, d, arg.E), parity); + + // load opposite parity Oprod + d[nu]++; + Link Oprod3 = arg.oprod(tidx, linkIndexShift(x, d, arg.E), otherparity); + Link force = U1 * U2 * conj(U3) * Oprod3 * conj(U4); + + // load Oprod(x+mu) + d[nu]--; + d[mu]++; + Link Oprod4 = arg.oprod(tidx, linkIndexShift(x, d, arg.E), otherparity); + force += U1 * Oprod4 * U2 * conj(U3) * conj(U4); + + if (nu < mu) + force_total -= force; + else + force_total += force; } - } // namespace quda + // Lower leaf + // U[nu*](x-nu) U[mu](x-nu) U[nu](x+mu-nu) Oprod(x+mu) U[*mu](x) + { + thread_array d{ftor}; + + // load U(x-nu)(+nu) + d[nu]--; + Link U1 = arg.gauge(nu, linkIndexShift(x, d, arg.E), otherparity); + d[nu]++; + + // load U(x-nu)(+mu) + d[nu]--; + Link U2 = arg.gauge(mu, linkIndexShift(x, d, arg.E), otherparity); + d[nu]++; + + // load U(x+mu-nu)(nu) + d[mu]++; + d[nu]--; + Link U3 = arg.gauge(nu, linkIndexShift(x, d, arg.E), parity); + d[mu]--; + d[nu]++; + + // load U(x)_(+mu) + Link U4 = arg.gauge(mu, linkIndexShift(x, d, arg.E), parity); + + // load Oprod(x+mu) + d[mu]++; + Link Oprod1 = arg.oprod(tidx, linkIndexShift(x, d, arg.E), otherparity); + Link force = conj(U1) * U2 * U3 * Oprod1 * conj(U4); + + d[mu]--; + d[nu]--; + Link Oprod2 = arg.oprod(tidx, linkIndexShift(x, d, arg.E), otherparity); + force += conj(U1) * Oprod2 * U2 * U3 * conj(U4); + + if (nu < mu) + force_total += force; + else + force_total -= force; + } + } template struct CloverDerivative : computeForceOps { @@ -213,20 +207,19 @@ namespace quda { using real = typename Arg::real; using Complex = complex; - using Link = Matrix; + using Link = Matrix; Link force; -#pragma unroll for (int nu = 0; nu < 4; nu++) { if (nu == mu) continue; - computeForce(force, *this, x_cb, parity, mu, nu); + computeForce(force, *this, x_cb, parity, mu, nu); } // Write to array - Link F = arg.force(mu, x_cb, parity == 0 ? arg.parity : 1 - arg.parity); - F += arg.coeff * force; - arg.force(mu, x_cb, parity == 0 ? arg.parity : 1 - arg.parity) = F; + Link F = arg.force(mu, x_cb, parity); + F += arg.coeff * static_cast(force); + arg.force(mu, x_cb, parity) = F; } }; diff --git a/include/kernels/clover_outer_product.cuh b/include/kernels/clover_outer_product.cuh index a915521e54..70e505b149 100644 --- a/include/kernels/clover_outer_product.cuh +++ b/include/kernels/clover_outer_product.cuh @@ -32,8 +32,9 @@ namespace quda { real coeff; CloverForceArg(GaugeField &force, const GaugeField &U, const ColorSpinorField &inA, const ColorSpinorField &inB, - const ColorSpinorField &inC, const ColorSpinorField &inD, const unsigned int parity, const double coeff) : - kernel_param(dim3(dim == -1 ? inA.VolumeCB() : inA.GhostFaceCB()[dim])), + const ColorSpinorField &inC, const ColorSpinorField &inD, const unsigned int parity, + const double coeff) : + kernel_param(dim3(dim == -1 ? inA.VolumeCB() : inB.GhostFaceCB()[dim])), // inB since it has a ghost allocated force(force), inA(inA), inB(inB), @@ -46,6 +47,21 @@ namespace quda { { for (int i=0; i<4; ++i) this->X[i] = U.X()[i]; for (int i=0; i<4; ++i) this->partitioned[i] = commDimPartitioned(i) ? true : false; + + // need to reset the ghost pointers since default ghost_offset + // (Ghost() method) not set (this is temporary work around) + void *ghost[8] = {}; + for (int dim = 0; dim < 4; dim++) { + for (int dir = 0; dir < 2; dir++) { ghost[2 * dim + dir] = (char *)inB.Ghost2() + inB.GhostOffset(dim, dir); } + } + this->inB.resetGhost(ghost); + inD.bufferIndex = (1 - inD.bufferIndex); + + for (int dim = 0; dim < 4; dim++) { + for (int dir = 0; dir < 2; dir++) { ghost[2 * dim + dir] = (char *)inD.Ghost2() + inD.GhostOffset(dim, dir); } + } + this->inD.resetGhost(ghost); + inD.bufferIndex = (1 - inD.bufferIndex); } }; diff --git a/include/kernels/clover_sigma_outer_product.cuh b/include/kernels/clover_sigma_outer_product.cuh index 91d2de8ef5..e4b901cdc1 100644 --- a/include/kernels/clover_sigma_outer_product.cuh +++ b/include/kernels/clover_sigma_outer_product.cuh @@ -21,21 +21,18 @@ namespace quda using F = typename colorspinor_mapper::type; Oprod oprod; - const F inA[nvector]; - const F inB[nvector]; - real coeff[nvector][2]; + F inA[nvector]; + F inB[nvector]; + array_2d coeff; - CloverSigmaOprodArg(GaugeField &oprod, const std::vector &inA, - const std::vector &inB, - const std::vector> &coeff_) : - kernel_param(dim3(oprod.VolumeCB(), 2, 6)), - oprod(oprod), - inA{*inA[0]}, - inB{*inB[0]} + CloverSigmaOprodArg(GaugeField &oprod, cvector_ref &inA, + cvector_ref &inB, const std::vector> &coeff_) : + kernel_param(dim3(oprod.VolumeCB(), 2, 6)), oprod(oprod) { for (int i = 0; i < nvector; i++) { - coeff[i][0] = coeff_[i][0]; - coeff[i][1] = coeff_[i][1]; + this->inA[i] = inA[i]; + this->inB[i] = inB[i]; + coeff[i] = {static_cast(coeff_[i][0]), static_cast(coeff_[i][1])}; } } }; diff --git a/include/kernels/clover_trace.cuh b/include/kernels/clover_trace.cuh index d7d7370fe3..d5f7514489 100644 --- a/include/kernels/clover_trace.cuh +++ b/include/kernels/clover_trace.cuh @@ -4,148 +4,157 @@ #include #include #include +#include namespace quda { - template - struct CloverTraceArg : kernel_param<> { + template struct CloverTraceArg : kernel_param<> { using real = typename mapper::type; + static constexpr bool twist = twist_; static constexpr int nColor = nColor_; + static constexpr int nSpin = 4; + static constexpr bool dynamic_clover = clover::dynamic_inverse(); using C = typename clover_mapper::type; using G = typename gauge_mapper::type; G output; - const C clover1; - const C clover2; + const C clover; + const C clover_inv; real coeff; + real mu2_minus_epsilon2; + const int parity; - CloverTraceArg(GaugeField& output, const CloverField& clover, double coeff) : - kernel_param(dim3(clover.VolumeCB(), 1, 1)), + CloverTraceArg(GaugeField &output, const CloverField &clover, double coeff, int parity) : + kernel_param(dim3(output.VolumeCB(), 1, 1)), output(output), - clover1(clover, 0), - clover2(clover, 1), - coeff(coeff) {} + clover(clover, false), + clover_inv(clover, dynamic_clover ? false : true), + coeff(coeff), + mu2_minus_epsilon2(clover.Mu2() - clover.Epsilon2()), + parity(parity) + { + } }; - template - __device__ __host__ void cloverSigmaTraceCompute(const Arg &arg, const int x, int parity) + template __device__ __host__ inline void cloverSigmaTraceCompute(const Arg &arg, const int x) { + using namespace linalg; // for Cholesky using real = typename Arg::real; - real A[72]; - if (parity==0) arg.clover1.load(A,x,parity); - else arg.clover2.load(A,x,parity); + constexpr int N = Arg::nColor; + using Mat = HMatrix; + Mat A[2]; // load the clover term into memory - for (int mu=0; mu<4; mu++) { - for (int nu=0; nu(2.0); // factor of two is inherent to QUDA clover storage + + if constexpr (Arg::dynamic_clover) { + if constexpr (Arg::twist) { // Compute (T^2 + mu2 - epsilon2) first, then invert + A[ch] = A[ch].square(); + A[ch] += arg.mu2_minus_epsilon2; + } - Matrix, Arg::nColor> mat; - setZero(&mat); + // compute the Cholesky decomposition + Cholesky, N * Arg::nSpin / 2> cholesky(A[ch]); + A[ch] = cholesky.template invert(); // return full inverse + } - real diag[2][6]; - complex tri[2][15]; - const int idtab[15]={0,1,3,6,10,2,4,7,11,5,8,12,9,13,14}; - complex ctmp; + if constexpr (Arg::twist) { + Mat A0 = arg.clover(x, arg.parity, ch); + A[ch] = static_cast(0.5) * (A0 * A[ch]); // (1 + T + imu g_5)^{-1} = (1 + T - imu g_5)/((1 + T)^2 + mu^2) + } + } - for (int ch=0; ch<2; ++ch) { - // factor of two is inherent to QUDA clover storage - for (int i=0; i<6; i++) diag[ch][i] = 2.0*A[ch*36+i]; - for (int i=0; i<15; i++) tri[ch][idtab[i]] = complex(2.0*A[ch*36+6+2*i], 2.0*A[ch*36+6+2*i+1]); - } + const Mat &A0 = A[0]; + const Mat &A1 = A[1]; + +#pragma unroll + for (int mu = 0; mu < 4; mu++) { +#pragma unroll + for (int nu = 0; nu < 4; nu++) { + if (nu >= mu) continue; + Matrix, Arg::nColor> mat = {}; // X, Y if (nu == 0) { if (mu == 1) { - for (int j=0; j<3; ++j) { - mat(j,j).y = diag[0][j+3] + diag[1][j+3] - diag[0][j] - diag[1][j]; +#pragma unroll + for (int j = 0; j < N; ++j) { + mat(j, j).imag(A0(j + N, j + N).real() + A1(j + N, j + N).real() - A0(j, j).real() - A1(j, j).real()); } // triangular part - int jk=0; - for (int j=1; j<3; ++j) { - int jk2 = (j+3)*(j+2)/2 + 3; +#pragma unroll + for (int j = 1; j < N; ++j) { +#pragma unroll for (int k=0; k struct CloverSigmaTr @@ -154,11 +163,7 @@ namespace quda { constexpr CloverSigmaTr(const Arg &arg) : arg(arg) {} static constexpr const char* filename() { return KERNEL_FILE; } - __device__ __host__ inline void operator()(int x_cb) - { - // odd parity - cloverSigmaTraceCompute(arg, x_cb, 1); - } + __device__ __host__ inline void operator()(int x_cb) { cloverSigmaTraceCompute(arg, x_cb); } }; } diff --git a/include/kernels/coarse_op_kernel.cuh b/include/kernels/coarse_op_kernel.cuh index a329fbb3dd..c4a038367b 100644 --- a/include/kernels/coarse_op_kernel.cuh +++ b/include/kernels/coarse_op_kernel.cuh @@ -1388,9 +1388,10 @@ namespace quda { }; template <> struct storeCoarseSharedAtomic_impl { - template using CacheT = - complex[Arg::max_color_height_per_block][Arg::max_color_width_per_block][4][Arg::coarseSpin][Arg::coarseSpin]; - template using Cache = SharedMemoryCache,opDimsStatic<2,1,1>>; + template + using CacheT = complex[Arg::max_color_height_per_block][Arg::max_color_width_per_block][4] + [Arg::coarseSpin][Arg::coarseSpin]; + template using Cache = SharedMemoryCache, DimsStatic<2, 1, 1>>; template using Ops = SpecialOps>; template @@ -1402,8 +1403,6 @@ namespace quda { using real = typename Arg::Float; using TileType = typename Arg::vuvTileType; const int dim_index = arg.dim_index % arg.Y_atomic.geometry; - //__shared__ complex X[Arg::max_color_height_per_block][Arg::max_color_width_per_block][4][Arg::coarseSpin][Arg::coarseSpin]; - //__shared__ complex Y[Arg::max_color_height_per_block][Arg::max_color_width_per_block][4][Arg::coarseSpin][Arg::coarseSpin]; Cache cache{ftor}; auto &X = cache.data()[0]; auto &Y = cache.data()[1]; @@ -1428,7 +1427,6 @@ namespace quda { } } - //__syncthreads(); cache.sync(); #pragma unroll @@ -1458,7 +1456,6 @@ namespace quda { } } - //__syncthreads(); cache.sync(); if (tx < Arg::coarseSpin*Arg::coarseSpin && (parity == 0 || arg.parity_flip == 1) ) { diff --git a/include/kernels/color_spinor_pack.cuh b/include/kernels/color_spinor_pack.cuh index d675408c88..bf71d1ad02 100644 --- a/include/kernels/color_spinor_pack.cuh +++ b/include/kernels/color_spinor_pack.cuh @@ -174,15 +174,16 @@ namespace quda { }; template <> struct site_max { - template static constexpr int Ms = spins_per_thread(Arg::nSpin); - template static constexpr int Mc = colors_per_thread(Arg::nColor); - template static constexpr int color_spin_threads = (Arg::nSpin/Ms) * (Arg::nColor/Mc); template struct CacheDims { - template static constexpr dim3 dims(dim3 b, A &...) { - dim3 block = b; - if (Arg::is_native) block.x = ((block.x + device::warp_size() - 1) / device::warp_size()) * device::warp_size(); - block.y = color_spin_threads; // state the y block since we know it at compile time - return block; + static constexpr int Ms = spins_per_thread(Arg::nSpin); + static constexpr int Mc = colors_per_thread(Arg::nColor); + static constexpr int color_spin_threads = (Arg::nSpin / Ms) * (Arg::nColor / Mc); + static constexpr dim3 dims(dim3 block) + { + // pad the shared block size to avoid bank conflicts for native ordering + if (Arg::is_native) block.x = ((block.x + device::warp_size() - 1) / device::warp_size()) * device::warp_size(); + block.y = color_spin_threads; // state the y block since we know it at compile time + return block; } }; template using Cache = SharedMemoryCache>; @@ -192,12 +193,13 @@ namespace quda { { using Arg = typename Ftor::Arg; using real = typename Arg::real; + constexpr int color_spin_threads = CacheDims::color_spin_threads; Cache cache{ftor}; cache.save(thread_max); cache.sync(); real this_site_max = static_cast(0); #pragma unroll - for (int sc = 0; sc < color_spin_threads; sc++) { + for (int sc = 0; sc < color_spin_threads; sc++) { auto sc_max = cache.load_y(sc); this_site_max = this_site_max > sc_max ? this_site_max : sc_max; } diff --git a/include/kernels/dslash_clover_helper.cuh b/include/kernels/dslash_clover_helper.cuh index a21beead5e..ab985b5b4d 100644 --- a/include/kernels/dslash_clover_helper.cuh +++ b/include/kernels/dslash_clover_helper.cuh @@ -209,7 +209,6 @@ namespace quda { Mat A = arg.clover(x_cb, clover_parity, chirality); - //SharedMemoryCache cache(target::block_dim()); SharedMemoryCache cache{*this}; half_fermion in_chi[n_flavor]; // flavor array of chirally projected fermion diff --git a/include/kernels/dslash_coarse.cuh b/include/kernels/dslash_coarse.cuh index 87e6eb75be..37decb5b56 100644 --- a/include/kernels/dslash_coarse.cuh +++ b/include/kernels/dslash_coarse.cuh @@ -303,7 +303,6 @@ namespace quda { //template __device__ __host__ inline void operator()(T &out, int dir, int dim, const Arg &arg) template __device__ __host__ inline void operator()(T &out, int dir, int dim, const Ftor &ftor) { - //SharedMemoryCache cache(target::block_dim()); SharedMemoryCache cache{ftor}; // only need to write to shared memory if not master thread if (dim > 0 || dir) cache.save(out); diff --git a/include/kernels/gauge_stout.cuh b/include/kernels/gauge_stout.cuh index 534cf3fe2f..0ed11abfe2 100644 --- a/include/kernels/gauge_stout.cuh +++ b/include/kernels/gauge_stout.cuh @@ -117,8 +117,8 @@ namespace quda using real = typename Arg::Float; using Complex = complex; using Link = Matrix, Arg::nColor>; - using StapCacheT = ThreadLocalCache; - using RectCacheT = ThreadLocalCache; + using StapCacheT = ThreadLocalCache; // offset by computeStapleRectangleOps + using RectCacheT = ThreadLocalCache; // offset by StapCacheT using Ops = combineOps>; }; @@ -148,8 +148,6 @@ namespace quda } Link U, Q; - //SharedMemoryCache Stap(target::block_dim()); - //SharedMemoryCache Rect(target::block_dim(), sizeof(Link)); typename OvrImpSTOUTOps::StapCacheT Stap{*this}; typename OvrImpSTOUTOps::RectCacheT Rect{*this}; diff --git a/include/kernels/gauge_wilson_flow.cuh b/include/kernels/gauge_wilson_flow.cuh index 743074be4c..c725cb3c30 100644 --- a/include/kernels/gauge_wilson_flow.cuh +++ b/include/kernels/gauge_wilson_flow.cuh @@ -62,8 +62,8 @@ namespace quda using real = typename Arg::real; using Link = Matrix, Arg::nColor>; using WilsonOps = computeStapleOps; - using StapOp = ThreadLocalCache; - using RectOp = ThreadLocalCache; + using StapOp = ThreadLocalCache; // offset by computeStapleRectangleOps + using RectOp = ThreadLocalCache; // offset by StapOp using SymanzikOps = combineOps>; using Ops = std::conditional_t; }; @@ -92,12 +92,8 @@ namespace quda // This function gets stap = S_{mu,nu} i.e., the staple of length 3, // and the 1x2 and 2x1 rectangles of length 5. From the following paper: // https://arxiv.org/abs/0801.1165 - //SharedMemoryCache Stap(target::block_dim()); - //SharedMemoryCache Rect(target::block_dim(), sizeof(Link)); // offset to ensure non-overlapping allocations typename computeStapleOpsWF::StapOp Stap{ftor}; typename computeStapleOpsWF::RectOp Rect{ftor}; - //ThreadLocalCache Stap{ftor}; - //ThreadLocalCache Rect{ftor}; computeStapleRectangle(ftor, x, arg.E, parity, dir, Stap, Rect, Arg::wflow_dim); Z = arg.coeff1x1 * static_cast(Stap) + arg.coeff2x1 * static_cast(Rect); //break; diff --git a/include/kernels/hisq_paths_force.cuh b/include/kernels/hisq_paths_force.cuh index c0156eab0b..5044943b37 100644 --- a/include/kernels/hisq_paths_force.cuh +++ b/include/kernels/hisq_paths_force.cuh @@ -394,7 +394,9 @@ namespace quda { 2 multiplies, 1 add, 1 rescale */ template - __device__ __host__ inline void lepage_force(int x[4], int point_a, int parity_a, Link &force_mu, LinkCache &Uab_cache) { + __device__ __host__ inline void lepage_force(int x[4], int point_a, int parity_a, Link &force_mu, + LinkCache &Uab_cache) + { int point_b = linkExtendedIndexShiftMILC(x, arg.sig, arg); int parity_b = 1 - parity_a; @@ -565,7 +567,6 @@ namespace quda { int point_a = e_cb; int parity_a = parity; - //SharedMemoryCache Uab_cache(target::block_dim()); ThreadLocalCache Uab_cache{*this}; // Scoped load of Uab { @@ -723,7 +724,8 @@ namespace quda { 4 multiplies, 2 adds, 2 rescales */ template - __device__ __host__ inline void all_link(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) { + __device__ __host__ inline void all_link(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) + { auto mycoeff_seven = parity_sign(parity_a) * coeff_sign(parity_a) * arg.coeff_seven; int point_b = linkExtendedIndexShiftMILC(x, arg.sig, arg); @@ -817,7 +819,6 @@ namespace quda { force_sig = mm_add(mycoeff_seven * Oz, Od * Uda, force_sig); Matrix_cache.save(force_sig, 2); } - } /** @@ -836,7 +837,8 @@ namespace quda { 2 multiplies, 2 adds, 2 rescales */ template - __device__ __host__ inline void side_five(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) { + __device__ __host__ inline void side_five(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) + { int y[4] = {x[0], x[1], x[2], x[3]}; int point_h = updateCoordExtendedIndexShiftMILC(y, arg.nu, arg); int parity_h = 1 - parity_a; @@ -889,7 +891,8 @@ namespace quda { 1 multiply, 1 add, 1 rescale */ template - __device__ __host__ inline void middle_five(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) { + __device__ __host__ inline void middle_five(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) + { int point_b = linkExtendedIndexShiftMILC(x, arg.sig, arg); int parity_b = 1 - parity_a; @@ -976,8 +979,8 @@ namespace quda { // calculate p5_sig constexpr int cacheLen = sig_positive ? 3 : 2; - //ThreadLocalCache> Matrix_cache{}; - ThreadLocalCache Matrix_cache{*this}; + ThreadLocalCache Matrix_cache{*this}; + if constexpr (sig_positive) { Link force_sig = arg.force(arg.sig, point_a, parity_a); Matrix_cache.save(force_sig, 2); diff --git a/include/lattice_field.h b/include/lattice_field.h index b92297eabc..462add3bde 100644 --- a/include/lattice_field.h +++ b/include/lattice_field.h @@ -317,82 +317,82 @@ namespace quda { /** Pinned memory buffer used for sending messages */ - array my_face_h = {}; + mutable array my_face_h = {}; /** Mapped version of my_face_h */ - array my_face_hd = {}; + mutable array my_face_hd = {}; /** Device memory buffer for sending messages */ - array my_face_d = {}; + mutable array my_face_d = {}; /** Local pointers to the pinned my_face buffer */ - array_3d my_face_dim_dir_h = {}; + mutable array_3d my_face_dim_dir_h = {}; /** Local pointers to the mapped my_face buffer */ - array_3d my_face_dim_dir_hd = {}; + mutable array_3d my_face_dim_dir_hd = {}; /** Local pointers to the device ghost_send buffer */ - array_3d my_face_dim_dir_d = {}; + mutable array_3d my_face_dim_dir_d = {}; /** Memory buffer used for receiving all messages */ - array from_face_h = {}; + mutable array from_face_h = {}; /** Mapped version of from_face_h */ - array from_face_hd = {}; + mutable array from_face_hd = {}; /** Device memory buffer for receiving messages */ - array from_face_d = {}; + mutable array from_face_d = {}; /** Local pointers to the pinned from_face buffer */ - array_3d from_face_dim_dir_h = {}; + mutable array_3d from_face_dim_dir_h = {}; /** Local pointers to the mapped from_face buffer */ - array_3d from_face_dim_dir_hd = {}; + mutable array_3d from_face_dim_dir_hd = {}; /** Local pointers to the device ghost_recv buffer */ - array_3d from_face_dim_dir_d = {}; + mutable array_3d from_face_dim_dir_d = {}; /** Message handles for receiving */ - array_3d mh_recv = {}; + mutable array_3d mh_recv = {}; /** Message handles for sending */ - array_3d mh_send = {}; + mutable array_3d mh_send = {}; /** Message handles for receiving */ - array_3d mh_recv_rdma = {}; + mutable array_3d mh_recv_rdma = {}; /** Message handles for sending */ - array_3d mh_send_rdma = {}; + mutable array_3d mh_send_rdma = {}; /** Message handles for receiving @@ -427,7 +427,7 @@ namespace quda { /** Whether we have initialized communication for this field */ - bool initComms = false; + mutable bool initComms = false; /** Whether we have initialized peer-to-peer communication @@ -543,17 +543,17 @@ namespace quda { @param[in] no_comms_fill Whether to allocate halo buffers for dimensions that are not partitioned */ - void createComms(bool no_comms_fill = false); + void createComms(bool no_comms_fill = false) const; /** Destroy the communication handlers */ - void destroyComms(); + void destroyComms() const; /** Create the inter-process communication handlers */ - void createIPCComms(); + void createIPCComms() const; /** Destroy the statically allocated inter-process communication handlers @@ -774,19 +774,19 @@ namespace quda { */ void *remoteFace_r() const; - virtual void gather(int, const qudaStream_t &) { errorQuda("Not implemented"); } + virtual void gather(int, const qudaStream_t &) const { errorQuda("Not implemented"); } - virtual void commsStart(int, const qudaStream_t &, bool, bool) { errorQuda("Not implemented"); } + virtual void commsStart(int, const qudaStream_t &, bool, bool) const { errorQuda("Not implemented"); } - virtual int commsQuery(int, const qudaStream_t &, bool, bool) + virtual int commsQuery(int, const qudaStream_t &, bool, bool) const { errorQuda("Not implemented"); return 0; } - virtual void commsWait(int, const qudaStream_t &, bool, bool) { errorQuda("Not implemented"); } + virtual void commsWait(int, const qudaStream_t &, bool, bool) const { errorQuda("Not implemented"); } - virtual void scatter(int, const qudaStream_t &) { errorQuda("Not implemented"); } + virtual void scatter(int, const qudaStream_t &) const { errorQuda("Not implemented"); } /** Return the volume string used by the autotuner */ auto VolString() const { return vol_string; } diff --git a/include/quda.h b/include/quda.h index c8392b2054..d11fd4cbd0 100644 --- a/include/quda.h +++ b/include/quda.h @@ -1511,9 +1511,9 @@ extern "C" { void createCloverQuda(QudaInvertParam* param); /** - * Compute the clover force contributions in each dimension mu given - * the array of solution fields, and compute the resulting momentum - * field. + * Compute the clover force contributions from a set of partial + * fractions stemming from a rational approximation suitable for use + * within MILC. * * @param mom Force matrix * @param dt Integrating step size @@ -1532,6 +1532,23 @@ extern "C" { int nvector, double multiplicity, void *gauge, QudaGaugeParam *gauge_param, QudaInvertParam *inv_param); + /** + * Compute the force from a clover or twisted clover determinant or + * a set of partial fractions stemming from a rational approximation + * suitable for use from within tmLQCD. + * + * @param h_mom Host force matrix + * @param h_x Array of solution vectors x_i = ( Q^2 + s_i )^{-1} b + * @param h_x0 Array of source vector necessary to compute the force of a ratio of determinant + * @param coeff Array of coefficients for the rational approximation or {1.0} for the determinant. + * @param nvector Number of solution vectors and coefficients + * @param gauge_param Gauge field meta data + * @param inv_param Dirac and solver meta data + * @param detratio if 0 compute the force of a determinant otherwise compute the force from a ratio of determinants + */ + void computeTMCloverForceQuda(void *h_mom, void **h_x, void **h_x0, double *coeff, int nvector, + QudaGaugeParam *gauge_param, QudaInvertParam *inv_param, int detratio); + /** * Compute the naive staggered force. All fields must be in the same precision. * diff --git a/include/reference_wrapper_helper.h b/include/reference_wrapper_helper.h index 3f73709ca6..1ab313df4a 100644 --- a/include/reference_wrapper_helper.h +++ b/include/reference_wrapper_helper.h @@ -189,6 +189,15 @@ namespace quda */ template vector make_set(std::vector &v) { return vector{v.begin(), v.end()}; } + /** + make_set is a helper function that creates a vector of + reference wrapped objects from the input reference argument. + This is the specialized overload that handles a vector_ref of + objects. Used to convert a non-const set to a const set. + @param[in] v Vector argument we wish to wrap + */ + template vector make_set(const vector_ref &v) { return vector {v.begin(), v.end()}; } + /** make_set is a helper function that creates a vector of reference wrapped objects from the input reference argument. diff --git a/include/targets/cuda/block_reduce_helper.h b/include/targets/cuda/block_reduce_helper.h index 9408dd518d..59907822e8 100644 --- a/include/targets/cuda/block_reduce_helper.h +++ b/include/targets/cuda/block_reduce_helper.h @@ -2,7 +2,7 @@ #include #include -#include +#include /** @file block_reduce_helper.h diff --git a/include/targets/cuda/load_store.h b/include/targets/cuda/load_store.h index 4a5420b166..0550ad62dd 100644 --- a/include/targets/cuda/load_store.h +++ b/include/targets/cuda/load_store.h @@ -6,6 +6,12 @@ namespace quda { + /** + @brief Element type used for coalesced storage. + */ + template + using atom_t = std::conditional_t>; + // pre-declaration of vector_load that we wish to specialize template struct vector_load_impl; diff --git a/include/targets/cuda/thread_array.h b/include/targets/cuda/thread_array.h index 01b9bf47be..fd2c724081 100644 --- a/include/targets/cuda/thread_array.h +++ b/include/targets/cuda/thread_array.h @@ -11,7 +11,6 @@ namespace quda { template struct thread_array : array { - //constexpr thread_array() : array() {} template constexpr thread_array(Ops &ops) : array() {} static constexpr unsigned int shared_mem_size(dim3) { return 0; } }; diff --git a/include/targets/cuda/tunable_kernel.h b/include/targets/cuda/tunable_kernel.h index 035b638b1a..152dfd8a61 100644 --- a/include/targets/cuda/tunable_kernel.h +++ b/include/targets/cuda/tunable_kernel.h @@ -46,6 +46,7 @@ namespace quda std::enable_if_t(), qudaError_t> launch_device(const kernel_t &kernel, const TuneParam &tp, const qudaStream_t &stream, const Arg &arg) { + checkSharedBytes(tp); #ifdef JITIFY launch_error = launch_jitify(kernel.name, tp, stream, arg); #else @@ -63,6 +64,7 @@ namespace quda std::enable_if_t(), qudaError_t> launch_device(const kernel_t &kernel, const TuneParam &tp, const qudaStream_t &stream, const Arg &arg) { + checkSharedBytes(tp); #ifdef JITIFY // note we do the copy to constant memory after the kernel has been compiled in launch_jitify launch_error = launch_jitify(kernel.name, tp, stream, arg); @@ -84,6 +86,7 @@ namespace quda template