Skip to content

Commit

Permalink
Clean
Browse files Browse the repository at this point in the history
Clean

This is a combination of 4 commits.

clean 1

clean 2

clean more

match main

typo fix
  • Loading branch information
micmelesse committed Aug 29, 2024
1 parent 7b8a15c commit 8c701b3
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 631 deletions.
39 changes: 4 additions & 35 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ on:
workflow_dispatch:
pull_request:
branches: [main_perf]
merge_group:
branches: [main_perf]
types: [checks_requested]
push:
merge_group:
branches: [main_perf]
types: [checks_requested]
push:
branches: [main_perf, micmelesse/upstream_pr]

concurrency:
group: ${{ github.ref }}
Expand Down Expand Up @@ -57,37 +57,6 @@ jobs:
- name: Build
run: |
python setup.py install
# - name: Flash Attention Mini Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_output
# - name: Flash Attention qkvpacked Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_qkvpacked
# pytest tests/test_flash_attn.py::test_flash_attn_varlen_qkvpacked
# - name: Flash Attention output Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_output
# pytest tests/test_flash_attn.py::test_flash_attn_varlen_output
# - name: Flash Attention causal Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_causal
# pytest tests/test_flash_attn.py::test_flash_attn_varlen_causal
# - name: Flash Attention kvcache Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_kvcache
# pytest tests/test_flash_attn.py::test_flash_attn_splitkv
# - name: Flash Attention race condition Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_race_condition
# - name: Flash Attention bwd Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_bwd_overflow
# pytest tests/test_flash_attn.py::test_flash_attn_bwd_transpose
# pytest tests/test_flash_attn.py::test_flash_attn_bwd_varlen_overflow
# - name: Flash Attention deterministic Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_deterministic
# pytest tests/test_flash_attn.py::test_flash_attn_varlen_deterministic
- name: Flash Attention Tests
run: |
pytest tests/test_flash_attn.py
Expand Down
9 changes: 1 addition & 8 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,4 @@ var/

# Dev
venv

# AMD
.eggs
.vscode
core
scripts
log*
*csv
scripts
192 changes: 19 additions & 173 deletions flash_attn/flash_attn_triton_interface_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,7 @@
from .flash_attn_triton_kernel_prefill_amd import MetaData, get_shape_from_layout, _attention_prefill, attention_prefill
from .flash_attn_triton_kernel_decode_amd import attention_decode

DEBUG = False

class AttentionContext:
def __init__(self, q, k, v, o, M, sm_scale, causal, alibi_slopes, dropout_p, BLOCK_DMODEL):
self.saved_tensors = (q, k, v, o, M)
self.sm_scale = sm_scale
self.grid = lambda META: (triton.cdiv(q.shape[2], META['BLOCK_M']), q.shape[1], q.shape[0])
self.causal = causal
self.alibi_slopes = alibi_slopes
self.dropout_p = dropout_p
self.BLOCK_DMODEL = BLOCK_DMODEL
self.philox_seed = 0x1BF52
self.philox_offset = 0x1D4B42
self.return_encoded_softmax = False

def save_for_backward(self, q, k, v, o, M):
self.saved_tensors = (q, k, v, o, M)


