Skip to content

Commit

Permalink
remove quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Sep 4, 2024
1 parent b1e2bd3 commit 3040c6b
Showing 1 changed file with 11 additions and 69 deletions.
80 changes: 11 additions & 69 deletions flash_attn/flash_attn_triton_kernel_decode_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,41 +84,7 @@ def _fwd_kernel_splitK(
IS_GQA: tl.constexpr,
IS_CAUSAL: tl.constexpr,
USE_ALIBI: tl.constexpr,
PACKED_PER_VAL: tl.constexpr = 1,
N_QUANT_GROUPS: tl.constexpr = 1,
):
"""This kernel can accept non-quantized or int4-quantized keys/values.
PACKED_PER_VAL determines the quantization type:
- PACKED_PER_VAL == 1 means no quantization
- PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32)
For the quantized case K/V should be int32 tensors.
Quantization can be row-wise (when N_QUANT_GROUPS = 1) or group-wise with N_QUANT_GROUPS = 2, 4, or 8.
Quantization coefficients are stored at the beginning of the row along the last dimension of K/V
So K[B, H, M, :] has a form
[ quant_coef0, quant_coef1, ...|
group0_quant_value0, group0_quant_value1,... |
group1_quant_value0, group1_quant_value1,...]
where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset.
"""
tl.static_assert(
(PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32))
or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)),
f"Only 4-bit quantization is supported, K/V should have dtype int32 in "
f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}",
)
tl.static_assert(
(((N_QUANT_GROUPS == 1 or N_QUANT_GROUPS == 2) or N_QUANT_GROUPS == 4) or N_QUANT_GROUPS == 8),
"Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.",
)

# Quantization
# TODO: enable quantization
tl.static_assert(N_QUANT_GROUPS == 1, "N_QUANT_GROUPS != 1. Quantization is not supported yet.")
tl.static_assert(PACKED_PER_VAL == 1, "PACKED_PER_VAL != 1. Quantization is not supported yet.")
QUANTIZED: tl.constexpr = 0


# Padding
PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL)
if PADDED_HEAD:
Expand Down Expand Up @@ -181,7 +147,7 @@ def _fwd_kernel_splitK(
for i in range(0, N_CTX_NEW, BLOCK_N):
# Load from K_new
k_new_block = tl.load(
knew_base + stride_kn_d * QUANTIZED * N_QUANT_GROUPS +
knew_base + stride_kn_d +
tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d +
(tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n,
mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) &
Expand All @@ -191,7 +157,7 @@ def _fwd_kernel_splitK(

# Store to K
tl.store(
k_base + stride_kd * QUANTIZED * N_QUANT_GROUPS +
k_base + stride_kd +
tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd +
(tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn,
k_new_block,
Expand All @@ -204,7 +170,7 @@ def _fwd_kernel_splitK(
for i in range(0, N_CTX_NEW, BLOCK_N):
# Load from V_new
v_new_block = tl.load(
vnew_base + stride_vn_d * QUANTIZED * N_QUANT_GROUPS +
vnew_base + stride_vn_d +
(tl.arange(0, BLOCK_N) + i)[:, None] * stride_vn_n +
tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vn_d,
mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) &
Expand All @@ -214,7 +180,7 @@ def _fwd_kernel_splitK(

# Store to V
tl.store(
v_base + stride_vd * QUANTIZED * N_QUANT_GROUPS +
v_base + stride_vd +
(tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn +
tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd,
v_new_block,
Expand All @@ -231,46 +197,26 @@ def _fwd_kernel_splitK(
order=(1, 0),
)

# Additional shift by 1 along the last dimension in the quantized case, since
# the first element along that dim contains packed quantization coefficients.
K_block_ptr = tl.make_block_ptr(
base=k_base + stride_kd * QUANTIZED * N_QUANT_GROUPS,
base=k_base + stride_kd ,
shape=(ACTUAL_BLOCK_DMODEL, hi),
strides=(stride_kd, stride_kn),
offsets=(0, lo),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=v_base + stride_vd * QUANTIZED * N_QUANT_GROUPS,
base=v_base + stride_vd,
shape=(hi, ACTUAL_BLOCK_DMODEL),
strides=(stride_vn, stride_vd),
offsets=(lo, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)

if QUANTIZED:
# Pointers to quantization coefficients
K_scale_shift_block_ptr = tl.make_block_ptr(
base=k_base,
shape=(1, hi),
strides=(stride_kd, stride_kn),
offsets=(0, lo),
block_shape=(1, BLOCK_N),
order=(0, 1),
)
V_scale_shift_block_ptr = tl.make_block_ptr(
base=v_base,
shape=(hi, 1),
strides=(stride_vn, stride_vd),
offsets=(lo, 0),
block_shape=(BLOCK_N, 1),
order=(1, 0),
)
else:
K_scale_shift_block_ptr = None
V_scale_shift_block_ptr = None

K_scale_shift_block_ptr = None
V_scale_shift_block_ptr = None

# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
Expand All @@ -297,7 +243,7 @@ def _fwd_kernel_splitK(
K_scale_shift_block_ptr,
V_scale_shift_block_ptr,
BOUNDS_CHECKS_N,
PACKED_PER_VAL,
1,
BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL,
Q.dtype.element_ty,
Expand Down Expand Up @@ -366,9 +312,6 @@ def _fwd_kernel_splitK(
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
if PACKED_PER_VAL > 1:
K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (0, BLOCK_N))
V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (BLOCK_N, 0))

# write back O
O_block_ptr = tl.make_block_ptr(
Expand Down Expand Up @@ -745,6 +688,7 @@ def forward(cls, q, k, v, input_metadata):
split_size = (seqlen_k + split_k - 1) // split_k
use_cache_seqlens = cache_seqlens is not None

# TODO: enable quantization
_fwd_kernel_splitK[grid](
Q=q,
K=k,
Expand Down Expand Up @@ -786,8 +730,6 @@ def forward(cls, q, k, v, input_metadata):
USE_ALIBI=False if input_metadata.alibi_slopes is None else True,
num_warps=num_warps,
num_stages=1,
PACKED_PER_VAL=PACKED_PER_VAL,
N_QUANT_GROUPS=cls.NUM_QUANT_GROUPS if PACKED_PER_VAL > 1 else 1,
)

out = torch.empty((batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_padded), device=q.device, dtype=q.dtype)
Expand Down

0 comments on commit 3040c6b

Please sign in to comment.