Skip to content

Commit

Permalink
Fix compile issue for Marin qqq on sm<8.0 (#1651)
Browse files Browse the repository at this point in the history
* fix compile guard

* remove guard on header file
  • Loading branch information
gau-nernst authored Feb 5, 2025
1 parent 1a4c8f9 commit 8afd10e
Showing 1 changed file with 10 additions and 45 deletions.
55 changes: 10 additions & 45 deletions torchao/csrc/cuda/marlin_qqq/marlin_qqq_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
#include <iostream>

#include "base.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include "mem.h"
#endif
#include "mem.h"

template <typename T>
inline std::string str(T x) {
Expand All @@ -41,8 +39,6 @@ inline std::string str(T x) {

namespace torchao {

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800

using I4 = Vec<int, 4>;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
Expand Down Expand Up @@ -208,6 +204,8 @@ __global__ void Marlin_QQQ(
int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization
) {
// host code or device code with SM >= 80. Marlin only supports SM >= 80.
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
Expand Down Expand Up @@ -855,47 +853,8 @@ __global__ void Marlin_QQQ(
}
}
}
}

#else

template <const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__ void Marlin_QQQ(
const int4* __restrict__ A, // int8 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // int32 global_reduce buffer of shape
// (max_par*16*4)xn, as int8 tensor core's output is
// int32 dtype
int4* __restrict__ D, // fp16 output buffer of shape mxn
const float* __restrict__ s_tok, // fp32 activation per-token quantization
// scales of shape mx1
const int4* __restrict__ s_ch, // fp32 weight per-channel quantization
// scales of shape 1xn
const int4* __restrict__ s_group, // fp16 weight per-group quantization
// scales of shape (k/groupsize)xn, when
// group_blocks=-1, it should be nullptr
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization
) {
// Marlin is not implemented yet for SM < 8.0
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_qqq_gemm(..) requires CUDA_ARCH >= 8.0");
return;
}

#endif
}

// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
Expand Down Expand Up @@ -1132,6 +1091,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
torch::Tensor const& s_group,
torch::Tensor& workspace, int64_t size_m,
int64_t size_n, int64_t size_k) {
const auto dprops = at::cuda::getCurrentDeviceProperties();
if (dprops->major < 8) {
TORCH_CHECK(false, __func__, "requires SM >= 8.0. Current device is SM",
dprops->major, ".", dprops->minor);
}

// Verify M
TORCH_CHECK(size_m == a.size(0),
"Shape mismatch: a.size(0) = " + str(a.size(0)) +
Expand Down

0 comments on commit 8afd10e

Please sign in to comment.