Skip to content

Commit

Permalink
remove bwd
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Sep 4, 2024
1 parent 3040c6b commit 9d87589
Showing 1 changed file with 2 additions and 47 deletions.
49 changes: 2 additions & 47 deletions flash_attn/flash_attn_triton_interface_amd.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import torch
import triton
from .flash_attn_triton_kernel_prefill_amd import MetaData, get_shape_from_layout, _attention_prefill, attention_prefill
from .flash_attn_triton_kernel_prefill_amd import MetaData, get_shape_from_layout, attention_prefill
from .flash_attn_triton_kernel_decode_amd import attention_decode



def fwd(q,
k,
v,
Expand Down Expand Up @@ -70,50 +68,7 @@ def bwd(
gen_,
rng_state,
):
# dummy context to call backward directly
class AttentionContext:
def __init__(self, q, k, v, o, M, sm_scale, causal, alibi_slopes, dropout_p, BLOCK_DMODEL):
self.saved_tensors = (q, k, v, o, M)
self.sm_scale = sm_scale
self.grid = lambda META: (triton.cdiv(q.shape[2], META['BLOCK_M']), q.shape[1], q.shape[0])
self.causal = causal
self.alibi_slopes = alibi_slopes
self.dropout_p = dropout_p
self.BLOCK_DMODEL = BLOCK_DMODEL
self.philox_seed = 0x1BF52
self.philox_offset = 0x1D4B42
self.return_encoded_softmax = False

def save_for_backward(self, q, k, v, o, M):
self.saved_tensors = (q, k, v, o, M)

if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD's Triton Backend yet")

if out is None:
out = torch.empty_like(q)

batch, max_seqlens_q, nheads_q, head_size = q.shape

# Transform inputs from bshd to bhsd layout
dout_bhsd = dout.permute(0, 2, 1, 3).contiguous()
q_bhsd = q.permute(0, 2, 1, 3).contiguous()
k_bhsd = k.permute(0, 2, 1, 3).contiguous()
v_bhsd = v.permute(0, 2, 1, 3).contiguous()
out_bhsd = out.permute(0, 2, 1, 3).contiguous() if out is not None else None

# Ensure all tensors have the same stride
dout_bhsd = dout_bhsd.view(dout_bhsd.shape)
q_bhsd = q_bhsd.view(q_bhsd.shape)
k_bhsd = k_bhsd.view(k_bhsd.shape)
v_bhsd = v_bhsd.view(v_bhsd.shape)
out_bhsd = out_bhsd.view(out_bhsd.shape) if out_bhsd is not None else None


ctx = AttentionContext(q_bhsd, k_bhsd, v_bhsd, out_bhsd, softmax_lse, softmax_scale, causal, alibi_slopes, dropout_p, head_size)
dq, dk, dv, softmax_lse, softmax_d = _attention_prefill.backward(ctx, dout_bhsd, None) # expect bhsd

return dq, dk, dv, softmax_d
raise ValueError("bwd is not supported on AMD's Triton Backend yet")

def varlen_fwd(
q,
Expand Down

0 comments on commit 9d87589

Please sign in to comment.