Skip to content

Commit

Permalink
dv matches
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Jan 28, 2025
1 parent f8298ae commit 537f11a
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 52 deletions.
16 changes: 14 additions & 2 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ def _flash_attn_backward(
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None,
) -> torch.Tensor:
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Expand Down Expand Up @@ -301,6 +305,10 @@ def _flash_attn_backward(
deterministic,
None,
rng_state,
descale_q,
descale_k,
descale_v,
descale_p
)
return softmax_d

Expand Down Expand Up @@ -849,7 +857,7 @@ def forward(
descale_v=descale_v,
descale_p=descale_p,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_p)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
Expand All @@ -862,7 +870,7 @@ def forward(

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
q, k, v, out, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_p = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
head_size_og = dout.size(3)
dout_padded = dout
Expand All @@ -887,6 +895,10 @@ def backward(ctx, dout, *args):
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_p=descale_p,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
Expand Down
92 changes: 47 additions & 45 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def _bwd_kernel_one_col_block(
dropout_p,
philox_seed,
batch_philox_offset,
descale_q,
descale_k,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
Expand All @@ -135,6 +137,7 @@ def _bwd_kernel_one_col_block(
DROPOUT: tl.constexpr,
USE_EXP2: tl.constexpr,
GROUP_SIZE: tl.constexpr,
IS_FP8: tl.constexpr,
):
if CAUSAL:
# TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M
Expand All @@ -160,7 +163,8 @@ def _bwd_kernel_one_col_block(
k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk
k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
v = tl.load(v_ptrs, mask=kv_mask, other=0.0)
kT = tl.trans(k)
vT = tl.trans(tl.load(v_ptrs, mask=kv_mask, other=0.0))

# loop over rows
for start_m in range(lo, num_block_m):
Expand All @@ -179,7 +183,10 @@ def _bwd_kernel_one_col_block(

# recompute p = softmax(qk, dim=-1).T
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
if IS_FP8:
qk += (tl.dot(q, kT) * descale_q * descale_k)
else:
qk += tl.dot(q, kT)

if CAUSAL:
col_offset = N_CTX_Q - N_CTX_K
Expand Down Expand Up @@ -228,7 +235,7 @@ def _bwd_kernel_one_col_block(
dv += tl.dot(tl.trans(p_drop_scaled.to(do.type.element_ty)), do)

# compute dp
dp_drop_scaled = tl.dot(do, tl.trans(v))
dp_drop_scaled = tl.dot(do, vT)
dp = tl.where(dropout_mask, dp_drop_scaled, 0.0) * dropout_scale

# compute ds
Expand All @@ -244,7 +251,7 @@ def _bwd_kernel_one_col_block(
dv += tl.dot(tl.trans(p), do)

# compute dp
dp = tl.dot(do, tl.trans(v))
dp = tl.dot(do, vT)

# compute ds
delta_ptrs = delta_offset + offs_m * stride_deltam
Expand Down Expand Up @@ -292,6 +299,8 @@ def _bwd_kernel(
L,
Delta,
Dropout_mask,
DESCALE_Q,
DESCALE_K,
stride_dq_all,
stride_qz,
stride_qh,
Expand Down Expand Up @@ -330,6 +339,7 @@ def _bwd_kernel(
DROPOUT: tl.constexpr,
USE_EXP2: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_FP8: tl.constexpr,
):
# program ids
off_zh = tl.program_id(0)
Expand Down Expand Up @@ -374,8 +384,16 @@ def _bwd_kernel(
else:
batch_philox_offset = 0
dropout_offset = 0


if IS_FP8:
stride_descale_q_z = HQ
stride_descale_kv_z = HK

descale_q = tl.load(DESCALE_Q + off_z * stride_descale_q_z + off_hq)
descale_k = tl.load(DESCALE_K + off_z * stride_descale_kv_z + off_hk)
else:
descale_q, descale_k = 1.0, 1.0

# output tensor offsets
dk_offset = DK + off_z * stride_kz + off_hk * stride_kh + k_start * stride_kn
dv_offset = DV + off_z * stride_vz + off_hk * stride_vh + k_start * stride_vn
Expand Down Expand Up @@ -430,7 +448,11 @@ def _bwd_kernel(
start_n,
num_block_m,
num_block_n,
dropout_p, philox_seed, batch_philox_offset,
dropout_p,
philox_seed,
batch_philox_offset,
descale_q,
descale_k,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
Expand All @@ -439,7 +461,8 @@ def _bwd_kernel(
CAUSAL=CAUSAL,
DROPOUT=DROPOUT,
USE_EXP2=USE_EXP2,
GROUP_SIZE=GROUP_SIZE
GROUP_SIZE=GROUP_SIZE,
IS_FP8=IS_FP8
)
else:
for start_n in range(0, num_block_n):
Expand Down Expand Up @@ -487,7 +510,11 @@ def _bwd_kernel(
start_n,
num_block_m,
num_block_n,
dropout_p, philox_seed, batch_philox_offset,
dropout_p,
philox_seed,
batch_philox_offset,
descale_q,
descale_k,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
Expand All @@ -496,7 +523,8 @@ def _bwd_kernel(
CAUSAL=CAUSAL,
DROPOUT=DROPOUT,
USE_EXP2=USE_EXP2,
GROUP_SIZE=GROUP_SIZE
GROUP_SIZE=GROUP_SIZE,
IS_FP8=IS_FP8
)


Expand Down Expand Up @@ -524,6 +552,11 @@ def attention_prefill_backward_triton_impl(
philox_offset,
use_exp2: bool,
sequence_parallel = True,
# fp8
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None
):
if DEBUG:
print()
Expand Down Expand Up @@ -695,41 +728,8 @@ def attention_prefill_backward_triton_impl(
IS_VARLEN=is_varlen
)

if False:
print("_bwd_kernel inputs")
print("do:", do, do.shape)
print("q:", q, q.shape)
print("k:", k, k.shape)
print("v:", v, v.shape)
print("sm_scale", sm_scale)
print("o:", o, o.shape)
print("dq:", dq, dq.shape)
print("dk:", dk, dk.shape)
print("dv:", dv, dv.shape)
print("L:", softmax_lse, softmax_lse.shape)
if DEBUG:
print("delta:", delta, delta.shape)
print("stride_qz, stride_qh, stride_qm, stride_qk:", stride_qz, stride_qh, stride_qm, stride_qk)
print("stride_kz, stride_kh, stride_kn, stride_kk:", stride_kz, stride_kh, stride_kn, stride_kk)
print("stride_vz, stride_vh, stride_vn, stride_vk:", stride_vz, stride_vh, stride_vn, stride_vk)
print("batch_q:", batch)
print("heads_q:",nheads_q)
print("max_seqlen_q:",max_seqlen_q)
print("max_seqlen_k:",max_seqlen_k)
print("dropout_p:",dropout_p)
print("philox_seed:", philox_seed)
print("philox_offset:",philox_offset)
print("BLOCK_M:",BLOCK_M)
print("BLOCK_N:",BLOCK_M)
print("BLOCK_DMODEL:",BLOCK_DMODEL)
print("ACTUAL_BLOCK_DMODEL:",ACTUAL_BLOCK_DMODEL)
print("SEQUENCE_PARALLEL:",sequence_parallel)
print("CAUSAL:",causal)
print("DROPOUT:", use_dropout)
print("num_warps:",num_warps)
print("num_stages:", num_stages)
print("USE_EXP2:", use_exp2)
print("num_blocks_m:", num_blocks_m)
print("num_blocks_n:", num_blocks_n)

_bwd_kernel[(batch * nheads_q, num_blocks_n if sequence_parallel else 1)](
q,
Expand All @@ -744,6 +744,8 @@ def attention_prefill_backward_triton_impl(
softmax_lse,
delta,
dropout_mask,
descale_q,
descale_k,
stride_dq_all,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
Expand Down Expand Up @@ -771,15 +773,15 @@ def attention_prefill_backward_triton_impl(
num_warps=num_warps,
num_stages=num_stages,
waves_per_eu = waves_per_eu,
IS_VARLEN=is_varlen
IS_VARLEN=is_varlen,
IS_FP8=is_fp8
)

if sequence_parallel:
dq = dq.sum(dim=0)

if DEBUG:
print("attention_prefill_backward_triton_impl outputs")
print("delta:", delta, delta.shape)
print("dv:", dv, dv.shape)
print("dk:", dk, dk.shape)
print("dq:", dq, dq.shape)
Expand Down
10 changes: 5 additions & 5 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def get_autotune_configs():
)
@triton.jit
def attn_fwd(Q, K, V, bias,
DESCALE_Q, DESCALE_K, DESCALE_V, DESCALE_P, stride_q_inv_scale_z, stride_kv_inv_scale_z, stride_p_inv_scale_z,
DESCALE_Q, DESCALE_K, DESCALE_V, DESCALE_P, stride_descale_q_z, stride_descale_kv_z, stride_descale_p_z,
SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah,
Expand Down Expand Up @@ -413,10 +413,10 @@ def attn_fwd(Q, K, V, bias,

# Load scale factors if IS_FP8.
if IS_FP8:
descale_q = tl.load(DESCALE_Q + off_z * stride_q_inv_scale_z + off_h_q)
descale_k = tl.load(DESCALE_K + off_z * stride_kv_inv_scale_z + off_h_k)
descale_v = tl.load(DESCALE_V + off_z * stride_kv_inv_scale_z + off_h_k)
descale_p = tl.load(DESCALE_P + off_z * stride_p_inv_scale_z + off_h_q)
descale_q = tl.load(DESCALE_Q + off_z * stride_descale_q_z + off_h_q)
descale_k = tl.load(DESCALE_K + off_z * stride_descale_kv_z + off_h_k)
descale_v = tl.load(DESCALE_V + off_z * stride_descale_kv_z + off_h_k)
descale_p = tl.load(DESCALE_P + off_z * stride_descale_p_z + off_h_q)
else:
descale_q, descale_k, descale_v, descale_p = 1.0, 1.0, 1.0, 1.0

Expand Down
8 changes: 8 additions & 0 deletions flash_attn/flash_attn_triton_amd/interface_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def bwd(
deterministic,
gen_,
rng_state,
descale_q,
descale_k,
descale_v,
descale_p
):
# NOTE: this might have perf costs
dq.zero_()
Expand Down Expand Up @@ -236,6 +240,10 @@ def bwd(
philox_seed,
philox_offset,
False,
descale_q = descale_q,
descale_k = descale_k,
descale_v = descale_v,
descale_p = descale_p
)
delta = delta_triton

Expand Down

0 comments on commit 537f11a

Please sign in to comment.