|
632 | 632 | "source": [
|
633 | 633 | "H = 128\n",
|
634 | 634 | "W = 128\n",
|
635 |
| - "K_H = 7\n", |
636 |
| - "K_W = 7\n", |
| 635 | + "K_H = 13\n", |
| 636 | + "K_W = 13\n", |
637 | 637 | "\n",
|
638 | 638 | "\n",
|
639 | 639 | "def get_x_y(idx):\n",
|
|
657 | 657 | " return hori_mask & vert_mask\n",
|
658 | 658 | "\n",
|
659 | 659 | "\n",
|
660 |
| - "test_mask(mask_mod=natten_mask)" |
| 660 | + "test_mask(mask_mod=natten_mask, S=H * W)" |
| 661 | + ] |
| 662 | + }, |
| 663 | + { |
| 664 | + "cell_type": "markdown", |
| 665 | + "metadata": {}, |
| 666 | + "source": [ |
| 667 | + "### Tiled NATTEN layout\n", |
| 668 | + "The solution above unrolls 2-D Q and KV into 1-D attention problem in a naive column major way. This breaks the locality of the very sparse Q K V layout: While the density of the MATTEN mask is `(13 * 13) / (128 * 128) = 1.0%`, the density of our block mask becomes 10.16% with 128x128 blocks. Q K V layouts with that retains their 2-D spatial locality could improve the block sparsity and make flexattention implementation more efficient. \n", |
| 669 | + "\n", |
| 670 | + "Static tiling as proposed in the [faster NATTEN](https://arxiv.org/abs/2403.04690) maps static tiles of $ T_h \\times T_w $ in the 2-D space in contiguous region in 1-D Q K V. " |
| 671 | + ] |
| 672 | + }, |
| 673 | + { |
| 674 | + "cell_type": "code", |
| 675 | + "execution_count": null, |
| 676 | + "metadata": {}, |
| 677 | + "outputs": [], |
| 678 | + "source": [ |
| 679 | + "H = 128\n", |
| 680 | + "W = 128\n", |
| 681 | + "K_H = 13\n", |
| 682 | + "K_W = 13\n", |
| 683 | + "T_H, T_W = 8, 8\n", |
| 684 | + "\n", |
| 685 | + "def gen_tiled_natten(W, H, K_W, K_H, T_W, T_H):\n", |
| 686 | + " def get_idx_tiled(x, y):\n", |
| 687 | + " \"\"\"\n", |
| 688 | + " Map 2-D coordinates to 1-D index for static tiles of T_H x T_W.\n", |
| 689 | + " \"\"\"\n", |
| 690 | + " t_x, t_y = x // T_W, y // T_H\n", |
| 691 | + " t_id = t_x * (W // T_W) + t_y\n", |
| 692 | + " i_x, i_y = x % T_W, y % T_H\n", |
| 693 | + " t_offset = i_x * T_W + i_y\n", |
| 694 | + " return t_id * (T_H * T_W) + t_offset\n", |
| 695 | + "\n", |
| 696 | + " def get_x_y_tiled(idx):\n", |
| 697 | + " \"\"\"\n", |
| 698 | + " Map 1-D index to 2-D coordinates for static tiles of T_H x T_W.\n", |
| 699 | + " \"\"\"\n", |
| 700 | + " t_id = idx // (T_H * T_W)\n", |
| 701 | + " t_x, t_y = t_id // (W // T_W), t_id % (W // T_W)\n", |
| 702 | + " t_offset = idx % (T_H * T_W)\n", |
| 703 | + " i_x, i_y = t_offset // T_W, t_offset % T_W\n", |
| 704 | + " return t_x*T_W + i_x, t_y*T_H + i_y\n", |
| 705 | + "\n", |
| 706 | + " def tiled_natten_mask(b, h, q, kv):\n", |
| 707 | + " q_x, q_y = get_x_y_tiled(q)\n", |
| 708 | + " kv_x, kv_y = get_x_y_tiled(kv)\n", |
| 709 | + " kernel_x = q_x.clamp(K_W // 2, (W - 1) - K_W // 2)\n", |
| 710 | + " kernel_y = q_y.clamp(K_H // 2, (H - 1) - K_H // 2)\n", |
| 711 | + " hori_mask = (kernel_x - kv_x).abs() <= K_W // 2\n", |
| 712 | + " vert_mask = (kernel_y - kv_y).abs() <= K_H // 2\n", |
| 713 | + " return hori_mask & vert_mask\n", |
| 714 | + " return tiled_natten_mask\n", |
| 715 | + "\n", |
| 716 | + "# tiled_natten_mask = gen_tiled_natten(W, H, K_W, K_H, T_W, T_H)\n", |
| 717 | + "from attn_gym.masks.natten import generate_tiled_natten\n", |
| 718 | + "tiled_natten_mask_mod = generate_tiled_natten(W, H, K_W, K_H, T_W, T_H)\n", |
| 719 | + "\n", |
| 720 | + "test_mask(mask_mod=tiled_natten_mask_mod, S=H * W)" |
| 721 | + ] |
| 722 | + }, |
| 723 | + { |
| 724 | + "cell_type": "markdown", |
| 725 | + "metadata": {}, |
| 726 | + "source": [ |
| 727 | + "Verify that Naive NATTEN Mask and tiled NATTEN generate the same output" |
| 728 | + ] |
| 729 | + }, |
| 730 | + { |
| 731 | + "cell_type": "code", |
| 732 | + "execution_count": null, |
| 733 | + "metadata": {}, |
| 734 | + "outputs": [], |
| 735 | + "source": [ |
| 736 | + "def run_natten(\n", |
| 737 | + " mask = None,\n", |
| 738 | + " encoder = None, \n", |
| 739 | + " decoder = None,\n", |
| 740 | + " query = None, \n", |
| 741 | + " key = None,\n", |
| 742 | + " value = None, \n", |
| 743 | + " gradOut = None,\n", |
| 744 | + " B=16,\n", |
| 745 | + " H=16,\n", |
| 746 | + " W=128,\n", |
| 747 | + " D=64,\n", |
| 748 | + " print_mask=True,\n", |
| 749 | + "):\n", |
| 750 | + " if decoder:\n", |
| 751 | + " permuter_x, permuter_y = decoder(torch.arange(W*W))\n", |
| 752 | + " permuter_index = permuter_x * W + permuter_y\n", |
| 753 | + " q = query[:, :, permuter_x, permuter_y, :].clone().detach().requires_grad_(query.requires_grad)\n", |
| 754 | + " k = key[:, :, permuter_x, permuter_y, :].clone().detach().requires_grad_(key.requires_grad)\n", |
| 755 | + " v = value[:, :, permuter_x, permuter_y, :].clone().detach().requires_grad_(value.requires_grad)\n", |
| 756 | + " dO = gradOut[:, :, permuter_x, permuter_y, :]\n", |
| 757 | + " else: \n", |
| 758 | + " q = query.flatten(2, 3).clone().detach().requires_grad_(query.requires_grad)\n", |
| 759 | + " k = key.flatten(2, 3).clone().detach().requires_grad_(key.requires_grad)\n", |
| 760 | + " v = value.flatten(2, 3).clone().detach().requires_grad_(value.requires_grad)\n", |
| 761 | + " dO = gradOut.flatten(2, 3)\n", |
| 762 | + " block_mask = create_block_mask_cached(mask, 1, 1, W*W, W*W, device=query.device)\n", |
| 763 | + " if print_mask:\n", |
| 764 | + " print(f\"\\nBlock Mask:\\n{block_mask}\")\n", |
| 765 | + " \n", |
| 766 | + " out = flex_attention(q, k, v, block_mask=block_mask)\n", |
| 767 | + " \n", |
| 768 | + " out.backward(dO)\n", |
| 769 | + " \n", |
| 770 | + " if encoder: \n", |
| 771 | + " i_x = torch.arange(W)[:, None].broadcast_to(W, W).flatten() \n", |
| 772 | + " i_y = torch.arange(W)[None, :].broadcast_to(W, W).flatten() \n", |
| 773 | + " depermuter = encoder(i_x, i_y)\n", |
| 774 | + " out = out[:, :, depermuter, :].reshape(B, H, W, W, D)\n", |
| 775 | + " q_grad = q.grad[:, :, depermuter, :].reshape(B, H, W, W, D)\n", |
| 776 | + " k_grad = k.grad[:, :, depermuter, :].reshape(B, H, W, W, D)\n", |
| 777 | + " v_grad = v.grad[:, :, depermuter, :].reshape(B, H, W, W, D)\n", |
| 778 | + " results = [out, q_grad, k_grad, v_grad]\n", |
| 779 | + " else:\n", |
| 780 | + " out= out.reshape(B, H, W, W, D)\n", |
| 781 | + " q_grad = q.grad.reshape(B, H, W, W, D)\n", |
| 782 | + " k_grad = k.grad.reshape(B, H, W, W, D)\n", |
| 783 | + " v_grad = v.grad.reshape(B, H, W, W, D)\n", |
| 784 | + " results = [out, q_grad, k_grad, v_grad]\n", |
| 785 | + " \n", |
| 786 | + " del q, k, v, dO\n", |
| 787 | + " \n", |
| 788 | + " return results\n", |
| 789 | + "\n", |
| 790 | + "\n", |
| 791 | + "def test_natten_masks(\n", |
| 792 | + " naive,\n", |
| 793 | + " tiled,\n", |
| 794 | + " B=16,\n", |
| 795 | + " H=16,\n", |
| 796 | + " W=128,\n", |
| 797 | + " D=64,\n", |
| 798 | + " skip_correctness=False,\n", |
| 799 | + " print_mask=True,\n", |
| 800 | + "): \n", |
| 801 | + " query = torch.randn(\n", |
| 802 | + " B, H, W, W, D, device=\"cuda\", dtype=torch.float16, requires_grad=True\n", |
| 803 | + " )\n", |
| 804 | + " key = torch.randn(\n", |
| 805 | + " B, H, W, W, D, device=\"cuda\", dtype=torch.float16, requires_grad=True\n", |
| 806 | + " )\n", |
| 807 | + " value = torch.randn(\n", |
| 808 | + " B, H, W, W, D, device=\"cuda\", dtype=torch.float16, requires_grad=True\n", |
| 809 | + " )\n", |
| 810 | + " gradOut = torch.randn(B, H, W, W, D, device=\"cuda\", dtype=torch.float16)\n", |
| 811 | + " \n", |
| 812 | + " naive_results = run_natten(mask=naive[0], encoder=naive[1], decoder=naive[2], query=query, key=key, value=value, gradOut=gradOut, print_mask=print_mask)\n", |
| 813 | + " tiled_results = run_natten(mask=tiled[0], encoder=tiled[1], decoder=tiled[2], query=query, key=key, value=value, gradOut=gradOut, print_mask=print_mask)\n", |
| 814 | + " \n", |
| 815 | + " if not skip_correctness:\n", |
| 816 | + " for naive, tiled in zip(naive_results, tiled_results):\n", |
| 817 | + " torch.testing.assert_close(naive, tiled, atol=1e-1, rtol=1e-2)\n", |
| 818 | + "\n", |
| 819 | + " print(\"Correctness check passed ✅\")\n", |
| 820 | + "\n", |
| 821 | + " # Clean up to save memory\n", |
| 822 | + " del query, key, value, gradOut, naive_results, tiled_results\n", |
| 823 | + " torch.cuda.empty_cache()\n", |
| 824 | + "\n", |
| 825 | + "test_natten_masks(\n", |
| 826 | + " naive=[natten_mask, None, None],\n", |
| 827 | + " tiled=[tiled_natten_mask, get_idx_tiled, get_x_y_tiled],\n", |
| 828 | + ")" |
661 | 829 | ]
|
662 | 830 | },
|
663 | 831 | {
|
|
0 commit comments