From ce51aacd8f8fcb818ffb15312d9a7e2520614e31 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 16 Jan 2025 11:37:38 -0600 Subject: [PATCH] fix backward bug --- flash_attn/flash_attn_interface.py | 4 ++-- flash_attn/flash_attn_triton_amd/fwd_prefill.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 2c8819760..602754e22 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -802,7 +802,7 @@ def backward(ctx, dout, *args): ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnFunc(torch.autograd.Function): @@ -891,7 +891,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, class FlashAttnVarlenFunc(torch.autograd.Function): diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index f61ade09e..32d934f25 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -292,7 +292,7 @@ def get_autotune_configs(): ) @triton.jit def attn_fwd(Q, K, V, bias, - DESCALE_Q, DESCALE_K, DESCALE_V, DESCALE_P, stride_q_inv_scale_z, stride_kv_inv_scale_z, stride_p_inv_scale_z, + DESCALE_Q, DESCALE_K, DESCALE_V, DESCALE_S, stride_q_inv_scale_z, stride_kv_inv_scale_z, stride_p_inv_scale_z, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, @@ -435,7 +435,7 @@ def attn_fwd(Q, K, V, bias, descale_q = tl.load(DESCALE_Q + off_z * stride_q_inv_scale_z + off_h_q) descale_k = tl.load(DESCALE_K + off_z * stride_kv_inv_scale_z + off_h_k) descale_v = tl.load(DESCALE_V + off_z * stride_kv_inv_scale_z + off_h_k) - descale_s = tl.load(DESCALE_P + off_z * stride_p_inv_scale_z + off_h_q) + descale_s = tl.load(DESCALE_S + off_z * stride_p_inv_scale_z + off_h_q) q_fp8_offset = q_fp8 + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm q_fp8_ptrs = q_fp8_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk