Skip to content

Commit

Permalink
probably mask application mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Dec 2, 2024
1 parent 3b7f290 commit d008a3c
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def _bwd_kernel_one_col_block(
p_drop = tl.where(dropout_mask, p, 0.0)

# compute dv
dv += tl.dot(tl.trans(p) , do)
dv += tl.dot(tl.trans(p_drop), do) # dropout scale is applied at the end

# compute dp
dp_drop_scaled = tl.dot(do, tl.trans(v))
Expand Down Expand Up @@ -440,13 +440,14 @@ def _bwd_kernel_one_col_block(
# compute dp
dp = tl.dot(do, tl.trans(v))

# compute ds , ds = p * (dp - delta[:, None])
# compute ds
delta_ptrs = d_offset + offs_m * stride_deltam
delta_i = tl.load(delta_ptrs, mask=mask_m)
ds = (p * (dp - delta_i[:, None])) * sm_scale
dscores_scaled = (p * (dp - delta_i[:, None]))
ds = dscores_scaled * sm_scale
ds = tl.where(p_mask, ds, 0.0)

# compute dk = dot(ds.T, q)
# compute dk
dk += tl.dot(tl.trans(ds), q)

# compute dq
Expand All @@ -463,6 +464,7 @@ def _bwd_kernel_one_col_block(

if DROPOUT:
dv *= dropout_scale
dk *= dropout_scale

# write-back
if GROUP_SIZE != 1:
Expand Down

0 comments on commit d008a3c

Please sign in to comment.