Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Nov 19, 2024
1 parent 1410fbc commit f233ac5
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions flash_attn/flash_attn_triton_amd/bwd_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ def attention_backward_core_ref_impl(

if dropout_p > 0.0:
dropout_mask = generate_dropout_mask(p.shape, dropout_p, philox_seed, philox_offset, p.device, p.dtype)

dropout_scale = (1.0 / (1 - dropout_p))

p = p * dropout_mask * dropout_scale
if DEBUG_CORE:
print("p after dropout:", p, p.shape)

# compute gradient wrt v
dv = torch.matmul(p.transpose(-2, -1), do)
if DEBUG_CORE:
Expand All @@ -85,15 +86,17 @@ def attention_backward_core_ref_impl(
if DEBUG_CORE:
print("dp:", dp, dp.shape)
if dropout_p > 0.0:
dp = dp * dropout_scale # Add scaling here since reference scales v
dp = dp * dropout_mask
if DEBUG_CORE:
print("dp after dropout:", dp, dp.shape)

# calculate ds using dp
delta = torch.sum(o * do, axis=-1) # what OAI kernel uses
delta_3d = delta.unsqueeze(-1)
if DEBUG_CORE:
print("delta_3d:", delta_3d, delta_3d.shape)
ds = (p * (dp - delta_3d)) * sm_scale
# calculate ds
delta = torch.sum(o * do, axis=-1).unsqueeze(-1)
dscores_scaled = (p * (dp - delta))
ds = dscores_scaled * sm_scale
if DEBUG_CORE:
print("delta:", delta, delta.shape)
print("dscores_scaled:", dscores_scaled, dscores_scaled.shape)
print("ds:", ds, ds.shape)


Expand All @@ -111,16 +114,15 @@ def attention_backward_core_ref_impl(
dq = dq.to(torch.float16)
dk = dk.to(torch.float16)
dv = dv.to(torch.float16)

# remove d dim with size 1
delta = delta_3d.squeeze(-1)
delta = delta.squeeze(-1)

if DEBUG_CORE:
print("attention_backward_core_ref_impl output")
print("dq:", dq, dq.shape)
print("dk:", dk, dk.shape)
print("dv:", dv, dv.shape)
print("delta:", delta, delta.shape)
print("dv:", dv, dv.shape)
print("dk:", dk, dk.shape)
print("dq:", dq, dq.shape)

return dq, dk, dv, delta

Expand Down

0 comments on commit f233ac5

Please sign in to comment.