Skip to content

Commit b32a6b5

Browse files
authored
Add natten with static tiling & Morton Curve (#87)
* Add natten with static tiling * move natten correctness test to test folder
1 parent 290c1e7 commit b32a6b5

File tree

4 files changed

+461
-3
lines changed

4 files changed

+461
-3
lines changed

attn_gym/masks/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from attn_gym.masks.prefix_lm import generate_prefix_lm_mask
44
from attn_gym.masks.document_mask import generate_doc_mask_mod
55
from attn_gym.masks.dilated_sliding_window import generate_dilated_sliding_window
6+
from attn_gym.masks.natten import generate_natten, generate_tiled_natten, generate_morton_natten

attn_gym/masks/natten.py

+161
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,133 @@ def natten_mask_mod(
4242
natten_mask_mod.__name__ = f"natten_c{canvas_w}x{canvas_h}_k{kernel_w}x{kernel_h}"
4343
return natten_mask_mod
4444

45+
def generate_tiled_natten(
46+
W: int,
47+
H: int,
48+
K_W: int,
49+
K_H: int,
50+
T_W: int,
51+
T_H: int,
52+
) -> _mask_mod_signature:
53+
"""Generates a NATTEN attention mask with a given kernel size and static tiling.
54+
Args:
55+
W: The width of the canvas.
56+
H: The height of the canvas.
57+
K_W: The width of the kernel.
58+
K_H: The height of the kernel.
59+
T_W: The width of the tile.
60+
T_H: The height of the tile.
61+
"""
62+
63+
def get_x_y_tiled(idx: IntTensor) -> Tuple[IntTensor, IntTensor]:
64+
"""
65+
Map 1-D index to 2-D coordinates for static tiles of T_H x T_W.
66+
"""
67+
t_id = idx // (T_H * T_W)
68+
t_x, t_y = t_id // (W // T_W), t_id % (W // T_W)
69+
t_offset = idx % (T_H * T_W)
70+
i_x, i_y = t_offset // T_W, t_offset % T_W
71+
return t_x*T_W + i_x, t_y*T_H + i_y
72+
73+
def tiled_natten_mask(
74+
b: IntTensor,
75+
h: IntTensor,
76+
q_idx: IntTensor,
77+
kv_idx: IntTensor,
78+
) -> BoolTensor:
79+
q_x, q_y = get_x_y_tiled(q_idx)
80+
kv_x, kv_y = get_x_y_tiled(kv_idx)
81+
kernel_x = q_x.clamp(K_W // 2, (W - 1) - K_W // 2)
82+
kernel_y = q_y.clamp(K_H // 2, (H - 1) - K_H // 2)
83+
hori_mask = (kernel_x - kv_x).abs() <= K_W // 2
84+
vert_mask = (kernel_y - kv_y).abs() <= K_H // 2
85+
return hori_mask & vert_mask
86+
87+
tiled_natten_mask.__name__ = f"tiled_natten_c{W}x{H}_k{K_W}x{K_H}_t{T_W}x{T_H}"
88+
return tiled_natten_mask
89+
90+
def interleave_bits_32(x):
91+
"""
92+
Interleave the bits of a 16-bit integer x, producing a 32-bit integer
93+
where the bits of x are interleaved with zeros.
94+
"""
95+
x = x & 0x0000FFFF # Ensure x is 16 bits
96+
x = (x | (x << 8)) & 0x00FF00FF
97+
x = (x | (x << 4)) & 0x0F0F0F0F
98+
x = (x | (x << 2)) & 0x33333333
99+
x = (x | (x << 1)) & 0x55555555
100+
return x
101+
102+
def morton_encode(x, y):
103+
"""
104+
Encode 2D coordinates (x, y) into a Morton code (Z-order curve index).
105+
106+
Parameters:
107+
x (int): The x-coordinate.
108+
y (int): The y-coordinate.
109+
110+
Returns:
111+
int: The Morton code resulting from interleaving the bits of x and y.
112+
"""
113+
return (interleave_bits_32(y) << 1) | interleave_bits_32(x)
114+
115+
def deinterleave_bits_32(code):
116+
"""
117+
Deinterleave bits to retrieve the original 16-bit integer.
118+
"""
119+
code = code & 0x55555555
120+
code = (code | (code >> 1)) & 0x33333333
121+
code = (code | (code >> 2)) & 0x0F0F0F0F
122+
code = (code | (code >> 4)) & 0x00FF00FF
123+
code = (code | (code >> 8)) & 0x0000FFFF
124+
return code
125+
126+
def morton_decode(code):
127+
"""
128+
Decode a Morton code to retrieve the original 2D coordinates (x, y).
129+
130+
Parameters:
131+
code (int): The Morton code.
132+
133+
Returns:
134+
tuple: A tuple (x, y) representing the original coordinates.
135+
"""
136+
x = deinterleave_bits_32(code)
137+
y = deinterleave_bits_32(code >> 1)
138+
return x, y
139+
140+
141+
def generate_morton_natten(
142+
canvas_w: int,
143+
canvas_h: int,
144+
kernel_w: int,
145+
kernel_h: int,
146+
) -> _mask_mod_signature:
147+
"""Generates a NATTEN attention mask with a given kernel size under morton curve layout.
148+
Args:
149+
canvas_w: The width of the canvas.
150+
canvas_h: The height of the canvas.
151+
kernel_w: The width of the kernel.
152+
kernel_h: The height of the kernel.
153+
"""
154+
def natten_mask_mod(
155+
b: IntTensor,
156+
h: IntTensor,
157+
q_idx: IntTensor,
158+
kv_idx: IntTensor,
159+
) -> BoolTensor:
160+
q_x, q_y = morton_decode(q_idx)
161+
kv_x, kv_y = morton_decode(kv_idx)
162+
# kernel nominally attempts to center itself on the query, but kernel center
163+
# is clamped to a fixed distance (kernel half-length) from the canvas edge
164+
kernel_center_x = q_x.clamp(kernel_w // 2, (canvas_w - 1) - kernel_w // 2)
165+
kernel_center_y = q_y.clamp(kernel_h // 2, (canvas_h - 1) - kernel_h // 2)
166+
hori_mask = (kernel_center_x - kv_x).abs() <= kernel_w // 2
167+
vert_mask = (kernel_center_y - kv_y).abs() <= kernel_h // 2
168+
return hori_mask & vert_mask
169+
170+
natten_mask_mod.__name__ = f"morton_natten_c{canvas_w}x{canvas_h}_k{kernel_w}x{kernel_h}"
171+
return natten_mask_mod
45172

46173
def main(device: str = "cpu"):
47174
"""Visualize the attention scores of NATTEN mask mod.
@@ -77,6 +204,40 @@ def make_tensor():
77204
device=device,
78205
name=natten_mask.__name__,
79206
)
207+
208+
209+
tiled_natten_mask = generate_tiled_natten(
210+
W=CANVAS_WIDTH,
211+
H=CANVAS_HEIGHT,
212+
K_W=kernel_size,
213+
K_H=kernel_size,
214+
T_W=2,
215+
T_H=2,
216+
)
217+
visualize_attention_scores(
218+
# TODO: update visualize_attention_scores to support 2D sequences
219+
query.flatten(start_dim=2, end_dim=3),
220+
key.flatten(start_dim=2, end_dim=3),
221+
mask_mod=tiled_natten_mask,
222+
device=device,
223+
name=tiled_natten_mask.__name__,
224+
)
225+
226+
227+
morton_natten_mask = generate_morton_natten(
228+
canvas_w=CANVAS_WIDTH,
229+
canvas_h=CANVAS_HEIGHT,
230+
kernel_w=kernel_size,
231+
kernel_h=kernel_size,
232+
)
233+
visualize_attention_scores(
234+
# TODO: update visualize_attention_scores to support 2D sequences
235+
query.flatten(start_dim=2, end_dim=3),
236+
key.flatten(start_dim=2, end_dim=3),
237+
mask_mod=morton_natten_mask,
238+
device=device,
239+
name=morton_natten_mask.__name__,
240+
)
80241

81242

82243
if __name__ == "__main__":

examples/flex_attn.ipynb

+171-3
Original file line numberDiff line numberDiff line change
@@ -632,8 +632,8 @@
632632
"source": [
633633
"H = 128\n",
634634
"W = 128\n",
635-
"K_H = 7\n",
636-
"K_W = 7\n",
635+
"K_H = 13\n",
636+
"K_W = 13\n",
637637
"\n",
638638
"\n",
639639
"def get_x_y(idx):\n",
@@ -657,7 +657,175 @@
657657
" return hori_mask & vert_mask\n",
658658
"\n",
659659
"\n",
660-
"test_mask(mask_mod=natten_mask)"
660+
"test_mask(mask_mod=natten_mask, S=H * W)"
661+
]
662+
},
663+
{
664+
"cell_type": "markdown",
665+
"metadata": {},
666+
"source": [
667+
"### Tiled NATTEN layout\n",
668+
"The solution above unrolls 2-D Q and KV into 1-D attention problem in a naive column major way. This breaks the locality of the very sparse Q K V layout: While the density of the MATTEN mask is `(13 * 13) / (128 * 128) = 1.0%`, the density of our block mask becomes 10.16% with 128x128 blocks. Q K V layouts with that retains their 2-D spatial locality could improve the block sparsity and make flexattention implementation more efficient. \n",
669+
"\n",
670+
"Static tiling as proposed in the [faster NATTEN](https://arxiv.org/abs/2403.04690) maps static tiles of $ T_h \\times T_w $ in the 2-D space in contiguous region in 1-D Q K V. "
671+
]
672+
},
673+
{
674+
"cell_type": "code",
675+
"execution_count": null,
676+
"metadata": {},
677+
"outputs": [],
678+
"source": [
679+
"H = 128\n",
680+
"W = 128\n",
681+
"K_H = 13\n",
682+
"K_W = 13\n",
683+
"T_H, T_W = 8, 8\n",
684+
"\n",
685+
"def gen_tiled_natten(W, H, K_W, K_H, T_W, T_H):\n",
686+
" def get_idx_tiled(x, y):\n",
687+
" \"\"\"\n",
688+
" Map 2-D coordinates to 1-D index for static tiles of T_H x T_W.\n",
689+
" \"\"\"\n",
690+
" t_x, t_y = x // T_W, y // T_H\n",
691+
" t_id = t_x * (W // T_W) + t_y\n",
692+
" i_x, i_y = x % T_W, y % T_H\n",
693+
" t_offset = i_x * T_W + i_y\n",
694+
" return t_id * (T_H * T_W) + t_offset\n",
695+
"\n",
696+
" def get_x_y_tiled(idx):\n",
697+
" \"\"\"\n",
698+
" Map 1-D index to 2-D coordinates for static tiles of T_H x T_W.\n",
699+
" \"\"\"\n",
700+
" t_id = idx // (T_H * T_W)\n",
701+
" t_x, t_y = t_id // (W // T_W), t_id % (W // T_W)\n",
702+
" t_offset = idx % (T_H * T_W)\n",
703+
" i_x, i_y = t_offset // T_W, t_offset % T_W\n",
704+
" return t_x*T_W + i_x, t_y*T_H + i_y\n",
705+
"\n",
706+
" def tiled_natten_mask(b, h, q, kv):\n",
707+
" q_x, q_y = get_x_y_tiled(q)\n",
708+
" kv_x, kv_y = get_x_y_tiled(kv)\n",
709+
" kernel_x = q_x.clamp(K_W // 2, (W - 1) - K_W // 2)\n",
710+
" kernel_y = q_y.clamp(K_H // 2, (H - 1) - K_H // 2)\n",
711+
" hori_mask = (kernel_x - kv_x).abs() <= K_W // 2\n",
712+
" vert_mask = (kernel_y - kv_y).abs() <= K_H // 2\n",
713+
" return hori_mask & vert_mask\n",
714+
" return tiled_natten_mask\n",
715+
"\n",
716+
"# tiled_natten_mask = gen_tiled_natten(W, H, K_W, K_H, T_W, T_H)\n",
717+
"from attn_gym.masks.natten import generate_tiled_natten\n",
718+
"tiled_natten_mask_mod = generate_tiled_natten(W, H, K_W, K_H, T_W, T_H)\n",
719+
"\n",
720+
"test_mask(mask_mod=tiled_natten_mask_mod, S=H * W)"
721+
]
722+
},
723+
{
724+
"cell_type": "markdown",
725+
"metadata": {},
726+
"source": [
727+
"Verify that Naive NATTEN Mask and tiled NATTEN generate the same output"
728+
]
729+
},
730+
{
731+
"cell_type": "code",
732+
"execution_count": null,
733+
"metadata": {},
734+
"outputs": [],
735+
"source": [
736+
"def run_natten(\n",
737+
" mask = None,\n",
738+
" encoder = None, \n",
739+
" decoder = None,\n",
740+
" query = None, \n",
741+
" key = None,\n",
742+
" value = None, \n",
743+
" gradOut = None,\n",
744+
" B=16,\n",
745+
" H=16,\n",
746+
" W=128,\n",
747+
" D=64,\n",
748+
" print_mask=True,\n",
749+
"):\n",
750+
" if decoder:\n",
751+
" permuter_x, permuter_y = decoder(torch.arange(W*W))\n",
752+
" permuter_index = permuter_x * W + permuter_y\n",
753+
" q = query[:, :, permuter_x, permuter_y, :].clone().detach().requires_grad_(query.requires_grad)\n",
754+
" k = key[:, :, permuter_x, permuter_y, :].clone().detach().requires_grad_(key.requires_grad)\n",
755+
" v = value[:, :, permuter_x, permuter_y, :].clone().detach().requires_grad_(value.requires_grad)\n",
756+
" dO = gradOut[:, :, permuter_x, permuter_y, :]\n",
757+
" else: \n",
758+
" q = query.flatten(2, 3).clone().detach().requires_grad_(query.requires_grad)\n",
759+
" k = key.flatten(2, 3).clone().detach().requires_grad_(key.requires_grad)\n",
760+
" v = value.flatten(2, 3).clone().detach().requires_grad_(value.requires_grad)\n",
761+
" dO = gradOut.flatten(2, 3)\n",
762+
" block_mask = create_block_mask_cached(mask, 1, 1, W*W, W*W, device=query.device)\n",
763+
" if print_mask:\n",
764+
" print(f\"\\nBlock Mask:\\n{block_mask}\")\n",
765+
" \n",
766+
" out = flex_attention(q, k, v, block_mask=block_mask)\n",
767+
" \n",
768+
" out.backward(dO)\n",
769+
" \n",
770+
" if encoder: \n",
771+
" i_x = torch.arange(W)[:, None].broadcast_to(W, W).flatten() \n",
772+
" i_y = torch.arange(W)[None, :].broadcast_to(W, W).flatten() \n",
773+
" depermuter = encoder(i_x, i_y)\n",
774+
" out = out[:, :, depermuter, :].reshape(B, H, W, W, D)\n",
775+
" q_grad = q.grad[:, :, depermuter, :].reshape(B, H, W, W, D)\n",
776+
" k_grad = k.grad[:, :, depermuter, :].reshape(B, H, W, W, D)\n",
777+
" v_grad = v.grad[:, :, depermuter, :].reshape(B, H, W, W, D)\n",
778+
" results = [out, q_grad, k_grad, v_grad]\n",
779+
" else:\n",
780+
" out= out.reshape(B, H, W, W, D)\n",
781+
" q_grad = q.grad.reshape(B, H, W, W, D)\n",
782+
" k_grad = k.grad.reshape(B, H, W, W, D)\n",
783+
" v_grad = v.grad.reshape(B, H, W, W, D)\n",
784+
" results = [out, q_grad, k_grad, v_grad]\n",
785+
" \n",
786+
" del q, k, v, dO\n",
787+
" \n",
788+
" return results\n",
789+
"\n",
790+
"\n",
791+
"def test_natten_masks(\n",
792+
" naive,\n",
793+
" tiled,\n",
794+
" B=16,\n",
795+
" H=16,\n",
796+
" W=128,\n",
797+
" D=64,\n",
798+
" skip_correctness=False,\n",
799+
" print_mask=True,\n",
800+
"): \n",
801+
" query = torch.randn(\n",
802+
" B, H, W, W, D, device=\"cuda\", dtype=torch.float16, requires_grad=True\n",
803+
" )\n",
804+
" key = torch.randn(\n",
805+
" B, H, W, W, D, device=\"cuda\", dtype=torch.float16, requires_grad=True\n",
806+
" )\n",
807+
" value = torch.randn(\n",
808+
" B, H, W, W, D, device=\"cuda\", dtype=torch.float16, requires_grad=True\n",
809+
" )\n",
810+
" gradOut = torch.randn(B, H, W, W, D, device=\"cuda\", dtype=torch.float16)\n",
811+
" \n",
812+
" naive_results = run_natten(mask=naive[0], encoder=naive[1], decoder=naive[2], query=query, key=key, value=value, gradOut=gradOut, print_mask=print_mask)\n",
813+
" tiled_results = run_natten(mask=tiled[0], encoder=tiled[1], decoder=tiled[2], query=query, key=key, value=value, gradOut=gradOut, print_mask=print_mask)\n",
814+
" \n",
815+
" if not skip_correctness:\n",
816+
" for naive, tiled in zip(naive_results, tiled_results):\n",
817+
" torch.testing.assert_close(naive, tiled, atol=1e-1, rtol=1e-2)\n",
818+
"\n",
819+
" print(\"Correctness check passed ✅\")\n",
820+
"\n",
821+
" # Clean up to save memory\n",
822+
" del query, key, value, gradOut, naive_results, tiled_results\n",
823+
" torch.cuda.empty_cache()\n",
824+
"\n",
825+
"test_natten_masks(\n",
826+
" naive=[natten_mask, None, None],\n",
827+
" tiled=[tiled_natten_mask, get_idx_tiled, get_x_y_tiled],\n",
828+
")"
661829
]
662830
},
663831
{

0 commit comments

Comments
 (0)