|
| 1 | +import torch |
| 2 | +import itertools |
| 3 | +from segment_anything_fast.flash_4 import _attention_rel_h_rel_w |
| 4 | + |
| 5 | +def test_op(batch, head, seq_len, hidden_dim, dtype): |
| 6 | + import math |
| 7 | + |
| 8 | + sm_scale = 1.0 / math.sqrt(hidden_dim) |
| 9 | + device = "cuda" |
| 10 | + torch.manual_seed(20) |
| 11 | + q = torch.empty( |
| 12 | + (batch, head, seq_len, hidden_dim), dtype=dtype, device=device |
| 13 | + ).normal_(mean=0.0, std=0.5) |
| 14 | + k = torch.empty( |
| 15 | + (batch, head, seq_len, hidden_dim), dtype=dtype, device=device |
| 16 | + ).normal_(mean=0.0, std=0.5) |
| 17 | + v = torch.empty( |
| 18 | + (batch, head, seq_len, hidden_dim), dtype=dtype, device=device |
| 19 | + ).normal_(mean=0.0, std=0.5) |
| 20 | + w = int((seq_len) ** 0.5) |
| 21 | + assert w * w == seq_len, "seq_len must be a perfect square" |
| 22 | + |
| 23 | + rel_h = torch.empty( |
| 24 | + (batch, head, seq_len, w, 1), dtype=dtype, device=device |
| 25 | + ).normal_(mean=0, std=0.5) |
| 26 | + rel_w = torch.empty( |
| 27 | + (batch, head, seq_len, 1, w), dtype=dtype, device=device |
| 28 | + ).normal_(mean=0, std=0.5) |
| 29 | + |
| 30 | + tri_out = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w) |
| 31 | + # reference implementation |
| 32 | + attn_bias = (rel_h + rel_w).view( |
| 33 | + q.size(0), q.size(1), rel_h.size(2), rel_h.size(3) * rel_w.size(4) |
| 34 | + ) |
| 35 | + ref_out = torch.nn.functional.scaled_dot_product_attention( |
| 36 | + q, k, v, attn_mask=attn_bias |
| 37 | + ) |
| 38 | + |
| 39 | + torch.testing.assert_close(ref_out, tri_out, rtol=1e-3, atol=1e-3) |
| 40 | + |
| 41 | +for batch, (head, seq_len), dtype in itertools.product([1, 8], [(16, 80), (12, 64)], [torch.float16, torch.bfloat16]): |
| 42 | + print(f"batch: {batch} head: {head} seq_len: {seq_len} dtype: {dtype}") |
| 43 | + test_op(batch, head, 4096, seq_len, dtype) |
0 commit comments