Skip to content

Commit

Permalink
try again
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Jan 17, 2025
1 parent a811071 commit 5037533
Show file tree
Hide file tree
Showing 3 changed files with 275 additions and 337 deletions.
91 changes: 10 additions & 81 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
tl_DROPOUT_USE_PYTORCH: tl.constexpr = DROPOUT_USE_PYTORCH
tl_DROPOUT_DUMP: tl.constexpr = DROPOUT_DUMP

tl_FP8_DUMP: tl.constexpr = False

# Convenience function to load with optional boundary checks.
# "First" is the major dim, "second" is the minor dim.
@triton.jit
Expand Down Expand Up @@ -64,7 +62,7 @@ def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpo
@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m,
actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs,
q_fp8_ptrs, k_fp8_ptrs, qk_fp8_ptrs, p_fp8_ptrs, acc_fp8_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope,
block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope,
descale_q, descale_k, descale_v, descale_s, IS_FP8: tl.constexpr,
IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr,
Expand Down Expand Up @@ -110,14 +108,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri

# -- compute qk ----
if IS_FP8 :
if tl_FP8_DUMP:
tl.store(q_fp8_ptrs, q, mask=q_mask)
tl.store(k_fp8_ptrs, k, mask=k_mask)

qk += (tl.dot(q, k) * descale_q * descale_k)

if tl_FP8_DUMP:
tl.store(qk_fp8_ptrs, qk, mask=p_mask)
else:
qk += tl.dot(q, k)
qk_scaled = qk * SM_SCALE
Expand Down Expand Up @@ -189,13 +180,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri

if IS_FP8:
p *= (1.0/ descale_s) # put p into fp8 range
if tl_FP8_DUMP:
tl.store(p_fp8_ptrs, p, mask=p_mask)

acc += (tl.dot(p.to(v.type.element_ty), v) * descale_s * descale_v)

if tl_FP8_DUMP:
tl.store(acc_fp8_ptrs, acc)
else:
acc += tl.dot(p.to(v.type.element_ty), v)

Expand All @@ -209,10 +194,6 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
if ENABLE_DROPOUT:
dropout_mask_ptrs += BLOCK_N * stride_sn
philox_ptrs += BLOCK_N * stride_sn
if IS_FP8:
qk_fp8_ptrs += BLOCK_N * stride_sn
p_fp8_ptrs += BLOCK_N * stride_sn

return acc, l_i, m_i


