Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Oct 26, 2024
1 parent 71cb8ea commit 9411609
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 228 deletions.
184 changes: 0 additions & 184 deletions flash_attn/flash_attn_triton_amd/bench_old.py

This file was deleted.

56 changes: 12 additions & 44 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _bwd_kernel_one_col_block(
USE_EXP2: tl.constexpr,
):
if CAUSAL:
# lo = start_n * BLOCK_M
# TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M
lo = 0
else:
lo = 0
Expand All @@ -157,8 +157,6 @@ def _bwd_kernel_one_col_block(
v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk
k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
v = tl.load(v_ptrs, mask=kv_mask, other=0.0)
# print("k:", k)
# print("v:", v)

# loop over rows
for start_m in range(lo, num_block_m * BLOCK_M, BLOCK_M):
Expand All @@ -174,8 +172,6 @@ def _bwd_kernel_one_col_block(
# load q, k, v, do on-chip
q = tl.load(q_ptrs, mask=q_mask, other=0.0)
do = tl.load(do_ptrs, mask=q_mask, other=0.0)
# print("q:", q)
# print("do:", do)

# recompute p = softmax(qk, dim=-1).T
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
Expand All @@ -184,15 +180,10 @@ def _bwd_kernel_one_col_block(
if CAUSAL:
col_offset = N_CTX_Q - N_CTX_K
causal_mask = offs_m[:, None] >= (col_offset + offs_n[None, :])
# print("causal_mask:", causal_mask)

# Apply the mask
qk = tl.where(causal_mask, qk, float("-inf"))
# print("qk after causal:", qk)
# print("qk:", qk)

l_ptrs = l_offset + offs_m * stride_deltam
l_i = tl.load(l_ptrs, mask=mask_m)
# print("l_i:", l_i)

# compute p
if USE_EXP2:
Expand All @@ -203,25 +194,21 @@ def _bwd_kernel_one_col_block(
else:
qk *= sm_scale
p = tl.math.exp(qk - l_i[:, None])
# print("p:", p)

# mask block in the cases where the data is smaller the block size
p_mask = mask_m[:, None] & mask_n[None, :]
p = tl.where(p_mask, p, 0.0)
# print("p masked:", p)

# compute dv
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
# print("dv:", dv)

# compute dp
dp = tl.dot(do, tl.trans(v))
# print("dp:", dp)

# compute ds , ds = p * (dp - delta[:, None])
d_ptrs = d_offset + offs_m * stride_deltam
Di = tl.load(d_ptrs, mask=mask_m)
ds = (p * (dp - Di[:, None])) * sm_scale
# print("ds:", ds)
ds = tl.where(p_mask, ds, 0.0).to(Q.dtype.element_ty)

# compute dk = dot(ds.T, q)
Expand All @@ -238,19 +225,11 @@ def _bwd_kernel_one_col_block(
# write-back dv and dk
dk_ptrs = dk_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
dv_ptrs = dv_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk

# write-back
# print("dv:", dv)
tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask)
tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask)

# @triton.autotune(
# configs=[
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
# num_warps=4),
# ],
# key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
# use_cuda_graph=True,
# )
@triton.jit
def _bwd_kernel(
Q,
Expand Down Expand Up @@ -496,32 +475,21 @@ def attention_prefill_backward_triton_impl(
print("max_seqlen_k:", max_seqlen_k)
print("use_exp2:", use_exp2)



# make contigious
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
softmax_lse = softmax_lse.contiguous()

# get strides and shape
if True:
batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout)
stride_qz, stride_qh, stride_qm, stride_qk = q_strides
stride_kz, stride_kh, stride_kn, stride_kk = k_strides
stride_vz, stride_vh, stride_vn, stride_vk = v_strides
stride_oz, stride_oh, stride_om, stride_ok = o_strides
stride_dq_all = q.numel()
batch_headsize = batch * nheads_q
else:
batch_q, heads_q, seqlen_q, head_size_q = q.shape
batch_k, heads_k, seqlen_k, head_size_k = k.shape
batch_headsize = batch_q * heads_q
stride_dq_all = dq.numel()
stride_qz, stride_qh, stride_qm, stride_qk = q.stride(0), q.stride(1), q.stride(2), q.stride(3)
stride_kz, stride_kh, stride_kn, stride_kk = k.stride(0), k.stride(1), k.stride(2), k.stride(3)
stride_vz, stride_vh, stride_vn, stride_vk = v.stride(0), v.stride(1), v.stride(2), v.stride(3)
batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout)
stride_qz, stride_qh, stride_qm, stride_qk = q_strides
stride_kz, stride_kh, stride_kn, stride_kk = k_strides
stride_vz, stride_vh, stride_vn, stride_vk = v_strides
stride_oz, stride_oh, stride_om, stride_ok = o_strides
stride_dq_all = q.numel()
batch_headsize = batch * nheads_q
is_varlen = layout == "thd"

# FIXME: some configs lead to oom for some reason when using 64 x 64 blocks
Expand Down

0 comments on commit 9411609

Please sign in to comment.