Skip to content

Commit

Permalink
Enable MQA/GQA in backward (#100)
Browse files Browse the repository at this point in the history
* simple failing test

* ref is working

* fix bug

* save

* find failing case

* fowrad varlen mqa/gqa works

* add mqa configs to bwd test

* varlen bwd ref fixed

* save failing case

* GQA flag

* ones passes

* go back to values

* save

* bhsd working with mqa

* remove repo

* test layouts

* clean up

* test back to normal

* clean up more

* use zeros_like

* zero out
  • Loading branch information
micmelesse authored Nov 15, 2024
1 parent 4159ab3 commit 8947040
Show file tree
Hide file tree
Showing 6 changed files with 428 additions and 248 deletions.
108 changes: 55 additions & 53 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,8 @@ def _bwd_kernel_one_col_block(
stride_deltaz,
stride_deltah,
stride_deltam,
Z,
H,
N_CTX_Q,
N_CTX_K,
off_h,
off_z,
off_hz,
start_n,
num_block_m,
num_block_n,
Expand All @@ -129,6 +124,7 @@ def _bwd_kernel_one_col_block(
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
USE_EXP2: tl.constexpr,
GROUP_SIZE: tl.constexpr,
):
if CAUSAL:
# TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M
Expand All @@ -153,8 +149,8 @@ def _bwd_kernel_one_col_block(
# load k and v once per column block
k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk
k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
v = tl.load(v_ptrs, mask=kv_mask, other=0.0)
k = tl.load(k_ptrs, mask=kv_mask, other=0.0).to(tl.float32)
v = tl.load(v_ptrs, mask=kv_mask, other=0.0).to(tl.float32)

# loop over rows
for start_m in range(lo, num_block_m * BLOCK_M, BLOCK_M):
Expand All @@ -168,8 +164,8 @@ def _bwd_kernel_one_col_block(
q_mask = mask_m[:, None] & mask_d[None, :]

# load q, k, v, do on-chip
q = tl.load(q_ptrs, mask=q_mask, other=0.0)
do = tl.load(do_ptrs, mask=q_mask, other=0.0)
q = tl.load(q_ptrs, mask=q_mask, other=0.0).to(tl.float32)
do = tl.load(do_ptrs, mask=q_mask, other=0.0).to(tl.float32)

# recompute p = softmax(qk, dim=-1).T
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
Expand All @@ -196,9 +192,10 @@ def _bwd_kernel_one_col_block(
# mask block in the cases where the data is smaller the block size
p_mask = mask_m[:, None] & mask_n[None, :]
p = tl.where(p_mask, p, 0.0)
p = p.to(tl.float32)

# compute dv
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
dv += tl.dot(tl.trans(p), do)

# compute dp
dp = tl.dot(do, tl.trans(v))
Expand All @@ -207,7 +204,7 @@ def _bwd_kernel_one_col_block(
d_ptrs = d_offset + offs_m * stride_deltam
Di = tl.load(d_ptrs, mask=mask_m)
ds = (p * (dp - Di[:, None])) * sm_scale
ds = tl.where(p_mask, ds, 0.0).to(Q.dtype.element_ty)
ds = tl.where(p_mask, ds, 0.0)

# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds), q)
Expand All @@ -225,8 +222,13 @@ def _bwd_kernel_one_col_block(
dv_ptrs = dv_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk

# write-back
tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask)
tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask)
if GROUP_SIZE != 1:
# use atomic_add to properly accumulate gradients from multiple query heads
tl.atomic_add(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask)
tl.atomic_add(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask)
else:
tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask)
tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask)

@triton.jit
def _bwd_kernel(
Expand Down Expand Up @@ -258,7 +260,8 @@ def _bwd_kernel(
stride_deltah,
stride_deltam,
Z,
H,
HQ,
HK,
num_block_m,
num_block_n,
cu_seqlens_q,
Expand All @@ -275,11 +278,17 @@ def _bwd_kernel(
IS_VARLEN: tl.constexpr,
):
# program ids
off_hz = tl.program_id(0)
off_zh = tl.program_id(0)
if SEQUENCE_PARALLEL:
start_n = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
off_z = off_zh // HQ
off_hq = off_zh % HQ

GROUP_SIZE = HQ // HK
if GROUP_SIZE != 1:
off_hk = off_hq // GROUP_SIZE
else:
off_hk = off_hq

if IS_VARLEN:
# Compute sequence lengths for the current batch
Expand All @@ -299,20 +308,20 @@ def _bwd_kernel(


# input tensor offsets
q_offset = Q + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
k_offset = K + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
v_offset = V + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
do_offset = DO + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
l_offset = L + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
d_offset = D + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
q_offset = Q + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm
k_offset = K + off_z * stride_kz + off_hk * stride_kh + k_start * stride_kn
v_offset = V + off_z * stride_vz + off_hk * stride_vh + k_start * stride_vn
do_offset = DO + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm
l_offset = L + off_z * stride_deltaz + off_hq * stride_deltah + q_start * stride_deltam
d_offset = D + off_z * stride_deltaz + off_hq * stride_deltah + q_start * stride_deltam

# output tensor offsets
dk_offset = DK + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
dv_offset = DV + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
dk_offset = DK + off_z * stride_kz + off_hk * stride_kh + k_start * stride_kn
dv_offset = DV + off_z * stride_vz + off_hk * stride_vh + k_start * stride_vn
if SEQUENCE_PARALLEL:
dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm
else:
dq_offset = DQ + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
dq_offset = DQ + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm

# inner loop
if SEQUENCE_PARALLEL:
Expand Down Expand Up @@ -353,13 +362,8 @@ def _bwd_kernel(
stride_deltaz,
stride_deltah,
stride_deltam,
Z,
H,
N_CTX_Q,
N_CTX_K,
off_h,
off_z,
off_hz,
start_n,
num_block_m,
num_block_n,
Expand All @@ -370,6 +374,7 @@ def _bwd_kernel(
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
USE_EXP2=USE_EXP2,
GROUP_SIZE=GROUP_SIZE
)
else:
for start_n in range(0, num_block_n):
Expand Down Expand Up @@ -410,13 +415,8 @@ def _bwd_kernel(
stride_deltaz,
stride_deltah,
stride_deltam,
Z,
H,
N_CTX_Q,
N_CTX_K,
off_h,
off_z,
off_hz,
start_n,
num_block_m,
num_block_n,
Expand All @@ -427,6 +427,7 @@ def _bwd_kernel(
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
USE_EXP2=USE_EXP2,
GROUP_SIZE=GROUP_SIZE
)


Expand Down Expand Up @@ -454,7 +455,7 @@ def attention_prefill_backward_triton_impl(
):
if DEBUG:
print()
print("attention_prefill_backward_triton_new_impl")
print("attention_prefill_backward_triton_impl")
print("do:", do, do.shape)
print("q:", q, q.shape)
print("k:", k, k.shape)
Expand Down Expand Up @@ -488,7 +489,6 @@ def attention_prefill_backward_triton_impl(
stride_kz, stride_kh, stride_kn, stride_kk = k_strides
stride_vz, stride_vh, stride_vn, stride_vk = v_strides
stride_oz, stride_oh, stride_om, stride_ok = o_strides
batch_headsize = batch * nheads_q
is_varlen = layout == "thd"

# FIXME: some configs lead to oom for some reason when using 64 x 64 blocks
Expand Down Expand Up @@ -538,22 +538,30 @@ def attention_prefill_backward_triton_impl(

# deal with dk, dv
if (dk is None) or (dv is None):
dk = torch.empty_like(k)
dv = torch.empty_like(v)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
else:
# store og
dk_og = dk
dv_og = dv


if (not dk.is_contiguous()):
dk_og = dk
dk = dk.contiguous()
copy_back["dk"] = True

if (not dv.is_contiguous()):
dv_og = dv
dv = dv.contiguous()
copy_back["dv"] = True

if DEBUG:
print("copy_back:", copy_back)

# zero out
dq.zero_()
dk.zero_()
dv.zero_()

# assert contigious
assert do.is_contiguous()
assert q.is_contiguous()
Expand All @@ -570,7 +578,7 @@ def attention_prefill_backward_triton_impl(
else:
stride_deltaz, stride_deltah, stride_deltam = delta.stride()

_bwd_preprocess_use_o[(num_blocks_m, batch_headsize)](
_bwd_preprocess_use_o[(num_blocks_m, batch * nheads_q)](
o,
do,
delta,
Expand Down Expand Up @@ -622,7 +630,7 @@ def attention_prefill_backward_triton_impl(
print("num_blocks_m:", num_blocks_m)
print("num_blocks_n:", num_blocks_n)

_bwd_kernel[(batch_headsize, num_blocks_n if sequence_parallel else 1)](
_bwd_kernel[(batch * nheads_q, num_blocks_n if sequence_parallel else 1)](
q,
k,
v,
Expand All @@ -641,6 +649,7 @@ def attention_prefill_backward_triton_impl(
stride_deltaz, stride_deltah, stride_deltam,
batch,
nheads_q,
nheads_k,
num_blocks_m,
num_blocks_n,
cu_seqlens_q,
Expand All @@ -660,18 +669,11 @@ def attention_prefill_backward_triton_impl(
IS_VARLEN=is_varlen
)

if DEBUG:
print("_bwd_kernel outputs")
print("dq:", dq, dq.shape)
print("dk:", dk, dk.shape)
print("dv:", dv, dv.shape)
print("delta:", delta, delta.shape)

if sequence_parallel:
dq = dq.sum(dim=0)

if DEBUG:
print("attention_prefill_backward_triton_new_impl outputs")
print("attention_prefill_backward_triton_impl outputs")
print("dq:", dq, dq.shape)
print("dk:", dk, dk.shape)
print("dv:", dv, dv.shape)
Expand Down
Loading

0 comments on commit 8947040

Please sign in to comment.