Expand Down Expand Up @@ -297,7 +278,7 @@ def attn_fwd(Q, K, V, bias,
stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah,
stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, q_fp8, k_fp8, qk_fp8, p_fp8, acc_fp8, HQ: tl.constexpr,
dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr,
HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr,
Expand Down Expand Up @@ -435,31 +416,9 @@ def attn_fwd(Q, K, V, bias,
descale_q = tl.load(DESCALE_Q + off_z * stride_q_inv_scale_z + off_h_q)
descale_k = tl.load(DESCALE_K + off_z * stride_kv_inv_scale_z + off_h_k)
descale_v = tl.load(DESCALE_V + off_z * stride_kv_inv_scale_z + off_h_k)
descale_s = tl.load(DESCALE_S + off_z * stride_p_inv_scale_z + off_h_q)

q_fp8_offset = q_fp8 + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
q_fp8_ptrs = q_fp8_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk


k_fp8_offset = k_fp8 + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
k_fp8_ptrs = k_fp8_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn


qk_fp8_offset = qk_fp8 + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm
qk_fp8_ptrs = qk_fp8_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn #+ cu_seqlens_q_start * stride_sm

p_fp8_offset = p_fp8 + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm
p_fp8_ptrs = p_fp8_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn #+ cu_seqlens_q_start * stride_sm

acc_fp8_offset = acc_fp8 + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
acc_fp8_ptrs = acc_fp8_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
descale_s = tl.load(DESCALE_S + off_z * stride_p_inv_scale_z + off_h_q)
else:
descale_q, descale_k, descale_v, descale_s = 1.0, 1.0, 1.0, 1.0
q_fp8_ptrs = None
k_fp8_ptrs = None
qk_fp8_ptrs = None
p_fp8_ptrs = None
acc_fp8_ptrs = None
descale_q, descale_k, descale_v, descale_s = 1.0, 1.0, 1.0, 1.0

# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
Expand All @@ -483,7 +442,7 @@ def attn_fwd(Q, K, V, bias,
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn,
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs,
sd_mask_ptrs, dropout_mask_ptrs, q_fp8_ptrs, k_fp8_ptrs, qk_fp8_ptrs, p_fp8_ptrs, acc_fp8_ptrs,
sd_mask_ptrs, dropout_mask_ptrs,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min, block_max, 0, 0, 0, alibi_slope,
descale_q, descale_k, descale_v, descale_s, IS_FP8,
Expand Down Expand Up @@ -511,14 +470,10 @@ def attn_fwd(Q, K, V, bias,
if ENABLE_DROPOUT:
dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn
philox_ptrs += n_full_blocks * BLOCK_N * stride_sn
if IS_FP8:
qk_fp8_ptrs += n_full_blocks * BLOCK_N * stride_sn
p_fp8_ptrs += n_full_blocks * BLOCK_N * stride_sn
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn,
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs,
sd_mask_ptrs, dropout_mask_ptrs, q_fp8_ptrs, k_fp8_ptrs, qk_fp8_ptrs, p_fp8_ptrs, acc_fp8_ptrs,
block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope,
descale_q, descale_k, descale_v, descale_s, IS_FP8,
sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks,
n_extra_tokens, alibi_slope, descale_q, descale_k, descale_v, descale_s, IS_FP8,
IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD,
Expand Down Expand Up @@ -660,24 +615,12 @@ def attention_prefill_forward_triton_impl(
descale_k_stride_z = descale_k.stride(0)
descale_v_stride_z = descale_v.stride(0)
descale_s_stride_z = descale_s.stride(0)

# dump intermedia results
q_fp8 = torch.zeros_like(q)
k_fp8 = torch.zeros_like(k)
# NOTE: the result of fp8 dot is float32
qk_fp8 = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), dtype=torch.float32, device=q.device)
p_fp8 = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), dtype=torch.float32, device=q.device)
acc_fp8 = torch.zeros(o.shape, dtype=torch.float32, device=q.device)
else:
is_fp8 = False
# For non-FP8 types, use dummy values (no scaling needed)
descale_q = descale_k = descale_v = descale_s = 1
descale_q_stride_z = descale_k_stride_z = descale_v_stride_z = descale_s_stride_z = 0
q_fp8 = None
k_fp8 = None
qk_fp8= None
p_fp8 = None
acc_fp8 = None


if DEBUG:
print("is_fp8:", is_fp8)
Expand Down Expand Up @@ -728,10 +671,7 @@ def attention_prefill_forward_triton_impl(
else:
sd_mask = None
dropout_mask = None
if is_fp8:
scores_strides = (qk_fp8.stride(0), qk_fp8.stride(1), qk_fp8.stride(2), qk_fp8.stride(3))
else:
scores_strides = (0, 0, 0, 0)
scores_strides = (0, 0, 0, 0)

# stores LSE the log of the normalization constant / sum of expoential score(unnormalzied probablities)
if is_varlen:
Expand All @@ -753,17 +693,13 @@ def attention_prefill_forward_triton_impl(
else:
alibi_strides = (0, 0)

if DEBUG:
print("attn_fwd input")
print("q:", q)
print("k:", k)

attn_fwd[grid](q, k, v, bias,
descale_q, descale_k, descale_v, descale_s, descale_q_stride_z, descale_k_stride_z, descale_s_stride_z,
sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
*bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes,
q_fp8=q_fp8, k_fp8 = k_fp8, qk_fp8=qk_fp8, p_fp8=p_fp8, acc_fp8=acc_fp8, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen,
BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True,
USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p
Expand All @@ -779,12 +715,5 @@ def attention_prefill_forward_triton_impl(
print("dropout_mask:", dropout_mask, dropout_mask.shape if dropout_mask is not None else None)
print("dropout_fraction fwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item())
write_dropout_mask(dropout_mask, "dropout_mask_fwd")
if is_fp8:
print("")
print("q_fp8:", q_fp8)
print("k_fp8:", k_fp8)
print("qk_fp8:", qk_fp8)
print("p_fp8:", p_fp8)
print("acc_fp8:", acc_fp8)

return o, softmax_lse, sd_mask.to(o.dtype) if return_softmax else None
Loading

0 comments on commit 5037533

Please sign in to comment.