Skip to content

Commit

Permalink
mismatch found on columns greater than 64
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Dec 2, 2024
1 parent 1a24f0c commit 40f31a7
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 10 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ core.*
*.csv
*.png
*.html
*.json
*.json
*.txt
4 changes: 2 additions & 2 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import triton
import triton.language as tl
from .utils import get_shape_from_layout, get_strides_from_layout, DEBUG, PERF, write_tensor
from .utils import get_shape_from_layout, get_strides_from_layout, DEBUG, write_dropout_mask

@triton.jit
def _bwd_preprocess_use_p(
Expand Down Expand Up @@ -1018,7 +1018,7 @@ def attention_prefill_backward_triton_impl(
print("copy_back:", copy_back)
print("dropout_mask:", dropout_mask, dropout_mask.shape if dropout_mask is not None else None)
print("dropout_fraction bwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item())
write_tensor(dropout_mask, "dropout_mask_bwd")
write_dropout_mask(dropout_mask, "dropout_mask_bwd")

if copy_back["dq"]:
dq_og.copy_(dq)
Expand Down
4 changes: 2 additions & 2 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import triton
import triton.language as tl
from .utils import get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, DEBUG, AUTOTUNE, write_tensor
from .utils import get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, DEBUG, AUTOTUNE, write_dropout_mask

# Convenience function to load with optional boundary checks.
# "First" is the major dim, "second" is the minor dim.
Expand Down Expand Up @@ -616,6 +616,6 @@ def attention_prefill_forward_triton_impl(
print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None)
print("dropout_mask:", dropout_mask, dropout_mask.shape if dropout_mask is not None else None)
print("dropout_fraction fwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item())
write_tensor(dropout_mask, "dropout_mask_fwd")
write_dropout_mask(dropout_mask, "dropout_mask_fwd")

return o, softmax_lse, sd_mask.to(o.dtype) if return_softmax else None
36 changes: 31 additions & 5 deletions flash_attn/flash_attn_triton_amd/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import csv
import json
import math
import torch
import os
import triton
Expand Down Expand Up @@ -259,15 +260,40 @@ def get_padded_headsize(size):
padded_d_model = max(padded_d_model, 16)
return padded_d_model

def write_tensor(x, tensor_name = "tensor"):
def write_dropout_mask(x, tensor_name = "tensor"):
batch, head, seqlen_m, seqlen_n = x.shape
x = x.tolist()

with open(f'{tensor_name}.csv', 'w') as f:
writer = csv.writer(f)
writer.writerows(x)

with open(f'{tensor_name}.json', 'w') as f:
json.dump(x, f, indent=2)
for b in range(batch):
for h in range(head):
dropout_mask = x[b][h]
if True:
BLOCK_M = 64
BLOCK_N = 64

# Calculate number of blocks in each dimension
m_blocks = math.ceil(seqlen_m / BLOCK_M)
n_blocks = math.ceil(seqlen_n / BLOCK_N)

# Process each block
for m_block in range(m_blocks):
# Calculate row range for current block
row_start = m_block * BLOCK_M
row_end = min(row_start + BLOCK_M, seqlen_m)

for n_block in range(n_blocks):
# Calculate column range for current block
col_start = n_block * BLOCK_N
col_end = min(col_start + BLOCK_N, seqlen_n)

# Extract and write the current block
for row_idx in range(row_start, row_end):
row_data = dropout_mask[row_idx][col_start:col_end]
writer.writerow(row_data)
else:
writer.writerows(dropout_mask)

def _strides(x: torch.Tensor, *stride_names: str):
if x is None:
Expand Down

0 comments on commit 40f31a7

Please sign in to comment.