diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 8974b2619..4c86bc995 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -273,6 +273,10 @@ def _flash_attn_backward( alibi_slopes: Optional[torch.Tensor], deterministic: bool, rng_state: Optional[torch.Tensor] = None, + descale_q=None, + descale_k=None, + descale_v=None, + descale_p=None, ) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] @@ -301,6 +305,10 @@ def _flash_attn_backward( deterministic, None, rng_state, + descale_q, + descale_k, + descale_v, + descale_p ) return softmax_d @@ -849,7 +857,7 @@ def forward( descale_v=descale_v, descale_p=descale_p, ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_p) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal @@ -862,7 +870,7 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_p = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) head_size_og = dout.size(3) dout_padded = dout @@ -887,6 +895,10 @@ def backward(ctx, dout, *args): ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_p=descale_p, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill.py b/flash_attn/flash_attn_triton_amd/bwd_prefill.py index bea6a3d13..a7e8d0688 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill.py @@ -126,6 +126,8 @@ def _bwd_kernel_one_col_block( dropout_p, philox_seed, batch_philox_offset, + descale_q, + descale_k, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, @@ -135,6 +137,7 @@ def _bwd_kernel_one_col_block( DROPOUT: tl.constexpr, USE_EXP2: tl.constexpr, GROUP_SIZE: tl.constexpr, + IS_FP8: tl.constexpr, ): if CAUSAL: # TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M @@ -160,7 +163,8 @@ def _bwd_kernel_one_col_block( k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk k = tl.load(k_ptrs, mask=kv_mask, other=0.0) - v = tl.load(v_ptrs, mask=kv_mask, other=0.0) + kT = tl.trans(k) + vT = tl.trans(tl.load(v_ptrs, mask=kv_mask, other=0.0)) # loop over rows for start_m in range(lo, num_block_m): @@ -179,7 +183,10 @@ def _bwd_kernel_one_col_block( # recompute p = softmax(qk, dim=-1).T qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, tl.trans(k)) + if IS_FP8: + qk += (tl.dot(q, kT) * descale_q * descale_k) + else: + qk += tl.dot(q, kT) if CAUSAL: col_offset = N_CTX_Q - N_CTX_K @@ -228,7 +235,7 @@ def _bwd_kernel_one_col_block( dv += tl.dot(tl.trans(p_drop_scaled.to(do.type.element_ty)), do) # compute dp - dp_drop_scaled = tl.dot(do, tl.trans(v)) + dp_drop_scaled = tl.dot(do, vT) dp = tl.where(dropout_mask, dp_drop_scaled, 0.0) * dropout_scale # compute ds @@ -244,7 +251,7 @@ def _bwd_kernel_one_col_block( dv += tl.dot(tl.trans(p), do) # compute dp - dp = tl.dot(do, tl.trans(v)) + dp = tl.dot(do, vT) # compute ds delta_ptrs = delta_offset + offs_m * stride_deltam @@ -292,6 +299,8 @@ def _bwd_kernel( L, Delta, Dropout_mask, + DESCALE_Q, + DESCALE_K, stride_dq_all, stride_qz, stride_qh, @@ -330,6 +339,7 @@ def _bwd_kernel( DROPOUT: tl.constexpr, USE_EXP2: tl.constexpr, IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, ): # program ids off_zh = tl.program_id(0) @@ -374,8 +384,16 @@ def _bwd_kernel( else: batch_philox_offset = 0 dropout_offset = 0 - + if IS_FP8: + stride_descale_q_z = HQ + stride_descale_kv_z = HK + + descale_q = tl.load(DESCALE_Q + off_z * stride_descale_q_z + off_hq) + descale_k = tl.load(DESCALE_K + off_z * stride_descale_kv_z + off_hk) + else: + descale_q, descale_k = 1.0, 1.0 + # output tensor offsets dk_offset = DK + off_z * stride_kz + off_hk * stride_kh + k_start * stride_kn dv_offset = DV + off_z * stride_vz + off_hk * stride_vh + k_start * stride_vn @@ -430,7 +448,11 @@ def _bwd_kernel( start_n, num_block_m, num_block_n, - dropout_p, philox_seed, batch_philox_offset, + dropout_p, + philox_seed, + batch_philox_offset, + descale_q, + descale_k, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, @@ -439,7 +461,8 @@ def _bwd_kernel( CAUSAL=CAUSAL, DROPOUT=DROPOUT, USE_EXP2=USE_EXP2, - GROUP_SIZE=GROUP_SIZE + GROUP_SIZE=GROUP_SIZE, + IS_FP8=IS_FP8 ) else: for start_n in range(0, num_block_n): @@ -487,7 +510,11 @@ def _bwd_kernel( start_n, num_block_m, num_block_n, - dropout_p, philox_seed, batch_philox_offset, + dropout_p, + philox_seed, + batch_philox_offset, + descale_q, + descale_k, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, @@ -496,7 +523,8 @@ def _bwd_kernel( CAUSAL=CAUSAL, DROPOUT=DROPOUT, USE_EXP2=USE_EXP2, - GROUP_SIZE=GROUP_SIZE + GROUP_SIZE=GROUP_SIZE, + IS_FP8=IS_FP8 ) @@ -524,6 +552,11 @@ def attention_prefill_backward_triton_impl( philox_offset, use_exp2: bool, sequence_parallel = True, + # fp8 + descale_q=None, + descale_k=None, + descale_v=None, + descale_p=None ): if DEBUG: print() @@ -695,41 +728,8 @@ def attention_prefill_backward_triton_impl( IS_VARLEN=is_varlen ) - if False: - print("_bwd_kernel inputs") - print("do:", do, do.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("sm_scale", sm_scale) - print("o:", o, o.shape) - print("dq:", dq, dq.shape) - print("dk:", dk, dk.shape) - print("dv:", dv, dv.shape) - print("L:", softmax_lse, softmax_lse.shape) + if DEBUG: print("delta:", delta, delta.shape) - print("stride_qz, stride_qh, stride_qm, stride_qk:", stride_qz, stride_qh, stride_qm, stride_qk) - print("stride_kz, stride_kh, stride_kn, stride_kk:", stride_kz, stride_kh, stride_kn, stride_kk) - print("stride_vz, stride_vh, stride_vn, stride_vk:", stride_vz, stride_vh, stride_vn, stride_vk) - print("batch_q:", batch) - print("heads_q:",nheads_q) - print("max_seqlen_q:",max_seqlen_q) - print("max_seqlen_k:",max_seqlen_k) - print("dropout_p:",dropout_p) - print("philox_seed:", philox_seed) - print("philox_offset:",philox_offset) - print("BLOCK_M:",BLOCK_M) - print("BLOCK_N:",BLOCK_M) - print("BLOCK_DMODEL:",BLOCK_DMODEL) - print("ACTUAL_BLOCK_DMODEL:",ACTUAL_BLOCK_DMODEL) - print("SEQUENCE_PARALLEL:",sequence_parallel) - print("CAUSAL:",causal) - print("DROPOUT:", use_dropout) - print("num_warps:",num_warps) - print("num_stages:", num_stages) - print("USE_EXP2:", use_exp2) - print("num_blocks_m:", num_blocks_m) - print("num_blocks_n:", num_blocks_n) _bwd_kernel[(batch * nheads_q, num_blocks_n if sequence_parallel else 1)]( q, @@ -744,6 +744,8 @@ def attention_prefill_backward_triton_impl( softmax_lse, delta, dropout_mask, + descale_q, + descale_k, stride_dq_all, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, @@ -771,7 +773,8 @@ def attention_prefill_backward_triton_impl( num_warps=num_warps, num_stages=num_stages, waves_per_eu = waves_per_eu, - IS_VARLEN=is_varlen + IS_VARLEN=is_varlen, + IS_FP8=is_fp8 ) if sequence_parallel: @@ -779,7 +782,6 @@ def attention_prefill_backward_triton_impl( if DEBUG: print("attention_prefill_backward_triton_impl outputs") - print("delta:", delta, delta.shape) print("dv:", dv, dv.shape) print("dk:", dk, dk.shape) print("dq:", dq, dq.shape) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 19ae4b139..287613322 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -273,7 +273,7 @@ def get_autotune_configs(): ) @triton.jit def attn_fwd(Q, K, V, bias, - DESCALE_Q, DESCALE_K, DESCALE_V, DESCALE_P, stride_q_inv_scale_z, stride_kv_inv_scale_z, stride_p_inv_scale_z, + DESCALE_Q, DESCALE_K, DESCALE_V, DESCALE_P, stride_descale_q_z, stride_descale_kv_z, stride_descale_p_z, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, @@ -413,10 +413,10 @@ def attn_fwd(Q, K, V, bias, # Load scale factors if IS_FP8. if IS_FP8: - descale_q = tl.load(DESCALE_Q + off_z * stride_q_inv_scale_z + off_h_q) - descale_k = tl.load(DESCALE_K + off_z * stride_kv_inv_scale_z + off_h_k) - descale_v = tl.load(DESCALE_V + off_z * stride_kv_inv_scale_z + off_h_k) - descale_p = tl.load(DESCALE_P + off_z * stride_p_inv_scale_z + off_h_q) + descale_q = tl.load(DESCALE_Q + off_z * stride_descale_q_z + off_h_q) + descale_k = tl.load(DESCALE_K + off_z * stride_descale_kv_z + off_h_k) + descale_v = tl.load(DESCALE_V + off_z * stride_descale_kv_z + off_h_k) + descale_p = tl.load(DESCALE_P + off_z * stride_descale_p_z + off_h_q) else: descale_q, descale_k, descale_v, descale_p = 1.0, 1.0, 1.0, 1.0 diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index d0b8e1d04..065d2ecc8 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -149,6 +149,10 @@ def bwd( deterministic, gen_, rng_state, + descale_q, + descale_k, + descale_v, + descale_p ): # NOTE: this might have perf costs dq.zero_() @@ -236,6 +240,10 @@ def bwd( philox_seed, philox_offset, False, + descale_q = descale_q, + descale_k = descale_k, + descale_v = descale_v, + descale_p = descale_p ) delta = delta_triton