Skip to content

Commit

Permalink
clean up more
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Aug 29, 2024
1 parent fe23cc8 commit ce414d6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 51 deletions.
41 changes: 3 additions & 38 deletions flash_attn/flash_attn_triton_kernel_decode_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,24 +113,16 @@ def _fwd_kernel_splitK(
)

# Quantization
# QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1
# PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_QUANT_GROUPS
# D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_QUANT_GROUPS
tl.static_assert(N_QUANT_GROUPS == 1, "N_QUANT_GROUPS != 1")
tl.static_assert(PACKED_PER_VAL == 1, "PACKED_PER_VAL != 1")
# TODO: enable quantization
tl.static_assert(N_QUANT_GROUPS == 1, "N_QUANT_GROUPS != 1. Quantization is not supported yet.")
tl.static_assert(PACKED_PER_VAL == 1, "PACKED_PER_VAL != 1. Quantization is not supported yet.")
QUANTIZED: tl.constexpr = 0


# Padding
# print("ACTUAL_BLOCK_DMODEL:", ACTUAL_BLOCK_DMODEL)
# print("BLOCK_DMODEL:", BLOCK_DMODEL)
PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL)
# print("PADDED_HEAD:", PADDED_HEAD)
if PADDED_HEAD:
d_mask = tl.arange(0, BLOCK_DMODEL) < ACTUAL_BLOCK_DMODEL
# print("d_mask:", d_mask)



start_m = tl.program_id(0)
off_zhg = tl.program_id(1)
Expand All @@ -151,7 +143,6 @@ def _fwd_kernel_splitK(
alibi_slope = tl.load(Alibi_slopes + a_offset)
else:
alibi_slope = None
# print("alibi_slope:", alibi_slope)

lo = splitk_idx * BLOCK_N_PER_SPLIT
if USE_CACHE_SEQLENs:
Expand All @@ -177,7 +168,6 @@ def _fwd_kernel_splitK(
v_base = V + v_head_idx * stride_vh + cache_batch_idx * stride_vz + off_g_q * stride_vg

# Copy new Keys and Values into Cache
# print("NEW_KV", NEW_KV)
if NEW_KV:
knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g

Expand Down Expand Up @@ -298,13 +288,9 @@ def _fwd_kernel_splitK(
q = (q * qk_scale).to(q.dtype)
if PADDED_HEAD:
q = tl.where(d_mask[None, :], q, 0.0)
# print("q:", q)

# print("BLOCK_N:", BLOCK_N)
# loop over k, v and update accumulator
for start_n in range(lo, hi, BLOCK_N):
# print("start_n:", start_n)

k, v = load_dequantize_k_v_group(
K_block_ptr,
V_block_ptr,
Expand All @@ -324,7 +310,6 @@ def _fwd_kernel_splitK(
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k) # noqa: F821
# print("qk before:", qk)

if USE_ALIBI:
row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
Expand All @@ -333,53 +318,41 @@ def _fwd_kernel_splitK(
# Compute relative positions
relative_pos = row_idx[:, None] + kv_len - (N_CTX_Q + col_idx[None, :])
relative_pos = tl.abs(relative_pos)
# print("relative_pos:", relative_pos)

# Compute ALiBi bias
alibi_bias = -1 * alibi_slope * relative_pos
# print("alibi_bias:", alibi_bias)
qk += (alibi_bias * 1.44269504)

# Apply causal mask if IS_CAUSAL is True
if IS_CAUSAL:
row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# print("row_idx:", row_idx)
col_idx = start_n + tl.arange(0, BLOCK_N)
# print("col_idx:", col_idx)

# create a N_CTX_Q x kv_len causal mask
col_offset = N_CTX_Q - kv_len
causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :])

# Apply the mask
qk = tl.where(causal_mask, qk, float("-inf"))
# print("qk after causal:", qk)

# TODO: This is slow, and only needed at the last iteration.
# Maybe we can unroll the last iteration instead?
if BOUNDS_CHECKS_N:
qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf"))
# print("qk after BOUNDS_CHECKS_N:", qk)

# -- compute scaling constant ---
# print("m_i:", m_i)
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
# print("m_i_new:", m_i_new)
if IS_CAUSAL:
alpha = tl.math.exp2(tl.where(m_i > float("-inf"), m_i - m_i_new, float("-inf")))
else:
alpha = tl.math.exp2(m_i - m_i_new)

# print("alpha:", alpha)
# print("before qk - m_i_new:", qk)
# cause of nan because subtracting infs
if IS_CAUSAL:
qk = tl.where(qk > float("-inf"), qk - m_i_new[:, None], float("-inf"))
else:
qk = qk - m_i_new[:, None]

p = tl.math.exp2(qk)
# print("p:", p)

# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
Expand All @@ -389,7 +362,6 @@ def _fwd_kernel_splitK(
# -- scale and update acc --
acc *= alpha[:, None]
acc += tl.dot(p.to(v.dtype), v)
# print("acc:", acc)

# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
Expand Down Expand Up @@ -559,19 +531,13 @@ def _splitK_reduce(
l_sum = tl.load(Metadata_ptr + stride_m2)
acc = tl.load(o_ptr)

# print("l_m:", l_m)
# print("l_sum:", l_sum)
# print("acc:", acc)

g_m = tl.max(l_m, axis=0)

# print("g_m:", g_m)
if IS_CAUSAL:
l_m_offset = l_m - g_m
alpha = tl.where(l_m_offset > float("-inf"), tl.math.exp2(l_m_offset), 0.0)
else:
alpha = tl.math.exp2(l_m - g_m)
# print("alpha:", alpha)

# read sum
l_sum *= alpha
Expand Down Expand Up @@ -760,7 +726,6 @@ def forward(cls, q, k, v, input_metadata):
dim_k = (dim_k - cls.NUM_QUANT_GROUPS) * 8

assert dim_k == dim_q, f"Keys have head dim {dim_k} but queries have head dim {dim_q}"
# print(f"batch_size = {batch_size}, seqlen_q = {seqlen_q}, seqlen_k = {seqlen_k}, heads_per_group_q = {heads_per_group_q}, heads_per_group_k = {heads_per_group_k}, dim_q = {dim_q}, dim_k = {dim_k}")

BLOCK_M = cls.BLOCK_M
BLOCK_N = cls.BLOCK_N
Expand Down
16 changes: 3 additions & 13 deletions flash_attn/flash_attn_triton_kernel_prefill_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,6 @@ def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second):
return tensor


@triton.jit
def print_gpu(prefix, val=None):
if (tl.program_id(0) == 0) and ((tl.program_id(1) == 0) and (tl.program_id(2) == 0)):
if val is not None:
tl.device_print(prefix, val)
else:
tl.device_print(prefix)


@triton.jit
def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False):
# when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix
Expand Down Expand Up @@ -348,12 +339,11 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
num_warps=4),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=8),
# TODO: This config fails with head_size not pow2 with data mismatches. Check why.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
# triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
# num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
# TODO: This configs fails with head_size not pow2 with data mismatches. figure out why
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
# triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
use_cuda_graph=True,
Expand Down

0 comments on commit ce414d6

Please sign in to comment.