Skip to content

Commit

Permalink
save scaling code so far
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Jan 7, 2025
1 parent 957b0e6 commit a290a6d
Showing 1 changed file with 51 additions and 9 deletions.
60 changes: 51 additions & 9 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import triton
import triton.language as tl
from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, write_dropout_mask, create_dropout_mask
from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, MetaData, get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, write_dropout_mask, create_dropout_mask

# NOTE: triton fails to import tl.constexprs so create them here for the file
tl_DROPOUT_USE_PYTORCH: tl.constexpr = DROPOUT_USE_PYTORCH
Expand Down Expand Up @@ -63,6 +63,7 @@ def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpo
def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m,
actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs,
block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope,
q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8: tl.constexpr,
IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr,
Expand Down Expand Up @@ -103,6 +104,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
# -- compute qk ----
qk += tl.dot(q, k)
qk_scaled = qk * SM_SCALE
if IS_FP8:
qk_scaled *= q_scale * k_scale # descale qk after matmul if quantized

if IS_CAUSAL:
causal_boundary = start_n + offs_n_causal
Expand Down Expand Up @@ -135,7 +138,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
p_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)

# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1)
l_ij = tl.sum(p, 1) # p is fp32 at this point
if ENABLE_DROPOUT:
if tl_DROPOUT_USE_PYTORCH:
dropout_mask = tl.load(dropout_mask_ptrs, mask=p_mask)
Expand Down Expand Up @@ -170,7 +173,10 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij
acc += tl.dot(p.to(v.type.element_ty), v)
if IS_FP8:
acc += tl.dot((p * p_inv_scale).to(v.type.element_ty), v) * p_scale * v_scale
else:
acc += tl.dot(p.to(v.type.element_ty), v)
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
if bias_ptrs is not None:
Expand Down Expand Up @@ -259,15 +265,17 @@ def get_autotune_configs():
use_cuda_graph=True,
)
@triton.jit
def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn,
def attn_fwd(Q, K, V, bias,
Q_SCALE, K_SCALE, V_SCALE, P_SCALE, P_INV_SCALE, stride_qscale_z, stride_kvscale_z, stride_pscale_z, stride_pinvscale_z,
SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah,
stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr,
HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr):
ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, IS_FP8: tl.constexpr):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
Expand Down Expand Up @@ -396,6 +404,16 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL)
q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0)

# Load scale factors if IS_FP8.
if IS_FP8:
q_scale = tl.load(Q_SCALE + off_z * stride_qscale_z + off_h_q)
k_scale = tl.load(K_SCALE + off_z * stride_kvscale_z + off_h_k)
v_scale = tl.load(V_SCALE + off_z * stride_kvscale_z + off_h_k)
p_scale = tl.load(P_SCALE + off_z * stride_pscale_z + off_h_q)
p_inv_scale = tl.load(P_INV_SCALE + off_z * stride_pinvscale_z + off_h_q)
else:
q_scale, k_scale, v_scale, p_scale, p_inv_scale = 1.0, 1.0, 1.0, 1.0, 1.0

# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
Expand All @@ -421,6 +439,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
sd_mask_ptrs, dropout_mask_ptrs,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min, block_max, 0, 0, 0, alibi_slope,
q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8,
# IS_CAUSAL, ....
False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
# _, MASK_STEPS, ...
Expand Down Expand Up @@ -449,6 +468,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs,
sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks,
n_extra_tokens, alibi_slope,
q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8,
IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD,
Expand Down Expand Up @@ -539,13 +559,34 @@ def attention_prefill_forward_triton_impl(
return_softmax,
use_exp2):


if q.dtype in FP8_TYPES:
is_fp8 = True
q_scale = fp8_metadata.q_scale
k_scale = fp8_metadata.k_scale
v_scale = fp8_metadata.v_scale
p_scale = fp8_metadata.p_scale
p_inv_scale = fp8_metadata.p_inv_scale
q_scale_stride_z = q_scale.stride(0)
kv_scale_stride_z = k_scale.stride(0)
p_scale_stride_z = p_scale.stride(0)
p_inv_scale_stride_z = p_inv_scale.stride(0)
else:
q_scale = k_scale = v_scale = p_scale = p_inv_scale = 1
q_scale_stride_z = kv_scale_stride_z = p_scale_stride_z = p_inv_scale_stride_z = 0

if DEBUG:
print()
print("attention_prefill_forward_triton_impl")
print("q:", q, q.shape)
print("k:", k, k.shape)
print("v:", v, v.shape)
print("o:", o, o.shape)
print("q_scale:", q_scale)
print("k_scale:", k_scale)
print("v_scale:", v_scale)
print("p_scale:", p_scale)
print("p_inv_scale:", p_inv_scale)
print("sm_scale:", sm_scale)
print("alibi_slopes:", alibi_slopes)
print("causal:", causal)
Expand Down Expand Up @@ -618,15 +659,16 @@ def attention_prefill_forward_triton_impl(
else:
alibi_strides = (0, 0)


attn_fwd[grid](q, k, v, bias, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
attn_fwd[grid](q, k, v, bias,
q_scale, k_scale, v_scale, p_scale, p_inv_scale, q_scale_stride_z, kv_scale_stride_z, p_scale_stride_z, p_inv_scale_stride_z,
sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
*bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes,
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen,
BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True,
USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax)
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=is_fp8)

if DEBUG:
print()
Expand Down

0 comments on commit a290a6d

Please sign in to comment.