Skip to content

Commit

Permalink
fix backward bug
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Jan 16, 2025
1 parent 4c110bd commit ce51aac
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def backward(ctx, dout, *args):
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None


class FlashAttnFunc(torch.autograd.Function):
Expand Down Expand Up @@ -891,7 +891,7 @@ def backward(ctx, dout, *args):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None,


class FlashAttnVarlenFunc(torch.autograd.Function):
Expand Down
4 changes: 2 additions & 2 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def get_autotune_configs():
)
@triton.jit
def attn_fwd(Q, K, V, bias,
DESCALE_Q, DESCALE_K, DESCALE_V, DESCALE_P, stride_q_inv_scale_z, stride_kv_inv_scale_z, stride_p_inv_scale_z,
DESCALE_Q, DESCALE_K, DESCALE_V, DESCALE_S, stride_q_inv_scale_z, stride_kv_inv_scale_z, stride_p_inv_scale_z,
SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah,
Expand Down Expand Up @@ -435,7 +435,7 @@ def attn_fwd(Q, K, V, bias,
descale_q = tl.load(DESCALE_Q + off_z * stride_q_inv_scale_z + off_h_q)
descale_k = tl.load(DESCALE_K + off_z * stride_kv_inv_scale_z + off_h_k)
descale_v = tl.load(DESCALE_V + off_z * stride_kv_inv_scale_z + off_h_k)
descale_s = tl.load(DESCALE_P + off_z * stride_p_inv_scale_z + off_h_q)
descale_s = tl.load(DESCALE_S + off_z * stride_p_inv_scale_z + off_h_q)

q_fp8_offset = q_fp8 + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
q_fp8_ptrs = q_fp8_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
Expand Down

0 comments on commit ce51aac

Please sign in to comment.