diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill.py b/flash_attn/flash_attn_triton_amd/bwd_prefill.py index 60ab515d7..6608cdf75 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill.py @@ -410,7 +410,7 @@ def _bwd_kernel_one_col_block( p_drop = tl.where(dropout_mask, p, 0.0) # compute dv - dv += tl.dot(tl.trans(p) , do) + dv += tl.dot(tl.trans(p_drop), do) # dropout scale is applied at the end # compute dp dp_drop_scaled = tl.dot(do, tl.trans(v)) @@ -440,13 +440,14 @@ def _bwd_kernel_one_col_block( # compute dp dp = tl.dot(do, tl.trans(v)) - # compute ds , ds = p * (dp - delta[:, None]) + # compute ds delta_ptrs = d_offset + offs_m * stride_deltam delta_i = tl.load(delta_ptrs, mask=mask_m) - ds = (p * (dp - delta_i[:, None])) * sm_scale + dscores_scaled = (p * (dp - delta_i[:, None])) + ds = dscores_scaled * sm_scale ds = tl.where(p_mask, ds, 0.0) - # compute dk = dot(ds.T, q) + # compute dk dk += tl.dot(tl.trans(ds), q) # compute dq @@ -463,6 +464,7 @@ def _bwd_kernel_one_col_block( if DROPOUT: dv *= dropout_scale + dk *= dropout_scale # write-back if GROUP_SIZE != 1: