Skip to content

Commit 387488b

Browse files
authored
Add unit test and fix for flash_4 (#108)
1 parent 92f7296 commit 387488b

File tree

2 files changed

+49
-4
lines changed

2 files changed

+49
-4
lines changed

segment_anything_fast/flash_4.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,14 @@ def _fwd_kernel_aligned(
107107
v = tl.load(V_block_ptr)
108108
# -- compute qk ---
109109
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=OUT_DTYPE)
110-
qk += tl.dot(q, k, out_dtype=OUT_DTYPE)
110+
qk += tl.dot(q, k) #, out_dtype=OUT_DTYPE)
111111

112112
# -- compute rel_h[:, None] + rel_w[None, :] bias ---
113113

114114
# Bias
115115
b0 = tl.load(B0 + b_offset + ((start_m * BLOCK_M + b_ptr_offsets_m)
116116
* stride_b0m)[:, None] + start_n // BLOCK_N)
117-
qk += (b0 + b1)
117+
qk += ((b0 + b1) * 1.44269504)
118118

119119
# -- compute scaling constant ---
120120
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
@@ -198,6 +198,7 @@ def _attention_rel_h_rel_w_kernel_aligned_device(q, k, v, rel_h_w, sm_scale, o,
198198
P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]
199199
assert P_SEQ == 0
200200
assert rel_h_w.is_contiguous(), str(rel_h_w.stride())
201+
OUT_DTYPE = tl.float16 if q.dtype == torch.float16 else tl.bfloat16
201202
_fwd_kernel_aligned[grid](
202203
q, k, v,
203204
rel_h_w,
@@ -212,7 +213,7 @@ def _attention_rel_h_rel_w_kernel_aligned_device(q, k, v, rel_h_w, sm_scale, o,
212213
q.shape[1],
213214
q.shape[2],
214215
P_SEQ,
215-
OUT_DTYPE=tl.float16 if q.dtype == torch.float16 else tl.bfloat16,
216+
OUT_DTYPE=OUT_DTYPE,
216217
BIAS_LAST_SIZE=(rel_h_w.size(-1) // 2),
217218
B0_NUMEL=rel_h_w.size(-1),
218219
BLOCK_M=BLOCK_M,
@@ -346,7 +347,8 @@ def _attention_rel_h_rel_w(q_, k_, v_, rel_h_, rel_w_):
346347
def kernel_guards(q_, k_, v_):
347348
return (q_.dtype == torch.bfloat16 or q_.dtype == torch.float16) and q_.dtype == k_.dtype and k_.dtype == v_.dtype and USE_CUSTOM_KERNEL
348349
# vit_b and vit_l
349-
if q_size_2_padded == 0 and q_.size(-1) == 64 and kernel_guards(q_, k_, v_):
350+
# TODO: This kernel currently does not produce correct results for batch size 1 for this case
351+
if q_.size(0) > 1 and q_size_2_padded == 0 and q_.size(-1) == 64 and kernel_guards(q_, k_, v_):
350352
rel_h_w = torch.cat([rel_h_.squeeze(-1), rel_w_.squeeze(-2)], dim=-1)
351353
o = torch.ops.customflash.custom_flash_aligned(
352354
q_, k_, v_, rel_h_w, sm_scale)

test/test_flash_4.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch
2+
import itertools
3+
from segment_anything_fast.flash_4 import _attention_rel_h_rel_w
4+
5+
def test_op(batch, head, seq_len, hidden_dim, dtype):
6+
import math
7+
8+
sm_scale = 1.0 / math.sqrt(hidden_dim)
9+
device = "cuda"
10+
torch.manual_seed(20)
11+
q = torch.empty(
12+
(batch, head, seq_len, hidden_dim), dtype=dtype, device=device
13+
).normal_(mean=0.0, std=0.5)
14+
k = torch.empty(
15+
(batch, head, seq_len, hidden_dim), dtype=dtype, device=device
16+
).normal_(mean=0.0, std=0.5)
17+
v = torch.empty(
18+
(batch, head, seq_len, hidden_dim), dtype=dtype, device=device
19+
).normal_(mean=0.0, std=0.5)
20+
w = int((seq_len) ** 0.5)
21+
assert w * w == seq_len, "seq_len must be a perfect square"
22+
23+
rel_h = torch.empty(
24+
(batch, head, seq_len, w, 1), dtype=dtype, device=device
25+
).normal_(mean=0, std=0.5)
26+
rel_w = torch.empty(
27+
(batch, head, seq_len, 1, w), dtype=dtype, device=device
28+
).normal_(mean=0, std=0.5)
29+
30+
tri_out = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
31+
# reference implementation
32+
attn_bias = (rel_h + rel_w).view(
33+
q.size(0), q.size(1), rel_h.size(2), rel_h.size(3) * rel_w.size(4)
34+
)
35+
ref_out = torch.nn.functional.scaled_dot_product_attention(
36+
q, k, v, attn_mask=attn_bias
37+
)
38+
39+
torch.testing.assert_close(ref_out, tri_out, rtol=1e-3, atol=1e-3)
40+
41+
for batch, (head, seq_len), dtype in itertools.product([1, 8], [(16, 80), (12, 64)], [torch.float16, torch.bfloat16]):
42+
print(f"batch: {batch} head: {head} seq_len: {seq_len} dtype: {dtype}")
43+
test_op(batch, head, 4096, seq_len, dtype)

0 commit comments

Comments
 (0)