def fwd(q,
k,
Expand All @@ -34,21 +18,6 @@ def fwd(q,
softcap,
return_softmax,
gen_):
if DEBUG:
print()
print("flash_attn_triton_amd.py::fwd")
print("q:", q, q.shape)
print("k:", k, k.shape)
print("v:", v, v.shape)
print("alibi_slopes:", alibi_slopes)
print("dropout_p:", dropout_p)
print("softmax_scale:", softmax_scale)
print("causal:", causal)
print("window_size_left:", window_size_left)
print("window_size_right:", window_size_right)
print("softcap", softcap)
print("return_softmax:", return_softmax)
print("gen_:", gen_)

if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD yet")
Expand Down Expand Up @@ -101,27 +70,22 @@ def bwd(
gen_,
rng_state,
):
if DEBUG:
print()
print("flash_attn_triton_amd.py::bwd")
print("dout:", dout, dout.shape, dout.stride())
print("q:", q, q.shape, q.stride())
print("k:", k, k.shape, k.stride())
print("v:", v, v.shape, v.stride())
print("softmax_lse:", softmax_lse)
print("dq:", dq, dq.shape, dq.stride())
print("dk:", dk, dk.shape, dk.stride())
print("dv:", dv, dv.shape, dv.stride())
print("alibi_slopes:", alibi_slopes)
print("dropout_p:", dropout_p)
print("out:", out)
print("softmax_scale:", softmax_scale)
print("causal:", causal)
print("window_size_left:", window_size_left)
print("window_size_right:", window_size_right)
print("deterministic:", deterministic)
print("gen_:", gen_)
print("rng_state:", rng_state)
# dummy context to call backward directly
class AttentionContext:
def __init__(self, q, k, v, o, M, sm_scale, causal, alibi_slopes, dropout_p, BLOCK_DMODEL):
self.saved_tensors = (q, k, v, o, M)
self.sm_scale = sm_scale
self.grid = lambda META: (triton.cdiv(q.shape[2], META['BLOCK_M']), q.shape[1], q.shape[0])
self.causal = causal
self.alibi_slopes = alibi_slopes
self.dropout_p = dropout_p
self.BLOCK_DMODEL = BLOCK_DMODEL
self.philox_seed = 0x1BF52
self.philox_offset = 0x1D4B42
self.return_encoded_softmax = False

def save_for_backward(self, q, k, v, o, M):
self.saved_tensors = (q, k, v, o, M)

if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD yet")
Expand All @@ -147,43 +111,9 @@ def bwd(


ctx = AttentionContext(q_bhsd, k_bhsd, v_bhsd, out_bhsd, softmax_lse, softmax_scale, causal, alibi_slopes, dropout_p, head_size)
dq, dk, dv, _, _ = _attention_prefill.backward(ctx, dout_bhsd, None) # expect bhsd

softmax_d = None # not sure what softmax_d is supposed to be
if DEBUG:
print()
print("bwd output")
print("dq:", dq, dq.shape)
print("dk:", dk, dk.shape)
print("dv:", dv, dv.shape)
print("softmax_d:", softmax_d)
print()
return dq, dk, dv, softmax_d


dq, dk, dv, softmax_lse, softmax_d = _attention_prefill.backward(ctx, dout_bhsd, None) # expect bhsd

def bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
window_size_left,
window_size_right,
softcap,
deterministic,
gen_,
rng_state,
):
raise ValueError("bwd is not supported on AMD yet")
return dq, dk, dv, softmax_d

def varlen_fwd(
q,
Expand All @@ -207,30 +137,6 @@ def varlen_fwd(
softcap,
return_softmax,
gen_):

if DEBUG:
print()
print("flash_attn_triton_amd.py::varlen_fwd")
print("q:", q, q.shape)
print("k:", k, k.shape)
print("v:", v, v.shape)
print("cu_seqlens_q:", cu_seqlens_q)
print("cu_seqlens_k:", cu_seqlens_k)
print("seqused_k:", seqused_k)
print("leftpad_k:", leftpad_k)
print("block_table_:", block_table_)
print("alibi_slopes:", alibi_slopes)
print("max_seqlen_q:", max_seqlen_q)
print("max_seqlen_k:", max_seqlen_k)
print("dropout_p:", dropout_p)
print("softmax_scale:", softmax_scale)
print("zero_tensors:", zero_tensors)
print("causal:", causal)
print("window_size_left:", window_size_left)
print("window_size_right:", window_size_right)
print("softcap", softcap)
print("return_softmax:", return_softmax)
print("gen_:", gen_)

if dropout_p != 0.0:
raise ValueError("dropout is not supported on AMD yet")
Expand Down Expand Up @@ -263,37 +169,6 @@ def varlen_fwd(

return tri_out, q , k , v, o, softmax_lse, softmax_dmask, None

def varlen_bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
zero_tensors,
causal,
window_size_left,
window_size_right,
deterministic,
gen_,
rng_state,
):
if DEBUG:
print()
print("flash_attn_triton_amd.py::varlen_bwd")

raise ValueError("varlen_bwd is not supported on AMD yet")

def varlen_bwd(
dout,
q,
Expand Down Expand Up @@ -344,30 +219,6 @@ def fwd_kvcache(
rotary_interleaved,
num_splits):

if DEBUG:
print()
print("flash_attn_triton_amd.py::fwd_kvcache")
print("q:", q, q.shape)
print("k_cache:", k_cache, k_cache.shape)
print("v_cache:", v_cache, v_cache.shape)
print("k:", k, k.shape if k is not None else None)
print("v:", v, v.shape if v is not None else None)
print("cache_seqlens:", cache_seqlens, cache_seqlens.size())
print("rotary_cos:", rotary_cos)
print("rotary_sin:", rotary_sin)
print("cache_batch_idx:", cache_batch_idx)
print("cache_leftpad", cache_leftpad)
print("block_table:", block_table, block_table.shape if block_table is not None else None)
print("alibi_slopes:", alibi_slopes)
print("out:", out)
print("softmax_scale:", softmax_scale)
print("causal:", causal)
print("window_size_left:", window_size_left)
print("window_size_right:", window_size_right)
print("softcap", softcap)
print("rotary_interleaved:", rotary_interleaved)
print("num_splits:", num_splits)

if out is None:
out = torch.empty_like(q)

Expand All @@ -394,9 +245,4 @@ def fwd_kvcache(

# launch kernel
tri_out, softmax_lse = attention_decode(q, k_cache, v_cache, input_metadata)

if DEBUG:
print()
print("tri_out:", tri_out, tri_out.shape)

return tri_out, softmax_lse
Loading

0 comments on commit 8c701b3

Please sign in to comment.