Skip to content

Commit

Permalink
forward out and lse are good. Something wrong with backward ref
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Oct 22, 2024
1 parent b5d663c commit bab936b
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 78 deletions.
2 changes: 0 additions & 2 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,6 @@ def attention_prefill_backward_triton_impl(
max_seqlen_q: int,
max_seqlen_k: int,
use_exp2: bool,
bwd_preprocessing_use_o: bool,
BLOCK_M=64,
BLOCK_N=64,
):
Expand All @@ -522,7 +521,6 @@ def attention_prefill_backward_triton_impl(
print("max_seqlen_q:", max_seqlen_q)
print("max_seqlen_k:", max_seqlen_k)
print("use_exp2:", use_exp2)
print("bwd_preprocessing_use_o:", bwd_preprocessing_use_o)
print("BLOCK_M:", BLOCK_M)
print("BLOCK_N:", BLOCK_N)

Expand Down
29 changes: 26 additions & 3 deletions flash_attn/flash_attn_triton_amd/bwd_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,29 +42,52 @@ def attention_backward_core_ref_impl(
print("p:", p)
# compute gradient wrt v
dv = torch.matmul(p.transpose(-2, -1), do.to(torch.float32))
if DEBUG:
print("dv:", dv)

# compute dp
dp = torch.matmul(do, v.transpose(-2, -1))
if DEBUG:
print("dp:", dp)

# calculate ds
if bwd_preprocessing_use_o:
delta = torch.sum(o * do, axis=-1).unsqueeze(-1).to(torch.float32) # what OAI kernel uses
else:
delta = torch.sum(p * dp, axis=-1).unsqueeze(-1) # what the math says you should use
if DEBUG:
print("delta:", delta)
ds = (p * (dp - delta)) * sm_scale
if DEBUG:
print("ds:", ds)


# compute gradient wrt k
dk = torch.matmul(ds.transpose(-2, -1), q.to(torch.float32))
if DEBUG:
print("dk:", dk)

# compute gradient wrt q
dq = torch.matmul(ds, k.to(torch.float32))
if DEBUG:
print("dq:", dq)

# cast back to original dtype
dq = dq.to(q.dtype)
dk = dk.to(k.dtype)
dv = dv.to(v.dtype)

return dq, dk, dv, delta.squeeze(-1)
# remove d dim with size 1
delta = delta.squeeze(-1)

if DEBUG:
print("attention_backward_core_ref_impl output")
print("dq:", dq, dq.shape)
print("dk:", dk, dk.shape)
print("dv:", dv, dv.shape)
print("delta:", delta, delta.shape)

return dq, dk, dv, delta

def attention_varlen_backward_pytorch_ref_impl(
do,
Expand Down Expand Up @@ -225,9 +248,9 @@ def attention_backward_pytorch_ref_impl(
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
use_exp2,
bwd_preprocessing_use_o,
use_exp2
):
bwd_preprocessing_use_o = True

if layout == "thd":
dq, dk, dv, delta = attention_varlen_backward_pytorch_ref_impl(
Expand Down
27 changes: 1 addition & 26 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

DEBUG = False


@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y
Expand Down Expand Up @@ -602,28 +601,4 @@ def attention_prefill_forward_triton_impl_explicit(
if is_varlen:
softmax_lse = softmax_lse.transpose(2, 1).reshape(-1, nheads_q).contiguous()

return o, softmax_lse, exp_scores, grid, head_size, philox_seed, philox_offset, scores, scores_scaled_shifted


def attention_prefill_forward_triton_impl(q, k, v, o, metadata):
if o is None:
o = torch.empty_like(q, dtype=v.dtype)
metadata.check_args(q, k, v, o)

return attention_prefill_forward_triton_impl_explicit(
q,
k,
v,
o,
metadata.sm_scale,
metadata.alibi_slopes,
metadata.causal,
metadata.bias,
metadata.dropout_p,
metadata.layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
metadata.return_scores,
metadata.use_exp2)
return o, softmax_lse, exp_scores, grid, head_size, philox_seed, philox_offset, scores, scores_scaled_shifted
145 changes: 104 additions & 41 deletions flash_attn/flash_attn_triton_amd/interface_fa.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch
import triton
from .fwd_prefill import attention_prefill_forward_triton_impl
from .fwd_prefill import attention_prefill_forward_triton_impl_explicit
from .bwd_prefill import attention_prefill_backward_triton_impl
from .fwd_decode import attention_decode_forward_triton_impl
from .fwd_ref import attention_forward_pytorch_ref_impl
from .bwd_ref import attention_backward_pytorch_ref_impl
from .utils import MetaData, get_shape_from_layout

DEBUG = False
DEBUG = True

def fwd(q,
k,
Expand Down Expand Up @@ -135,14 +136,9 @@ def bwd(
None,
None,
None,
False,
True,
False
)

if DEBUG:
print("dq:", dq, dq.shape)


return dq, dk, dv, None

def varlen_fwd(
Expand Down Expand Up @@ -193,28 +189,75 @@ def varlen_fwd(
o = torch.empty_like(q)

# Setup metadata
input_metadata = MetaData(sm_scale=softmax_scale)
input_metadata.use_exp2 = False
metadata = MetaData(sm_scale=softmax_scale)
metadata.use_exp2 = False
if return_softmax:
input_metadata.return_encoded_softmax = True
input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # set layout to "thd" and other metdata
metadata.return_encoded_softmax = True
metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # set layout to "thd" and other metdata

# get shapes
batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shape_from_layout(q, k, input_metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shape_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)

if causal:
input_metadata.need_causal()
metadata.need_causal()

if alibi_slopes is not None:
input_metadata.need_alibi(alibi_slopes, batch, nheads_q)
metadata.need_alibi(alibi_slopes, batch, nheads_q)

if dropout_p > 0.0:
input_metadata.need_dropout(dropout_p, return_softmax)
metadata.need_dropout(dropout_p, return_softmax)

# Check arguments
input_metadata.check_args(q, k, v, o)
metadata.check_args(q, k, v, o)
if o is None:
o = torch.empty_like(q, dtype=v.dtype)

if False:
(o_triton,
softmax_lse,
exp_scores,
grid,
head_size,
philox_seed,
philox_offset,
scores,
scores_scaled_shifted) = attention_prefill_forward_triton_impl_explicit(
q,
k,
v,
o,
metadata.sm_scale,
metadata.alibi_slopes,
metadata.causal,
metadata.bias,
metadata.dropout_p,
metadata.layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
metadata.return_scores,
metadata.use_exp2)
else:
(o_triton,
softmax_lse,
exp_scores,
_,
_,
_,
_) = attention_forward_pytorch_ref_impl(
q,
k,
v,
metadata.sm_scale,
metadata.causal,
metadata.layout,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
metadata.use_exp2)

o_triton, softmax_lse, exp_scores, grid, head_size, philox_seed, philox_offset, scores, scores_scaled_shifted = attention_prefill_forward_triton_impl(q, k, v, o, input_metadata)

return o_triton, q , k , v, o, softmax_lse, exp_scores, None

Expand Down Expand Up @@ -273,29 +316,49 @@ def varlen_bwd(
if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD yet")

_, _, _, _, _, _ = attention_prefill_backward_triton_impl(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
softmax_scale,
alibi_slopes,
causal,
"thd",
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
False,
True,
)
if False:
_, _, _, _, _, _ = attention_prefill_backward_triton_impl(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
softmax_scale,
alibi_slopes,
causal,
"thd",
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
False,
)
else:
dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl(
dout,
q,
k,
v,
out,
softmax_lse,
softmax_scale,
causal,
"thd",
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
False,
)
dq.copy_(dq_ref)
dk.copy_(dk_ref)
dv.copy_(dv_ref)
softmax_d = delta_ref

softmax_d = None
return dq, dk, dv, softmax_d

def fwd_kvcache(
Expand Down
17 changes: 11 additions & 6 deletions tests/test_flash_attn_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,7 @@ def test_flash_attn_output(
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(2, 2)
(4, 4)
# (1, 147),
# (113, 203),
# (128, 217),
Expand Down Expand Up @@ -1255,20 +1255,20 @@ def test_flash_attn_varlen_output(
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
q = torch.ones(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if softcap > 0:
# Ensure the values of qk are at least within softcap range.
q = q * softcap

if kvpacked:
kv = torch.randn(
kv = torch.ones(
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
else:
k = torch.randn(
k = torch.ones(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
v = torch.randn(
v = torch.ones(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)

Expand Down Expand Up @@ -1344,6 +1344,11 @@ def test_flash_attn_varlen_output(
deterministic=deterministic,
return_attn_probs=True,
)
if True:
print("out_unpad:", out_unpad, out_unpad.shape)
print("sm_lse:", sm_lse, sm_lse.shape)


out = output_pad_fn(out_unpad)
if dropout_p > 0.0:
S_dmask_converted = convert_flash_attn_S_to_softmax(
Expand Down Expand Up @@ -1453,7 +1458,7 @@ def test_flash_attn_varlen_output(
print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")

g = torch.randn_like(out)
g = torch.ones_like(out)
if ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)):
if kvpacked:
(
Expand Down

0 comments on commit bab936b

Please sign in to comment.