From 7101dcfb98baa4eaf1f0c16b1c8d0b71e48db3b4 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 9 Dec 2024 09:26:20 +0000 Subject: [PATCH] Revert "[inductor][cpp] Add FlexAttention support for CPU inference (#141453)" This reverts commit 7edbde3334df3223c009769d8226d06071e1fff9. Reverted https://github.com/pytorch/pytorch/pull/141453 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think it is failing periodic NO_AVX2 ([comment](https://github.com/pytorch/pytorch/pull/141453#issuecomment-2527377475)) --- aten/src/ATen/native/CPUBlas.cpp | 84 +- aten/src/ATen/native/CPUBlas.h | 23 +- test/inductor/test_flex_attention.py | 496 +++----- .../codegen/cpp_flex_attention_template.py | 1096 ----------------- torch/_inductor/kernel/flex_attention.py | 212 ---- torch/nn/attention/flex_attention.py | 15 +- 6 files changed, 174 insertions(+), 1752 deletions(-) delete mode 100644 torch/_inductor/codegen/cpp_flex_attention_template.py diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index 7ef54320aa808..9e2c9fb5194a6 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -1125,9 +1125,6 @@ struct Brgemm : public KernelCache { if (dtype == ScalarType::Half) { static bool fp16_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_fp16; return fp16_support; - } else if (dtype == ScalarType::Float) { - static bool fp32_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx2; - return fp32_support; } else if (dtype == ScalarType::BFloat16) { static bool bf16_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core; return bf16_support; @@ -1195,29 +1192,18 @@ void brgemm( int64_t ld_b, int64_t ld_c, const bool add_C, - const float* A, - const float* B, - float* C, - bool is_vnni) { - - TORCH_CHECK(!is_vnni, - "Float Brgemm does not support vnni layout."); - + const at::Half* A, + const at::Half* B, + float* C) { #if defined(ONEDNN_UKERNEL_ENABLED) - if (Brgemm::device_check(ScalarType::Float)) { - Brgemm::call( + if (Brgemm::device_check(ScalarType::Half)) { + Brgemm::call( M, N, K, ld_a, ld_b, ld_c, add_C, A, B, C); return; } #endif - // fallback path - auto beta = add_C ? 1 : 0; - gemm( - at::native::TransposeType::NoTranspose, - at::native::TransposeType::NoTranspose, - N, M, K, 1, - B, ld_b, A, ld_a, - beta, C, ld_c); + TORCH_CHECK(false, + "Half Brgemm is only supported on X64 when oneDNN ukernel is enabled and avx512_fp16 is supported"); } void brgemm( @@ -1230,64 +1216,22 @@ void brgemm( const bool add_C, const at::BFloat16* A, const at::BFloat16* B, - float* C, - bool is_vnni) { + float* C) { #if defined(ONEDNN_UKERNEL_ENABLED) - if (is_vnni && Brgemm::device_check(ScalarType::BFloat16)) { + if (Brgemm::device_check(ScalarType::BFloat16)) { Brgemm::call( M, N, K, ld_a, ld_b, ld_c, add_C, A, B, C); return; } #endif - // fallback path - TORCH_CHECK(!is_vnni, - "BFloat16 Brgemm VNNI format is only supported on X64 when oneDNN ukernel is enabled and `amx` is supported"); - auto beta = add_C ? 1 : 0; - gemm( - at::native::TransposeType::NoTranspose, - at::native::TransposeType::NoTranspose, - N, M, K, 1, - B, ld_b, A, ld_a, - beta, C, ld_c); -} - -void brgemm( - int64_t M, - int64_t N, - int64_t K, - int64_t ld_a, - int64_t ld_b, - int64_t ld_c, - const bool add_C, - const at::Half* A, - const at::Half* B, - float* C, - bool is_vnni) { -#if defined(ONEDNN_UKERNEL_ENABLED) - if (is_vnni && Brgemm::device_check(ScalarType::Half)) { - Brgemm::call( - M, N, K, ld_a, ld_b, ld_c, add_C, A, B, C); - return; - } -#endif - // fallback path - TORCH_CHECK(!is_vnni, - "Half Brgemm VNNI format is only supported on X64 when oneDNN ukernel is enabled and `amx_fp16` is supported"); - auto beta = add_C ? 1 : 0; - gemm( - at::native::TransposeType::NoTranspose, - at::native::TransposeType::NoTranspose, - N, M, K, 1, - B, ld_b, A, ld_a, - beta, C, ld_c); + TORCH_CHECK(false, + "BFloat16 Brgemm is only supported on X64 when oneDNN ukernel is enabled and avx512 is supported"); } -void brgemm_release(bool is_vnni) { +void brgemm_release() { #if defined(ONEDNN_UKERNEL_ENABLED) - if (is_vnni) { - dnnl::ukernel::brgemm::release_hw_context(); - Brgemm::get_current() = nullptr; - } + dnnl::ukernel::brgemm::release_hw_context(); + Brgemm::get_current() = nullptr; #endif } diff --git a/aten/src/ATen/native/CPUBlas.h b/aten/src/ATen/native/CPUBlas.h index 046cb9b439ca1..486422449b4c5 100644 --- a/aten/src/ATen/native/CPUBlas.h +++ b/aten/src/ATen/native/CPUBlas.h @@ -204,8 +204,7 @@ TORCH_API void brgemm( const bool add_C, const at::Half* A, const at::Half* B, - float* C, - bool is_vnni = true); + float* C); TORCH_API void brgemm( int64_t M, @@ -217,24 +216,10 @@ TORCH_API void brgemm( const bool add_C, const at::BFloat16* A, const at::BFloat16* B, - float* C, - bool is_vnni = true); - -TORCH_API void brgemm( - int64_t M, - int64_t N, - int64_t K, - int64_t ld_a, - int64_t ld_b, - int64_t ld_c, - const bool add_C, - const float* A, - const float* B, - float* C, - bool is_vnni = false); + float* C); // Release brgemm hardware context -TORCH_API void brgemm_release(bool is_vnni = true); +TORCH_API void brgemm_release(); // Pack B matrix to get better performance if needed void pack( @@ -248,6 +233,6 @@ void pack( void* out); // Whether pack is supported in the platform. -TORCH_API bool could_pack(ScalarType dt_in); +bool could_pack(ScalarType dt_in); } // namespace at::native::cpublas diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 2e68a9878252f..bc4588ad54c23 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -101,34 +101,13 @@ def create_block_mask_test(score_mod, query, key): return block_mask -TEST_ON_CUDA = ( - torch.cuda.is_available() - and torch.utils._triton.has_triton() - and torch.cuda.get_device_capability() >= (8, 0) +test_dtypes = ( + [torch.float16, torch.bfloat16, torch.float32] + if PLATFORM_SUPPORTS_BF16 + else [torch.float16, torch.float32] ) -if TEST_ON_CUDA: - test_device = "cuda" - test_dtypes = ( - [torch.float32, torch.bfloat16, torch.float16] - if PLATFORM_SUPPORTS_BF16 - else [torch.float16, torch.float32] - ) - test_dtypes_fast = [torch.float16] -else: - test_device = "cpu" - torch_config_string = torch.__config__.show() - LONG_COMPILATION_ON_CPU = False - if "CLANG" in torch_config_string.upper(): - # if the compiler is clang, skip UT for CPU due to long compilation time found in CI - # TODO: check reason of long compile time - LONG_COMPILATION_ON_CPU = True - test_dtypes = ( - [torch.float32, torch.bfloat16] - if torch.ops.mkldnn._is_mkldnn_bf16_supported() - else [torch.float32] - ) - test_dtypes_fast = [torch.float32] +test_dtypes_fast = [torch.float16] # --------- Useful score mod functions for testing --------- @@ -307,12 +286,6 @@ def batch_reserve(paged_attention: PagedAttention, target_seq_len: Tensor): class TestFlexAttention(InductorTestCase): - def setUp(self): - super().setUp() - self.device = test_device - if self.device == "cpu" and LONG_COMPILATION_ON_CPU: - self.skipTest("skip UT for CPU due to long compilation time found in CI") - def _check_equal( self, golden_out: torch.Tensor, @@ -416,74 +389,45 @@ def run_test( KV_S = Q_S if V_D is None: V_D = Q_D - - if self.device == "cpu": - test_inference_only = True - else: - test_inference_only = False - q = torch.randn( - (Q_B, Q_H, Q_S, Q_D), - dtype=dtype, - device=self.device, - requires_grad=not test_inference_only, + (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True ) k = torch.randn( - (KV_B, KV_H, KV_S, Q_D), - dtype=dtype, - device=self.device, - requires_grad=not test_inference_only, + (KV_B, KV_H, KV_S, Q_D), dtype=dtype, device="cuda", requires_grad=True ) v = torch.randn( - (KV_B, KV_H, KV_S, V_D), - dtype=dtype, - device=self.device, - requires_grad=not test_inference_only, + (KV_B, KV_H, KV_S, V_D), dtype=dtype, device="cuda", requires_grad=True ) - if block_mask is None: - block_mask = create_block_mask( - noop_mask, Q_B, Q_H, Q_S, KV_S, device=self.device - ) q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) sdpa_partial = create_attention( score_mod, block_mask, enable_gqa=(not Q_H == KV_H) ) - compiled_sdpa = torch.compile(sdpa_partial) golden_out = sdpa_partial(q_gold, k_gold, v_gold) ref_out = sdpa_partial(q_ref, k_ref, v_ref) compiled_out = compiled_sdpa(q, k, v) - if test_inference_only: - self._check_out( - golden_out, - ref_out, - compiled_out, - is_paged_attention=False, - ) - else: - backward_grad = torch.randn( - (Q_B, Q_H, Q_S, V_D), dtype=dtype, device=self.device - ) - golden_out.backward(backward_grad.to(torch.float64)) - ref_out.backward(backward_grad) - compiled_out.backward(backward_grad) + backward_grad = torch.randn((Q_B, Q_H, Q_S, V_D), dtype=dtype, device="cuda") - self._check_out_and_grad( - golden_out, - ref_out, - compiled_out, - q_gold, - q_ref, - q, - k_gold, - k_ref, - k, - v_gold, - v_ref, - v, - ) + golden_out.backward(backward_grad.to(torch.float64)) + ref_out.backward(backward_grad) + compiled_out.backward(backward_grad) + + self._check_out_and_grad( + golden_out, + ref_out, + compiled_out, + q_gold, + q_ref, + q, + k_gold, + k_ref, + k, + v_gold, + v_ref, + v, + ) def preprocess_paged_attention( self, @@ -512,7 +456,7 @@ def preprocess_paged_attention( KV_H, MAX_CACHED_SEQ_LEN, QK_D, - device=self.device, + device="cuda", dtype=dtype, ) v_cache = torch.zeros( @@ -520,7 +464,7 @@ def preprocess_paged_attention( KV_H, MAX_CACHED_SEQ_LEN, V_D, - device=self.device, + device="cuda", dtype=dtype, ) @@ -534,37 +478,32 @@ def preprocess_paged_attention( # (KV_S//4+1, ..., KV_S//4 + KV_S//2) are allocated to batch index 1, etc. # Thus, kv tensors of batch index 1 will be scattered in the kv cache, simulating # a real use case of paged attention. - paged_attention = PagedAttention( - n_pages, page_size, max_batch_size, device=self.device - ) + paged_attention = PagedAttention(n_pages, page_size, max_batch_size) batch_reserve( paged_attention, - torch.tensor( - [KV_S // 4, KV_S // 2, KV_S // 4, KV_S // 3], device=self.device - ), + torch.tensor([KV_S // 4, KV_S // 2, KV_S // 4, KV_S // 3], device="cuda"), ) batch_reserve( paged_attention, - torch.tensor( - [KV_S // 4, KV_S // 2, KV_S // 2, KV_S // 2], device=self.device - ), + torch.tensor([KV_S // 4, KV_S // 2, KV_S // 2, KV_S // 2], device="cuda"), ) batch_reserve( paged_attention, - torch.tensor([KV_S // 2, KV_S, KV_S // 2, KV_S], device=self.device), + torch.tensor([KV_S // 2, KV_S, KV_S // 2, KV_S], device="cuda"), ) batch_reserve( - paged_attention, torch.tensor([KV_S, KV_S, KV_S, KV_S], device=self.device) + paged_attention, torch.tensor([KV_S, KV_S, KV_S, KV_S], device="cuda") ) # update cache with k and v - input_pos = torch.arange(KV_S, device=self.device, dtype=torch.int32) - batch_idx = torch.arange(KV_B, device=self.device, dtype=torch.int32) + input_pos = torch.arange(KV_S, device="cuda", dtype=torch.int32) + batch_idx = torch.arange(KV_B, device="cuda", dtype=torch.int32) paged_attention.assign(batch_idx, input_pos, k, v, k_cache, v_cache) # convert block mask and score mod converted_block_mask = paged_attention.convert_logical_block_mask(block_mask) converted_score_mod = paged_attention.get_score_mod(score_mod) + return k_cache, v_cache, converted_block_mask, converted_score_mod def run_paged_attention( @@ -583,14 +522,9 @@ def run_paged_attention( k.shape[1], k.shape[2], ) - if self.device == "cpu": - test_inference_only = True - else: - test_inference_only = False + if block_mask is None: - block_mask = create_block_mask( - noop_mask, B, 1, Q_S, KV_S, device=self.device - ) + block_mask = create_block_mask(noop_mask, B, 1, Q_S, KV_S) ( k_cache, @@ -598,42 +532,21 @@ def run_paged_attention( converted_block_mask, converted_score_mod, ) = self.preprocess_paged_attention( - score_mod, - q, - k, - v, - block_mask, - dtype, - block_mask.BLOCK_SIZE[1], + score_mod, q, k, v, block_mask, dtype, block_mask.BLOCK_SIZE[1] ) compiled_sdpa = torch.compile(flex_attention) # compute - return_lse = True - if test_inference_only: - return_lse = False - compiled_lse = None - compiled_out = compiled_sdpa( - q, - k_cache, - v_cache, - return_lse=return_lse, - block_mask=converted_block_mask, - score_mod=converted_score_mod, - enable_gqa=(not Q_H == KV_H), - ) - - else: - compiled_out, compiled_lse = compiled_sdpa( - q, - k_cache, - v_cache, - return_lse=return_lse, - block_mask=converted_block_mask, - score_mod=converted_score_mod, - enable_gqa=(not Q_H == KV_H), - ) + compiled_out, compiled_lse = compiled_sdpa( + q, + k_cache, + v_cache, + return_lse=True, + block_mask=converted_block_mask, + score_mod=converted_score_mod, + enable_gqa=(not Q_H == KV_H), + ) return compiled_out, compiled_lse def run_test_with_paged_attention( @@ -651,32 +564,21 @@ def run_test_with_paged_attention( block_mask: Optional[BlockMask] = None, ): assert Q_H % KV_H == 0 - if self.device == "cpu": - test_inference_only = True - else: - test_inference_only = False + q = torch.randn( - (Q_B, Q_H, Q_S, QK_D), dtype=dtype, device=self.device, requires_grad=False + (Q_B, Q_H, Q_S, QK_D), dtype=dtype, device="cuda", requires_grad=False ) k = torch.randn( - (KV_B, KV_H, KV_S, QK_D), - dtype=dtype, - device=self.device, - requires_grad=False, + (KV_B, KV_H, KV_S, QK_D), dtype=dtype, device="cuda", requires_grad=False ) v = torch.randn( - (KV_B, KV_H, KV_S, V_D), - dtype=dtype, - device=self.device, - requires_grad=False, + (KV_B, KV_H, KV_S, V_D), dtype=dtype, device="cuda", requires_grad=False ) q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) if block_mask is None: - block_mask = create_block_mask( - noop_mask, Q_B, 1, Q_S, KV_S, device=self.device - ) + block_mask = create_block_mask(noop_mask, Q_B, 1, Q_S, KV_S) sdpa_partial = create_attention( score_mod, block_mask, enable_gqa=(not Q_H == KV_H) @@ -687,20 +589,19 @@ def run_test_with_paged_attention( compiled_out, compiled_lse = self.run_paged_attention( score_mod, q, k, v, dtype, block_mask ) + self._check_out( golden_out, ref_out, compiled_out, is_paged_attention=True, ) - - if not test_inference_only: - self._check_out( - golden_lse, - ref_lse, - compiled_lse, - is_paged_attention=True, - ) + self._check_out( + golden_lse, + ref_lse, + compiled_lse, + is_paged_attention=True, + ) def run_test_with_call( self, @@ -715,27 +616,14 @@ def run_test_with_call( KV_S: int = S, V_D: int = D, ): - if self.device == "cpu": - test_inference_only = True - else: - test_inference_only = False q = torch.randn( - (Q_B, Q_H, Q_S, Q_D), - dtype=dtype, - device=self.device, - requires_grad=not test_inference_only, + (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True ) k = torch.randn( - (KV_B, KV_H, KV_S, Q_D), - dtype=dtype, - device=self.device, - requires_grad=not test_inference_only, + (KV_B, KV_H, KV_S, Q_D), dtype=dtype, device="cuda", requires_grad=True ) v = torch.randn( - (KV_B, KV_H, KV_S, V_D), - dtype=dtype, - device=self.device, - requires_grad=not test_inference_only, + (KV_B, KV_H, KV_S, V_D), dtype=dtype, device="cuda", requires_grad=True ) q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) @@ -743,36 +631,27 @@ def run_test_with_call( golden_out = sdpa_call(q_gold, k_gold, v_gold) ref_out = sdpa_call(q_ref, k_ref, v_ref) compiled_out = compiled_sdpa(q, k, v) - if test_inference_only: - self._check_out( - golden_out, - ref_out, - compiled_out, - is_paged_attention=False, - ) - else: - backward_grad = torch.randn( - (Q_B, Q_H, Q_S, V_D), dtype=dtype, device=self.device - ) - golden_out.backward(backward_grad.to(torch.float64)) - ref_out.backward(backward_grad) - compiled_out.backward(backward_grad) + backward_grad = torch.randn((Q_B, Q_H, Q_S, V_D), dtype=dtype, device="cuda") - self._check_out_and_grad( - golden_out, - ref_out, - compiled_out, - q_gold, - q_ref, - q, - k_gold, - k_ref, - k, - v_gold, - v_ref, - v, - ) + golden_out.backward(backward_grad.to(torch.float64)) + ref_out.backward(backward_grad) + compiled_out.backward(backward_grad) + + self._check_out_and_grad( + golden_out, + ref_out, + compiled_out, + q_gold, + q_ref, + q, + k_gold, + k_ref, + k, + v_gold, + v_ref, + v, + ) def run_dynamic_test( self, @@ -911,32 +790,13 @@ def run_automatic_dynamic_test( S: int = S, D: int = D, ): - if self.device == "cpu": - test_inference_only = True - else: - test_inference_only = False MAX_S = S - block_mask1 = create_block_mask(noop_mask, 1, 1, S, S, device=self.device) + block_mask1 = create_block_mask(noop_mask, 1, 1, S, S) sdpa_partial1 = create_attention(score_mod, block_mask=block_mask1) # The first eager batch, shape (B, H, S, D) - q1 = torch.randn( - (B, H, S, D), - dtype=dtype, - device=self.device, - requires_grad=not test_inference_only, - ) - k1 = torch.randn( - (B, H, S, D), - dtype=dtype, - device=self.device, - requires_grad=not test_inference_only, - ) - v1 = torch.randn( - (B, H, S, D), - dtype=dtype, - device=self.device, - requires_grad=not test_inference_only, - ) + q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") golden_out1 = sdpa_partial1( q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64) ) @@ -945,26 +805,11 @@ def run_automatic_dynamic_test( # The second eager batch, shape (B * 2, H, S / 2, D) B = int(B * 2) S = int(S / 2) - block_mask2 = create_block_mask(noop_mask, 1, 1, S, S, device=self.device) + block_mask2 = create_block_mask(noop_mask, 1, 1, S, S) sdpa_partial2 = create_attention(score_mod, block_mask=block_mask2) - q2 = torch.randn( - (B, H, S, D), - dtype=dtype, - device=self.device, - requires_grad=not test_inference_only, - ) - k2 = torch.randn( - (B, H, S, D), - dtype=dtype, - device=self.device, - requires_grad=not test_inference_only, - ) - v2 = torch.randn( - (B, H, S, D), - dtype=dtype, - device=self.device, - requires_grad=not test_inference_only, - ) + q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") golden_out2 = sdpa_partial2( q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64) ) @@ -973,26 +818,11 @@ def run_automatic_dynamic_test( # The third eager batch, shape (B * 4, H, S / 4, D) B = int(B * 2) S = int(S / 2) - block_mask3 = create_block_mask(noop_mask, 1, 1, S, S, device=self.device) + block_mask3 = create_block_mask(noop_mask, 1, 1, S, S) sdpa_partial3 = create_attention(score_mod, block_mask=block_mask3) - q3 = torch.randn( - (B, H, S, D), - dtype=dtype, - device=self.device, - requires_grad=not test_inference_only, - ) - k3 = torch.randn( - (B, H, S, D), - dtype=dtype, - device=self.device, - requires_grad=not test_inference_only, - ) - v3 = torch.randn( - (B, H, S, D), - dtype=dtype, - device=self.device, - requires_grad=not test_inference_only, - ) + q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") golden_out3 = sdpa_partial3( q3.to(torch.float64), k3.to(torch.float64), v3.to(torch.float64) ) @@ -1027,6 +857,7 @@ def run_automatic_dynamic_test( self._check_equal(golden_out3, ref_out3, compiled_out3, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) + @supported_platform @common_utils.parametrize("dtype", test_dtypes) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods(self, dtype: torch.dtype, score_mod: Callable): @@ -1073,6 +904,7 @@ def test_builtin_score_mods_dynamic( ): self.run_dynamic_test(score_mask_mod, dtype) + @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_automatic_dynamic( @@ -1080,6 +912,7 @@ def test_builtin_score_mods_automatic_dynamic( ): self.run_automatic_dynamic_test(score_mod, dtype) + @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_different_seqlen( @@ -1100,6 +933,7 @@ def test_builtin_score_mods_different_seqlen( self.run_test(*inputs) self.run_test_with_paged_attention(*inputs) + @supported_platform @common_utils.parametrize("dtype", test_dtypes) @common_utils.parametrize("score_mod", test_score_mods) @common_utils.parametrize("BLOCK_SIZE", test_block_size) @@ -1109,12 +943,11 @@ def test_builtin_score_mods_different_block_size( score_mod: Callable, BLOCK_SIZE: Union[int, Tuple[int, int]], ): - block_mask = create_block_mask( - noop_mask, B, H, S, S, BLOCK_SIZE=BLOCK_SIZE, device=self.device - ) + block_mask = create_block_mask(noop_mask, B, H, S, S, BLOCK_SIZE=BLOCK_SIZE) self.run_test(score_mod, dtype, block_mask=block_mask) self.run_test_with_paged_attention(score_mod, dtype, block_mask=block_mask) + @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("batch_dims", test_Bq_Bkv) @common_utils.parametrize("head_dims", test_Hq_Hkv) @@ -1132,7 +965,7 @@ def test_kv_batch_broadcast( Bq, Bkv = batch_dims assert Bq > 1 and Bkv == 1 - block_mask = create_block_mask(noop_mask, Bq, 1, S, S, device=self.device) + block_mask = create_block_mask(noop_mask, Bq, 1, S, S) self.run_test( score_mod, @@ -1148,6 +981,7 @@ def test_kv_batch_broadcast( block_mask, ) + @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("batch_dims", test_Bq_Bkv) @common_utils.parametrize("head_dims", test_Hq_Hkv) @@ -1168,13 +1002,25 @@ def test_kv_batch_broadcast_causal_mask( def mask_mod(b, h, q, kv): return q >= kv - block_mask = create_block_mask(mask_mod, Bq, 1, S, S, device=self.device) + block_mask = create_block_mask(mask_mod, Bq, 1, S, S) attention = functools.partial( flex_attention, block_mask=block_mask, enable_gqa=(not Hq == Hkv) ) - self.run_test_with_call(attention, dtype, Bq, Hq, S, D, Bkv, Hkv, S, D) + self.run_test_with_call( + attention, + torch.float16, + Bq, + Hq, + S, + D, + Bkv, + Hkv, + S, + D, + ) + @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_GQA(self, dtype: torch.dtype, score_mod: Callable): @@ -3065,6 +2911,21 @@ def global_causal(b, h, q_idx, kv_idx): torch.testing.assert_close(flex_k_grad, k.grad, atol=3e-3, rtol=2e-3) torch.testing.assert_close(flex_v_grad, v.grad, atol=3e-3, rtol=2e-3) + def test_cpu_error_message(self): + make_tensor = functools.partial( + torch.randn, + (2, 2, 128, 16), + device="cpu", + dtype=torch.float32, + requires_grad=False, + ) + query, key, value = make_tensor(), make_tensor(), make_tensor() + with self.assertRaisesRegex( + ValueError, + "FlexAttention is only supported on CUDA devices. Found input tensors on cpu device.", + ): + flex_attention(query, key, value) + @supported_platform def test_mixed_device_error_message(self): # Create tensors on different devices @@ -3437,40 +3298,6 @@ def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3 """, # noqa: B950 ) - @unittest.skipIf(TEST_ON_CUDA, "Testing CPU error message") - def test_cpu_error_message_return_lse(self): - make_tensor = functools.partial( - torch.randn, - (2, 2, 128, 16), - device="cpu", - dtype=torch.float32, - requires_grad=False, - ) - query, key, value = make_tensor(), make_tensor(), make_tensor() - attention = torch.compile(flex_attention) - with self.assertRaisesRegex( - torch._dynamo.exc.BackendCompilerFailed, - r"NotImplementedError: torch.compile on CPU only supports inference and `return_lse` is not supported yet.", - ): - attention(query, key, value, return_lse=True) - - @unittest.skipIf(TEST_ON_CUDA, "Testing CPU error message") - def test_validate_cpu_dtype_error_message(self): - make_tensor = functools.partial( - torch.randn, - (2, 2, 128, 16), - device="cpu", - dtype=torch.half, - requires_grad=False, - ) - query, key, value = make_tensor(), make_tensor(), make_tensor() - attention = torch.compile(flex_attention) - with self.assertRaisesRegex( - torch._dynamo.exc.BackendCompilerFailed, - r"`torch.float` and `torch.bfloat16` are supported in FlexAttention for CPU device. Found input tensors are `torch.float16`.", - ): - attention(query, key, value) - @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") def test_device_cuda_1(self): class TestModule(torch.nn.Module): @@ -3999,12 +3826,6 @@ def create_inputs(S): class TestPagedAttention(InductorTestCase): - def setUp(self): - super().setUp() - self.device = test_device - if self.device == "cpu" and LONG_COMPILATION_ON_CPU: - self.skipTest("skip UT for CPU due to long compilation time found in CI") - def _check_equal( self, golden_out: torch.Tensor, @@ -4302,7 +4123,8 @@ def test_update(self): ) self.assertEqual(k_cache, expected_cache) - @common_utils.parametrize("dtype", test_dtypes_fast) + @supported_platform + @common_utils.parametrize("dtype", test_dtypes) @common_utils.parametrize("score_mod", test_score_mods) def test_paged_builtin_score_mods(self, dtype: torch.dtype, score_mod: Callable): n_pages, page_size, max_batch_size, max_seq_len = 32, 128, 4, 512 @@ -4312,15 +4134,15 @@ def causal_mask(b, h, q, kv): return q >= kv block_mask = create_block_mask( - causal_mask, max_batch_size, 1, max_seq_len, max_seq_len, device=self.device + causal_mask, max_batch_size, 1, max_seq_len, max_seq_len ) q = torch.randn( max_batch_size, n_heads, max_seq_len, head_dim, - device=self.device, - dtype=dtype, + device="cuda", + dtype=torch.float16, requires_grad=False, ) k = torch.randn( @@ -4328,8 +4150,8 @@ def causal_mask(b, h, q, kv): n_heads, max_seq_len, head_dim, - device=self.device, - dtype=dtype, + device="cuda", + dtype=torch.float16, requires_grad=False, ) v = torch.randn( @@ -4337,8 +4159,8 @@ def causal_mask(b, h, q, kv): n_heads, max_seq_len, head_dim, - device=self.device, - dtype=dtype, + device="cuda", + dtype=torch.float16, requires_grad=False, ) @@ -4356,39 +4178,27 @@ def causal_mask(b, h, q, kv): n_heads, MAX_CACHED_SEQ_LEN, head_dim, - device=self.device, - dtype=dtype, + device="cuda", + dtype=torch.float16, ) v_cache = torch.zeros( 1, n_heads, MAX_CACHED_SEQ_LEN, head_dim, - device=self.device, - dtype=dtype, + device="cuda", + dtype=torch.float16, ) - paged_cache = PagedAttention( - n_pages, page_size, max_batch_size, device=self.device - ) - batch_reserve( - paged_cache, torch.tensor([100, 200, 50, 300], device=self.device) - ) - batch_reserve( - paged_cache, torch.tensor([100, 512, 300, 300], device=self.device) - ) - batch_reserve( - paged_cache, torch.tensor([512, 512, 300, 300], device=self.device) - ) - batch_reserve( - paged_cache, torch.tensor([512, 512, 512, 300], device=self.device) - ) - batch_reserve( - paged_cache, torch.tensor([512, 512, 512, 512], device=self.device) - ) + paged_cache = PagedAttention(n_pages, page_size, max_batch_size) + batch_reserve(paged_cache, torch.tensor([100, 200, 50, 300], device="cuda")) + batch_reserve(paged_cache, torch.tensor([100, 512, 300, 300], device="cuda")) + batch_reserve(paged_cache, torch.tensor([512, 512, 300, 300], device="cuda")) + batch_reserve(paged_cache, torch.tensor([512, 512, 512, 300], device="cuda")) + batch_reserve(paged_cache, torch.tensor([512, 512, 512, 512], device="cuda")) - batch_idx = torch.arange(max_batch_size, device=self.device, dtype=torch.int32) - input_pos = torch.arange(max_seq_len, device=self.device, dtype=torch.int32) + batch_idx = torch.arange(max_batch_size, device="cuda", dtype=torch.int32) + input_pos = torch.arange(max_seq_len, device="cuda", dtype=torch.int32) paged_cache.assign(batch_idx, input_pos, k, v, k_cache, v_cache) new_block_mask = paged_cache.convert_logical_block_mask(block_mask) diff --git a/torch/_inductor/codegen/cpp_flex_attention_template.py b/torch/_inductor/codegen/cpp_flex_attention_template.py deleted file mode 100644 index 12d66ad06f255..0000000000000 --- a/torch/_inductor/codegen/cpp_flex_attention_template.py +++ /dev/null @@ -1,1096 +0,0 @@ -# mypy: allow-untyped-defs -import contextlib -import logging -import re -from typing import List, Optional -from unittest.mock import patch - -import sympy - -import torch -import torch.utils - -from .. import ir -from ..ir import TensorBox -from ..select_algorithm import DataProcessorTemplateWrapper -from ..utils import parallel_num_threads -from ..virtualized import V -from .cpp_template import CppTemplate - - -log = logging.getLogger(__name__) - -# TODO: reuse cpp codegen to generate below pointwise/reduction kernels -SOFTMAX_FUSIONS = r""" -// 1) out = exp(a - val) -// 2) val = sum(out) -template -inline void {{kernel_name}}_exp_reduce_sum_fusion_kernel( - T1* a, - const int& size, - T2* out, - T1& val) { - auto vec_size = at::vec::Vectorized::size(); - auto vec_max = at::vec::Vectorized(val); - T1 tmp_sum = 0; - auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); - for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(a + i); - auto tmp1 = tmp0 - vec_max; - auto tmp2 = tmp1.exp_u20(); - vec_tmp_sum += tmp2; - at::native::_store(out + i, tmp2); - } - tmp_sum = at::vec::vec_reduce_all( - [](at::vec::Vectorized& x, at::vec::Vectorized& y) { - return x + y; - }, - vec_tmp_sum); - for (long i = vec_size * (size / vec_size); i < size; i++) { - auto tmp0 = a[i]; - auto tmp1 = tmp0 - val; - auto tmp2 = exp(tmp1); - tmp_sum += tmp2; - out[i] = tmp2; - } - val = tmp_sum; -} - -// 1) out = a * scale -// 2) max = max(out) -template -inline void {{kernel_name}}_mul_reduce_max_fusion_kernel( - const scalar_t* a, - const scalar_t& scale, - const int& size, - scalar_t* out, - scalar_t& max) { - auto vec_size = at::vec::Vectorized::size(); - auto vec_scale = at::vec::Vectorized(scale); - scalar_t tmp_max = -std::numeric_limits::infinity(); - auto vec_tmp_max = at::vec::Vectorized(tmp_max); - for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(a + i); - auto tmp1 = tmp0 * vec_scale; - vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp1); - at::native::_store(out + i, tmp1); - } - for (long i = vec_size * (size / vec_size); i < size; i++) { - auto tmp0 = a[i]; - auto tmp1 = tmp0 * scale; - tmp_max = std::max(tmp_max, tmp1); - out[i] = tmp1; - } - max = std::max( - tmp_max, - at::vec::vec_reduce_all( - [](at::vec::Vectorized& x, at::vec::Vectorized& y) { - return at::vec::maximum(x, y); - }, - vec_tmp_max)); -} - -template -static inline scalar_t* {{kernel_name}}_conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) { - TORCH_CHECK(ptr2 == nullptr); - return ptr; -} - -template , int> = 0> -static inline scalar_t* {{kernel_name}}_conditional_data_ptr(float* ptr, scalar_t* ptr2) { - return ptr2; -} - -template -inline void {{kernel_name}}_fill_stub(scalar_t* data, scalar_t val, int64_t size) { - using Vec = at::vec::Vectorized; - Vec data_vec = Vec(val); - int64_t d = 0; - for (; d < size - (size % Vec::size()); d += Vec::size()) { - data_vec.store(data + d); - } - #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) - # pragma unroll - #endif - for (; d < size; d++) { - data[d] = val; - } -} - -// out = a * scale -template -inline void {{kernel_name}}_mul_scale_kernel( - scalar_t* a, - scalar_t scale, - int64_t size) { - auto vec_size = at::vec::Vectorized::size(); - auto vec_scale = at::vec::Vectorized(scale); - for (int64_t i = 0; i < vec_size * (size / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(a + i); - auto tmp1 = tmp0 * vec_scale; - at::native::_store(a + i, tmp1); - } - for (int64_t i = vec_size * (size / vec_size); i < size; i++) { - auto tmp0 = a[i]; - auto tmp1 = tmp0 * scale; - a[i] = tmp1; - } -} - -""" - -BRGEMM_PACK_FUNCTIONS = r""" -template -inline void {{kernel_name}}_copy_value_with_pad( - const scalar_t* value_ptr, - scalar_t* dst_ptr, - int64_t rows, - int64_t cols, - int64_t prows, - int64_t pcols, - int64_t ldi) { - auto vec_size = at::vec::Vectorized::size(); - int64_t i = 0; - for (; i < rows; i++) { - int64_t j = 0; - for (; j < cols - (cols % vec_size); j += vec_size) { - auto vec_v = - at::vec::Vectorized::loadu(value_ptr + i * ldi + j); - vec_v.store(dst_ptr + i * pcols + j); - } - - if (j < cols) { - auto vec_v = at::vec::Vectorized::loadu( - value_ptr + i * ldi + j, cols - j); - vec_v.store(dst_ptr + i * pcols + j, cols - j); - } - - // col padding - auto psize = pcols - cols; - if (psize > 0) { - auto zero_vec = at::vec::Vectorized(0); - int64_t pj = 0; - for (; pj < psize - (psize % vec_size); pj += vec_size) { - zero_vec.store(dst_ptr + i * pcols + cols + pj); - } - if (pj < psize) { - zero_vec.store(dst_ptr + i * pcols + cols + pj, psize - pj); - } - } - } - // row padding - for (; i < prows; i++) { - auto zero_vec = at::vec::Vectorized(0); - int64_t j = 0; - for (; j < pcols - (pcols % vec_size); j += vec_size) { - zero_vec.store(dst_ptr + i * pcols + j); - } - if (j < pcols) { - zero_vec.store(dst_ptr + i * pcols + j, pcols - j); - } - - } -} -// Transpose a [2, 32] matrix to [32, 2] -// Note: the output leading dimension should be 2, -// that is, the output must be contiguous -static inline void {{kernel_name}}_transpose_pad_2x32_block( - const uint16_t* src, - uint16_t* dst, - int64_t ld_src, - int krem = 2, - int nrem = 32) { -#if defined(CPU_CAPABILITY_AVX512) - __m512i r0, r1; - __m512i d0, d1; - // load - if (nrem < 32) { - __mmask32 mask_krem_v = (1LL << nrem) - 1; - r0 = _mm512_maskz_loadu_epi16(mask_krem_v, src); - // if krem is not 2, pad with zeros - if (krem == 2) { - r1 = _mm512_maskz_loadu_epi16(mask_krem_v, src + ld_src); - } else { - r1 = _mm512_setzero_si512(); - } - } else { - r0 = _mm512_loadu_si512(reinterpret_cast(src)); - if (krem == 2) { - r1 = _mm512_loadu_si512(reinterpret_cast(src + ld_src)); - } else { - r1 = _mm512_setzero_si512(); - } - } - // transpose - d0 = _mm512_unpacklo_epi16(r0, r1); - d1 = _mm512_unpackhi_epi16(r0, r1); - r0 = _mm512_shuffle_i32x4(d0, d1, 0x88); - r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd); - d0 = _mm512_shuffle_i32x4(r0, r1, 0x88); - d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd); - - // store - if (nrem < 16) { - __mmask32 mask_rem_v = (1LL << (nrem * 2)) - 1; - _mm512_mask_storeu_epi16(dst, mask_rem_v, d0); - } else if (nrem == 16) { - _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); - } else if (nrem < 32) { - __mmask32 mask_rem_v = (1LL << (nrem * 2 - 32)) - 1; - _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); - _mm512_mask_storeu_epi16( - reinterpret_cast<__m512i*>(dst + 32), mask_rem_v, d1); - } else { - // normal store - _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); - _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 32), d1); - } -#else -TORCH_CHECK(false, "transpose_pad_2x32_block is only supported when avx512 is supported") -#endif -} - -// To use AMX to accelerate GEMM, -// reorder the memory format [K, N] -> [K/2, N, 2] -// Note: If K % 2 != 0, pad K implicitly -static inline void {{kernel_name}}_pack_vnni2( - const uint16_t* src, - uint16_t* dst, - int64_t ld_src, - int64_t K, - int64_t N) { -#if defined(CPU_CAPABILITY_AVX512) - int64_t bk = 0; - int64_t _K = K / 2 * 2; - int64_t _N = N / 32 * 32; - for (; bk < _K; bk += 2) { - int64_t bn = 0; - for (; bn < _N; bn += 32) { - {{kernel_name}}_transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src); - } - int64_t nrem = N - bn; - if (nrem > 0) { - {{kernel_name}}_transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 2, nrem); - } - } - if (K % 2 == 1) { - int64_t bn = 0; - for (; bn < _N; bn += 32) { - {{kernel_name}}_transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1); - } - int64_t nrem = N - bn; - if (nrem > 0) { - {{kernel_name}}_transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1, nrem); - } - } -#else -TORCH_CHECK(false, "pack_vnni2 is only supported when avx512 is supported") -#endif -} -""" - -ALLOCATE_BUFFER = r""" - int64_t {{buffer_name}}_dtype_itemsize = std::is_same_v<{{buffer_dtype}}, at::BFloat16> ? 2 : 4; - auto& {{buffer_name}}_allocator = *at::getCPUAllocator(); - auto {{buffer_name}}_work_data = {{buffer_name}}_allocator.allocate({{buffer_size}}*{{buffer_name}}_dtype_itemsize); - void* {{buffer_name}}_data_ptr = {{buffer_name}}_work_data.get(); - {{buffer_dtype}}* {{buffer_name}} = ({{buffer_dtype}}*){{buffer_name}}_data_ptr; -""" - -FLEX_ATTENTION_TEMPLATE = r""" -{{template.header().getvalue()}} -#include -#include -#include -{{template.codegen_softmax_fusion(kernel.kernel_name)}} -{{template.codegen_brgemm_pack_function(kernel.kernel_name)}} -{%- set kernel_args = {"query": query, "key": key, "value": value, - "kv_num_blocks": kv_num_blocks, "kv_indices": kv_indices, "full_kv_num_blocks": full_kv_num_blocks} %} -{%- set kernel_args = template.update_kernel_args(kernel_args) %} - -extern "C" -{{kernel.def_kernel(inputs=kernel_args, outputs={"output": output}, extra_sizevars=template.extra_sizevars)}} -{ - int64_t kvBlockSize = {{kvBlockSize}}; - kvBlockSize = kvBlockSize>{{kernel.size(key, 1)}} ? {{kernel.size(key, 1)}} - : kvBlockSize; - int64_t num_thread = {{num_thread}}; - - // dtypes of kernel and internal buffers - using scalar_t = {{kernel.dtype(query)}}; - constexpr bool is_reduced_type = std::is_reduced_floating_point_v; - using accum_t = at::opmath_type<{{kernel.dtype(query)}}>; - using Vec = at::vec::Vectorized; - accum_t scaling_factor = {{scale}}; - int64_t batchSize = {{kernel.size(query, 0)}}; - int64_t qSize = {{kernel.size(query, 1)}}; - int64_t num_head = {{kernel.size(query, 2)}}; - int64_t headSize = {{kernel.size(query, 3)}}; - int64_t batchSize_k = {{kernel.size(key, 0)}}; - int64_t num_head_k = {{kernel.size(key, 2)}}; - int64_t headSize_v = {{kernel.size(value, 3)}}; - bool is_broadcast_bs_kv = batchSize != batchSize_k; - bool is_broadcast_head_kv = num_head != num_head_k; - int64_t gqa_shards = num_head / num_head_k; - int64_t bs_shards = batchSize / batchSize_k; - - int64_t batchSize_kvi = {{kernel.size(kv_indices, 0)}}; - int64_t num_head_kvi = {{kernel.size(kv_indices, 1)}}; - int64_t block_num_kvi = {{kernel.size(kv_indices, 3)}}; - bool is_broadcast_bs_kvi = batchSize != batchSize_kvi; - bool is_broadcast_head_kvi = num_head != num_head_kvi; - int64_t gqa_shards_kvi = num_head / num_head_kvi; - int64_t bs_shards_kvi = batchSize / batchSize_kvi; - int64_t kviStrideB = {{kernel.stride(kv_indices, 0)}}; - int64_t kviStrideH = {{kernel.stride(kv_indices, 1)}}; - int64_t kviStrideQ = {{kernel.stride(kv_indices, 2)}}; - auto kv_indices_data = kv_indices; - - // Strides - int64_t qStrideB = {{kernel.stride(query, 0)}}; - int64_t qStrideM = {{kernel.stride(query, 1)}}; - int64_t qStrideH = {{kernel.stride(query, 2)}}; - int64_t kStrideB = {{kernel.stride(key, 0)}}; - int64_t kStrideN = {{kernel.stride(key, 1)}}; - int64_t kStrideH = {{kernel.stride(key, 2)}}; - int64_t vStrideB = {{kernel.stride(value, 0)}}; - int64_t vStrideN = {{kernel.stride(value, 1)}}; - int64_t vStrideH = {{kernel.stride(value, 2)}}; - int64_t oStrideB = {{kernel.stride(output, 0)}}; - int64_t oStrideM = {{kernel.stride(output, 2)}}; - int64_t oStrideH = {{kernel.stride(output, 1)}}; - - // Check total kv block number for kv value. - int64_t block_num_kv_count = 0; - bool has_block_indice_zero = true; - for (int64_t kv_count = 0; kv_count < block_num_kvi; kv_count++) { - if (*(kv_indices + kv_count) > 0) { - block_num_kv_count++; - } else if (*(kv_indices + kv_count) == 0) { - if (has_block_indice_zero) { - has_block_indice_zero = false; - block_num_kv_count++; - } else { - break; - } - } - } - // Check to use kv_indice if total block size is bigger than kv length, e.g., - // in PagedAttention case. - bool use_kv_indice = false; - if (block_num_kvi != block_num_kv_count && batchSize_k == 1) { - use_kv_indice = true; - } - int64_t kvSize = use_kv_indice ? block_num_kv_count * kvBlockSize - : {{kernel.size(key, 1)}}; - - // Split size heuristics tuned for q/k len - int64_t qSplitSize = 32; - int64_t kvSplitSize = 512; - if (qSize >= 768) { - qSplitSize = 256; - kvSplitSize = 512; - } else if (qSize >= 192) { - qSplitSize = 64; - kvSplitSize = 512; - } - if (kvBlockSize < kvSplitSize) { - kvSplitSize = kvBlockSize; - } - - qSplitSize = qSplitSize > qSize ? qSize : qSplitSize; - kvSplitSize = kvSplitSize > kvSize ? kvSize : kvSplitSize; - int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize; - int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize; - int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; - - bool need_pack = false; - // Whether pack is needed for BFloat16 - if (std::is_same_v) { - // check platform ability - need_pack = at::native::cpublas::could_pack(at::kBFloat16); - } - if (need_pack) { - // When the number of gemm is greater than the number of pack, - // the pack overhead can be overlaped. - int64_t thresh_size = 64 ; - need_pack = kvSize >= thresh_size && qSize >= thresh_size; - if (need_pack) { - double pack_size = batchSize * num_head * kvSize * headSize; - double qs_per_thread = (batchSize * num_head * qSlice + num_thread - 1) / num_thread; - double gemm_size_per_thread = qs_per_thread * qSplitSize * kvSize * headSize; - need_pack = gemm_size_per_thread / pack_size >= 4; - } - } - - // Pad is needed for packing when K is not even - bool headSize_even = headSize % 2 == 0; - int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize; - int64_t ekvSplitSize = need_pack && (kvSplitSize % 2 != 0) ? kvSplitSize + 1 : kvSplitSize; - int64_t ekvTail = need_pack && (kvTail % 2 != 0) ? kvTail + 1 : kvTail; - int64_t kv_padding_size = (kvSize - 1) / kvSplitSize * ekvSplitSize + ekvTail; - - // Allocate per thread temp buf (accumulate type) - int64_t _size_per_thread = - /* qk */ qSplitSize * kvSplitSize + - /* qk_max */ qSplitSize + - /* qk_sum */ qSplitSize + - /* dst */ qSplitSize * headSize_v; - - // Inputs/outputs buffers - const scalar_t* q_data = query; - const scalar_t* k_data = key; - const scalar_t* v_data = value; - scalar_t* out_data = output; - - // Buffers to store accum results, padding query and transpose/packing key/value - {{template.codegen_allocate_buffer("buf_data", "accum_t", "num_thread*_size_per_thread")}} - {{template.codegen_allocate_buffer("buf_reduced_data", "scalar_t", "num_thread*qSplitSize*ekvSplitSize")}} - {{template.codegen_allocate_buffer("key_reorder_ptr", "scalar_t", "batchSize*num_head*eheadSize*kvSize")}} - {{template.codegen_allocate_buffer("value_reorder_ptr", "scalar_t", "batchSize*num_head*kv_padding_size*headSize_v")}} - {{template.codegen_allocate_buffer("transpose_buffer_ptr", "scalar_t", "num_thread*kvSplitSize*headSize")}} - {{template.codegen_allocate_buffer("query_padding_ptr", "scalar_t", "num_thread*qSplitSize*eheadSize")}} - - // Reorder K, V and transpose K - at::parallel_for(0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) { - int ompIdx = at::get_thread_num(); - int64_t i = 0, j = 0, l = 0, n = 0; - scalar_t* transpose_ptr = need_pack? transpose_buffer_ptr + ompIdx * kvSplitSize * headSize : nullptr; - at::native::data_index_init(begin, i, batchSize, j, num_head, l, kvSlice); - for ([[maybe_unused]] auto z : c10::irange(begin, end)) { - n = l * kvSplitSize; - int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n); - auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i; - auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j; - auto kv_block_num = n / cur_kvSplitSize; - auto kv_block_offset = n - kv_block_num * cur_kvSplitSize; - // getting kv indices by [BS, Head, 1, kv_block_num] - auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i; - auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j; - auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB + - j_kvi * kviStrideH + kv_block_num; - auto k_addr = - k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN; - auto v_addr = - v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN; - if (use_kv_indice) { - k_addr = - k_data + i_kv * kStrideB + j_kv * kStrideH + - (*kv_logical_data * cur_kvSplitSize + kv_block_offset) * kStrideN; - v_addr = - v_data + i_kv * vStrideB + j_kv * vStrideH + - (*kv_logical_data * cur_kvSplitSize + kv_block_offset) * vStrideN; - } - if (need_pack) { - // transpose [cur_kvSplitSize, headSize] -> [headSize, cur_kvSplitSize] - at::native::utils::transpose( - cur_kvSplitSize, - headSize, - /* src_ptr */ - reinterpret_cast(k_addr), - /* ld_src */ kStrideN, - /* dst */ reinterpret_cast(transpose_ptr), - /* ld_dst */ cur_kvSplitSize); - - // Pack [headSize, cur_kvSplitSize] - {{kernel.kernel_name}}_pack_vnni2( - /* src */ reinterpret_cast(transpose_ptr), - /* dst */ reinterpret_cast(key_reorder_ptr + i * num_head * eheadSize * kvSize + - j * eheadSize * kvSize + n * eheadSize), - /* ld_src */ cur_kvSplitSize, - /* K */ headSize, - /* N */ cur_kvSplitSize); - - // Pack [cur_kvSplitSize, headSize_v] - {{kernel.kernel_name}}_pack_vnni2( - /* src */ reinterpret_cast(v_addr), - /* dst */ reinterpret_cast(value_reorder_ptr + - i * num_head * kv_padding_size * headSize_v + - j * kv_padding_size * headSize_v + n * headSize_v), - /* ld_src */ vStrideN, - /* K */ cur_kvSplitSize, - /* N */ headSize_v); - } else { - using trans_t = std::conditional_t, uint16_t, float>; - at::native::utils::transpose( - cur_kvSplitSize, - headSize, - /* src_ptr */ - reinterpret_cast(k_addr), - /* ld_src */ kStrideN, - /* dst */ reinterpret_cast(key_reorder_ptr + i * num_head * eheadSize * kvSize + - j * eheadSize * kvSize + n * eheadSize), - /* ld_dst */ cur_kvSplitSize); - } - // Move to the next query - at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice); - } - }); - - // Attention loop below - at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) { - int64_t i = 0, j = 0, k = 0; - at::native::data_index_init(begin, i, batchSize, j, num_head, k, qSlice); - int ompIdx = at::get_thread_num(); - accum_t* buf_ptr = buf_data + ompIdx * _size_per_thread; - accum_t* qk_data = buf_ptr; - accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize; - accum_t* qk_sum_data = qk_max_data + qSplitSize; - accum_t* dst_data = qk_sum_data + qSplitSize; - scalar_t *qk_reduced_data = - is_reduced_type - ? buf_reduced_data + ompIdx * qSplitSize * ekvSplitSize - : nullptr; - scalar_t* query_t_padding_ptr = (!headSize_even && need_pack) - ? query_padding_ptr + ompIdx * qSplitSize * eheadSize - : nullptr; - - for ([[maybe_unused]] auto z : c10::irange(begin, end)) { - int64_t m = k * qSplitSize; - int64_t cur_qSplitSize = std::min(qSplitSize, qSize - m); - // Initialize max and sum - {{kernel.kernel_name}}_fill_stub(qk_max_data, - -std::numeric_limits::infinity(), cur_qSplitSize); - {{kernel.kernel_name}}_fill_stub(qk_sum_data, - static_cast(0), cur_qSplitSize); - - if (!headSize_even && need_pack) { - // Pad query if headSize is not even - {{kernel.kernel_name}}_copy_value_with_pad( - q_data + i * qStrideB + j * qStrideH + m * qStrideM, - query_t_padding_ptr, - cur_qSplitSize, - headSize, - cur_qSplitSize, - eheadSize, - qStrideM - ); - } - for (int64_t n = 0; n < kvSize; n += kvSplitSize) { - int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n); - int64_t cur_ekvSplitSize = (need_pack && cur_kvSplitSize % 2 != 0) ? cur_kvSplitSize + 1 : cur_kvSplitSize; - - // Calculate scale * q @ k.T - auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i; - auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j; - auto kv_block_num = n / kvBlockSize; - auto kv_block_offset = n - kv_block_num * kvBlockSize; - // getting kv indices by [BS, Head, 1, kv_block_num] - auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i; - auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j; - auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB + - j_kvi * kviStrideH + kv_block_num; - if (!need_pack) { - auto k_addr_t = key_reorder_ptr + i * num_head * eheadSize * kvSize + - j * eheadSize * kvSize + n * eheadSize; - // TODO: use the micro-gemm template instead of brgemm API - at::native::cpublas::brgemm( - cur_qSplitSize, - cur_kvSplitSize, - eheadSize, - qStrideM, - cur_kvSplitSize, - cur_kvSplitSize, - false, - q_data + i * qStrideB + j * qStrideH + - m * qStrideM, - k_addr_t, - qk_data, - need_pack); - } else { - at::native::cpublas::brgemm( - cur_qSplitSize, - cur_kvSplitSize, - eheadSize, - headSize_even ? qStrideM : eheadSize, - cur_kvSplitSize, - cur_kvSplitSize, - false, - !headSize_even - ? query_t_padding_ptr - : q_data + i * qStrideB + j * qStrideH + m * qStrideM, - key_reorder_ptr + i * num_head * eheadSize * kvSize + - j * eheadSize * kvSize + n * eheadSize, - qk_data, - need_pack); - } - - {{kernel.kernel_name}}_mul_scale_kernel(qk_data, scaling_factor, cur_qSplitSize*cur_kvSplitSize); - -{%- if score_mod and mask_mod %} - // TODO: vectorization optimization for below score and mask codegen functions - // apply score mod function - for (int64_t row = 0; row < cur_qSplitSize; ++row) { - for (int64_t col = 0; col < cur_kvSplitSize; col++) { - std::vector b_idx = {i}; - std::vector h_idx = {j}; - std::vector q_idx = {m+row}; - int64_t phisical_kv_idx = n+col; - if (use_kv_indice) { - phisical_kv_idx= *kv_logical_data * kvBlockSize + col; - } - std::vector kv_idx = {phisical_kv_idx}; - accum_t* in_ptr0 = qk_data + row * cur_kvSplitSize + col; - auto in_ptr1 = b_idx.data(); - auto in_ptr2 = h_idx.data(); - auto in_ptr3 = q_idx.data(); - auto in_ptr4 = kv_idx.data(); - {{ template.generate_other_buffer("score_others", 0, "len_score_other", kernel.args) }} - accum_t* out_ptr{{score_buf_idx}} = in_ptr0; - {{ template.modification(score_mod, score_buf_name, score_buf_idx) }} - } - } - // Apply block mask, fill unused with -inf - for (int64_t row = 0; row < cur_qSplitSize; ++row) { - for (int64_t col = 0; col < cur_kvSplitSize; col++) { - std::vector b_idx = {i}; - std::vector h_idx = {j}; - std::vector q_idx = {m+row}; - int64_t phisical_kv_idx = n+col; - if (use_kv_indice) { - phisical_kv_idx= *kv_logical_data * kvBlockSize + col; - } - std::vector kv_idx = {phisical_kv_idx}; - accum_t* qk_block = qk_data + row * cur_kvSplitSize + col; - auto in_ptr1 = b_idx.data(); - auto in_ptr2 = h_idx.data(); - auto in_ptr3 = q_idx.data(); - auto in_ptr4 = kv_idx.data(); - {{ template.generate_other_buffer("mask_others", -1, "len_mask_other", kernel.args) }} - std::vector temp = {0}; - int64_t* out_ptr{{mask_buf_idx}} = temp.data(); - {{ template.modification(mask_mod, mask_buf_name, mask_buf_idx) }} - *qk_block = *out_ptr{{mask_buf_idx}} != 0 - ? *qk_block - : -std::numeric_limits::infinity(); - } - } -{%- endif %} - // Update coefficients with Softmax - accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0; - for (int64_t row = 0; row < cur_qSplitSize; ++row) { - // apply scaling factor and max per row in fusion - {{kernel.kernel_name}}_mul_reduce_max_fusion_kernel( - qk_data + row * cur_kvSplitSize, - static_cast(1), - cur_kvSplitSize, - qk_data + row * cur_kvSplitSize, - tmp_max); - tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max; - if (tmp_max == -std::numeric_limits::infinity()) { - // to avoid `nan = exp2f(-inf - (-inf))` - {{kernel.kernel_name}}_fill_stub( - {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data) + row * cur_ekvSplitSize, - static_cast(0), cur_kvSplitSize); - } else { - tmp_sum = tmp_max; - // qk <- exp(qk - max) and sum per row - {{kernel.kernel_name}}_exp_reduce_sum_fusion_kernel( - qk_data + row * cur_kvSplitSize, cur_kvSplitSize, - {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data) + row * cur_ekvSplitSize, - tmp_sum); - // exp_tmp <- exp(max[row] - max) - exp_tmp = std::exp(qk_max_data[row] - tmp_max); - // sum[row] <- sum + exp_tmp * sum[row] - qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row]; - // max[row] <- max - qk_max_data[row] = tmp_max; - // dst <- dst * exp_tmp - if (n > 0) { - at::vec::map( - [exp_tmp](Vec x) { return x * Vec(exp_tmp); }, - dst_data + row * headSize_v, - dst_data + row * headSize_v, - headSize_v); - } - } - if (need_pack && cur_kvSplitSize % 2 != 0) { - // Pad: [qSplitSize, cur_kvSplitSize] -> [qSplitSize, cur_kvSplitSize + 1] - *(qk_reduced_data + row * (1 + cur_kvSplitSize) + cur_kvSplitSize) = scalar_t(0); - } - } - // Calculate Softmax(q @ k.T) @ v - if (!need_pack) { - auto v_addr = - v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN; - if (use_kv_indice) { - v_addr = - v_data + i_kv * vStrideB + j_kv * vStrideH + - (*kv_logical_data * kvBlockSize + kv_block_offset) * vStrideN; - } - at::native::cpublas::brgemm( - cur_qSplitSize, - headSize_v, - cur_ekvSplitSize, - cur_ekvSplitSize, - vStrideN, - headSize_v, - n > 0, - {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data), - v_addr, - dst_data, - need_pack); - } else { - int64_t psize = n / kvSplitSize * ekvSplitSize; - at::native::cpublas::brgemm( - cur_qSplitSize, - headSize_v, - cur_ekvSplitSize, - cur_ekvSplitSize, - headSize_v, - headSize_v, - n > 0, - qk_reduced_data, - value_reorder_ptr + - i * num_head * kv_padding_size * headSize_v + - j * kv_padding_size * headSize_v + psize * headSize_v, - dst_data, - need_pack); - } - } - // dst <- dst / sum[row] - // reorder MHA output with strides - for (int64_t row = 0; row < cur_qSplitSize; ++row) { - // Row sums for full masked out rows are 0, we set them to 1 - // in order to avoid NaNs in the output and instead set fully - // masked out rows to 0 - qk_max_data[row] = qk_max_data[row] == -std::numeric_limits::infinity() ? 0 : qk_max_data[row]; - qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row]; - accum_t sum_reciprocal = 1 / qk_sum_data[row]; - at::vec::map( - [sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); }, - out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM, - dst_data + row * headSize_v, - headSize_v); - } - // Move to the next query - at::native::data_index_step(i, batchSize, j, num_head, k, qSlice); - } - - at::native::cpublas::brgemm_release(need_pack); - - }); -} -""" - - -class CppFlexAttentionTemplate(CppTemplate): - def __init__( - self, - input_nodes, - layout: ir.Layout, - scale, - score_mod, - mask_mod, - kv_block_size, - has_other_buffer, - no_full_kv_block, - fake_buffers, - len_score_other, - len_mask_other, - kernel_input_name_to_buffer, - ) -> None: - assert layout.dtype in [torch.float, torch.bfloat16] - super().__init__("flex_attention", input_nodes, layout, parallel_num_threads()) - self.scale = scale - self.score_mod = score_mod - self.mask_mod = mask_mod - self.score_buf_name = ( - V.graph.register_buffer(self.score_mod) if self.score_mod else None - ) - self.mask_buf_name = ( - V.graph.register_buffer(self.mask_mod) if self.mask_mod else None - ) - - def get_idx(buf_name): - match = re.search(r"\d+", buf_name) - assert match, f"incorrect score buf name: {buf_name}" - return match.group() - - self.score_buf_idx = ( - get_idx(self.score_buf_name) if self.score_buf_name else None - ) - self.mask_buf_idx = get_idx(self.mask_buf_name) if self.mask_buf_name else None - self.kv_block_size = kv_block_size - self.has_other_buffer = has_other_buffer - self.no_full_kv_block = no_full_kv_block - self.other_buffer_input_offset = 1 - if self.no_full_kv_block: - self.other_buffer_input_offset = 0 - self.fake_buffers = fake_buffers - self.len_score_other = len_score_other - self.len_mask_other = len_mask_other - self.kernel_input_name_to_buffer = kernel_input_name_to_buffer - self.extra_sizevars = list( - { - val - for val in self.kernel_input_name_to_buffer.values() - if isinstance(val, sympy.Symbol) - } - ) - self.other_buf_start_idx = 5 - self.score_mod_other_buffers = ( - self.input_nodes[ - self.other_buf_start_idx - + self.other_buffer_input_offset : self.other_buf_start_idx - + self.other_buffer_input_offset - + self.len_score_other - ] - if self.has_other_buffer - else None - ) - self.mask_mod_other_buffers = ( - self.input_nodes[ - self.other_buf_start_idx - + self.other_buffer_input_offset - + self.len_score_other : - ] - if self.has_other_buffer - else None - ) - self.other_ptr_data = {} # type: ignore[var-annotated] - - def update_kernel_args(self, kernel_args): - kernel_args.update( - { - key: value - for key, value in self.kernel_input_name_to_buffer.items() - if not isinstance(value, sympy.Symbol) - } - ) - return kernel_args - - def generate_other_buffer(self, buf_list, start_offset, len_attr, kernel_args): - kernel_input_name_to_buffer_name = { - key: value if isinstance(value, sympy.Symbol) else value.get_name() - for key, value in self.kernel_input_name_to_buffer.items() - } - - def get_arg(name): - return kernel_input_name_to_buffer_name.get(name) - - def get_arg_name(name): - if isinstance(get_arg(name), sympy.Symbol): - return kernel_args.sizevars.get(get_arg(name)) - return kernel_args.input_buffers.get(get_arg(name)) - - if not self.has_other_buffer: - return "" - - if start_offset == -1: - start_offset = getattr(self, len_attr) - - length = getattr(self, len_attr) - for i in range(length): - pointer = f"in_ptr{self.other_buf_start_idx + start_offset + i}" - buffer_key = f"{buf_list}_{i}" - if pointer not in self.other_ptr_data: - self.other_ptr_data[pointer] = ( - get_arg_name(buffer_key), - get_arg(buffer_key), - ) - - return "\n".join( - f"auto {ptr} = {name};" for ptr, (name, _) in self.other_ptr_data.items() - ) - - def modification(self, subgraph_buffer, output_name, output_idx): - assert isinstance(subgraph_buffer, ir.ComputedBuffer) - subgraph_buffer_data = subgraph_buffer.data - from ..loop_body import LoopBody - from ..utils import sympy_index_symbol_with_prefix, SymT - from ..virtualized import V - from .cpp import CppKernelProxy, KernelGroup - - kernel_group = KernelGroup() - kernel_input_args = { - "score": "in_ptr0", - "b": "in_ptr1", - "h": "in_ptr2", - "q_idx": "in_ptr3", - "kv_idx": "in_ptr4", - } - if self.has_other_buffer: - kernel_input_args.update( - {arg: ptr for ptr, (_, arg) in self.other_ptr_data.items()} - ) - - kernel_output_args = {output_name: f"out_ptr{output_idx}"} - - args = kernel_group.args - for name, inp in kernel_input_args.items(): - args.input_buffers[name] = inp - - for name, inp in kernel_output_args.items(): - args.output_buffers[name] = inp - - for name in self.extra_sizevars: - args.sizevars[name] = f"k{name}" - - kernel_group.args = args - - cpp_kernel_proxy = CppKernelProxy(kernel_group) - bodies = [] - var_sizes_list = [] - - var_sizes = tuple([]) # type: ignore[var-annotated] # noqa: C409 - output_index = 0 - var_ranges = { - sympy_index_symbol_with_prefix(SymT.INDEX, i): sz - for i, sz in enumerate(var_sizes) - } - - def fn(*args): - V.ops.store( - output_name, - output_index, - subgraph_buffer_data.make_loader()(args).value, - ) - - body = LoopBody( - fn, - (list(var_ranges.keys())), - var_ranges, - list(var_ranges.keys()), - tuple(), - ) - - from ..loop_body import MemoryUsageType - - assert all( - mem.buffer_name in kernel_group.args.input_buffers - for mem in body.memory_usage[MemoryUsageType.LOAD] - ), "All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers" - - bodies.append(body) - var_sizes_list.append((var_sizes, ())) - - cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) - kernel_group.finalize_kernel(cpp_kernel_proxy, []) - return kernel_group.loops_code.getvalue() - - @staticmethod - def add_choices( - choices, - input_nodes, - layout, - scale, - score_mod, - mask_mod, - kv_block_size, - has_other_buffer, - no_full_kv_block, - fake_buffers, - len_score_other, - len_mask_other, - kernel_input_name_to_buffer, - ): - def preprocessor(input_nodes, layout): - return input_nodes, layout - - def postprocessor(output): - return output - - template = DataProcessorTemplateWrapper( - CppFlexAttentionTemplate, - preprocessor, - postprocessor, - input_nodes=input_nodes, - layout=layout, - scale=scale, - score_mod=score_mod, - mask_mod=mask_mod, - kv_block_size=kv_block_size, - has_other_buffer=has_other_buffer, - no_full_kv_block=no_full_kv_block, - fake_buffers=fake_buffers, - len_score_other=len_score_other, - len_mask_other=len_mask_other, - kernel_input_name_to_buffer=kernel_input_name_to_buffer, - ) - template.maybe_append_choice(choices) - return template - - def apply_score_mod(self, score, b, h, q_idx, kv_idx): - return self.score_mod.graph_module(score, b, h, q_idx, kv_idx).item() - - def render( # type: ignore[override,return] - self, - kernel, - template_buffer_node: Optional[ir.CppTemplateBuffer] = None, - epilogue_nodes: Optional[List[ir.IRNode]] = None, - **kwargs, - ) -> str: - if epilogue_nodes is not None and epilogue_nodes != []: - raise NotImplementedError( - "Unsupported for `epilogue_nodes` in CppFlexAttentionTemplate." - ) - # Query (Batch x Num_heads x Q_seq_len x Dim_per_head) - # -> (Batch x Q_seq_len x Num_heads x Dim_per_head) - # Key (Batch x Num_heads x KV_seq_len x Dim_per_head) - # -> (Batch x KV_seq_len x Num_heads x Dim_per_head) - # Value (Batch x Num_heads x KV_seq_len x Dim_per_head) - # -> (Batch x KV_seq_len x Num_heads x Dim_per_head) - - query = kernel.permute(self.input_nodes[0], [0, 2, 1, 3]) - key = kernel.permute(self.input_nodes[1], [0, 2, 1, 3]) - value = kernel.permute(self.input_nodes[2], [0, 2, 1, 3]) - - num_threads = parallel_num_threads() - buf_out = TensorBox.create(self.output_node) - if template_buffer_node is not None: - buf_out = template_buffer_node - options = dict( - query=query, - key=key, - value=value, - kv_num_blocks=self.input_nodes[3], - kv_indices=self.input_nodes[4], - full_kv_num_blocks=self.input_nodes[5] - if not self.no_full_kv_block - else None, - score_mod_other_buffers=self.score_mod_other_buffers, - mask_mod_other_buffers=self.mask_mod_other_buffers, - scale=self.scale, - accumulate_dtype=torch.float, - query_dtype=query.layout.dtype, - kvBlockSize=self.kv_block_size, - template=self, - output=buf_out, - kernel=kernel, - num_thread=num_threads, - score_mod=self.score_mod, - mask_mod=self.mask_mod, - score_buf_name=self.score_buf_name, - mask_buf_name=self.mask_buf_name, - score_buf_idx=self.score_buf_idx, - mask_buf_idx=self.mask_buf_idx, - ) - with contextlib.ExitStack() as stack: - for buf in self.fake_buffers: - stack.enter_context( - patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf)) - ) - return self._template_from_string(FLEX_ATTENTION_TEMPLATE).render(**options) - - def codegen_softmax_fusion(self, kernel_name: str): - # TODO: use inductor IR to rewrite those fusions - return self._template_from_string(SOFTMAX_FUSIONS).render( - dict(kernel_name=kernel_name) - ) - - def codegen_brgemm_pack_function(self, kernel_name: str): - # TODO: make them general for common bmm templates - return self._template_from_string(BRGEMM_PACK_FUNCTIONS).render( - dict(kernel_name=kernel_name) - ) - - def codegen_allocate_buffer(self, buffer_name: str, buffer_dtype, buffer_size): - return self._template_from_string(ALLOCATE_BUFFER).render( - dict( - buffer_name=buffer_name, - buffer_dtype=buffer_dtype, - buffer_size=buffer_size, - ) - ) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 9fb795aa0dc96..369aea4b7afd6 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -15,7 +15,6 @@ from .. import config from ..ir import ( - Buffer, ComputedBuffer, ExternKernel, FixedLayout, @@ -804,202 +803,6 @@ def create_indices_fake(x) -> torch.Tensor: from torch._inductor.kernel.flex_decoding import create_flex_decoding_kernel -from ..codegen.cpp_flex_attention_template import CppFlexAttentionTemplate - - -def lower_cpu( - query, - key, - value, - subgraph, - block_mask, - scale, - kernel_options, - score_mod_other_buffers, - mask_mod_other_buffers, -): - ( - _, # q_length - _, # kv_length - kv_num_blocks, - kv_indices, - full_kv_num_blocks, - full_kv_indices, - q_num_blocks, - q_indices, - full_q_num_blocks, - full_q_indices, - SPARSE_Q_BLOCK_SIZE, - SPARSE_KV_BLOCK_SIZE, - mask_graph, - ) = block_mask - - if kernel_options["OUTPUT_LOGSUMEXP"]: - raise NotImplementedError( - "torch.compile on CPU only supports inference and `return_lse` is not supported yet." - ) - - fake_buffers: List[Buffer] = [] # noqa: F821 - placeholder_inps = [ - create_placeholder(name, dtype, query.get_device()) - for name, dtype in [ - ("score", torch.float), - ("b", torch.int64), - ("h", torch.int64), - ("q_idx", torch.int64), - ("kv_idx", torch.int64), - ] - ] - subgraph_buffer = build_subgraph_buffer( - placeholder_inps + list(score_mod_other_buffers), subgraph - ) - if subgraph_buffer is not None: - if isinstance(subgraph_buffer, list): - for _buf in subgraph_buffer: - if _buf is not None: - _buf.freeze_layout() - else: - subgraph_buffer.freeze_layout() - mask_graph_placeholder_inps = [ - create_placeholder(name, dtype, query.get_device()) - for name, dtype in [ - ("b", torch.int64), - ("h", torch.int64), - ("q_idx", torch.int64), - ("kv_idx", torch.int64), - ] - ] - mask_graph_buffer = build_subgraph_buffer( - mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph - ) - - buffer_list = ( - placeholder_inps - + list(score_mod_other_buffers) - + mask_graph_placeholder_inps - + list(mask_mod_other_buffers) - ) - for item in buffer_list: - if isinstance(item, TensorBox): - fake_buffers.append(item.data.data) # type: ignore[attr-defined] - - ( - query, - key, - value, - kv_num_blocks, - kv_indices, - full_kv_num_blocks, - full_kv_indices, - q_num_blocks, - q_indices, - full_q_num_blocks, - full_q_indices, - ) = maybe_realize( - [ - query, - key, - value, - kv_num_blocks, - kv_indices, - full_kv_num_blocks, - full_kv_indices, - q_num_blocks, - q_indices, - full_q_num_blocks, - full_q_indices, - ] - ) - - if len({query.get_name(), key.get_name(), value.get_name()}) != 3: - raise NotImplementedError( - "Unsupported for now if query, key, value are the same buffer." - ) - if query.get_dtype() not in [torch.float, torch.bfloat16]: - raise NotImplementedError( - "`torch.float` and `torch.bfloat16` are supported in FlexAttention for CPU device. " - f"Found input tensors are `{query.get_dtype()}`." - ) - score_mod_other_buffers = maybe_realize(score_mod_other_buffers) - mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) - Bq, Hq, seq_len_q, qk_head_dim = query.get_size() - Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() - B = Bq - - # Construct output layout with strides matching the query. - out_size = [B, Hq, seq_len_q, v_head_dim] - fill_order = get_fill_order(query.get_stride()) - out_strides = construct_strides(out_size, fill_order) - - layout = FixedLayout( - query.get_device(), - query.get_dtype(), - [B, Hq, seq_len_q, v_head_dim], - stride=[sympy.sympify(s) for s in out_strides], - ) - _choices: List[Any] = [] - input_nodes = [query, key, value, kv_num_blocks, kv_indices] - if not full_kv_num_blocks: - no_full_kv_block = True - else: - no_full_kv_block = False - input_nodes += [full_kv_num_blocks] - has_other_buffer = False - kernel_input_name_to_buffer = {} - if score_mod_other_buffers or mask_mod_other_buffers: - has_other_buffer = True - - for prefix, buffers in [ - ("score_others", score_mod_other_buffers), - ("mask_others", mask_mod_other_buffers), - ]: - kernel_input_name_to_buffer.update( - {f"{prefix}_{i}": buf for i, buf in enumerate(buffers)} - ) - input_nodes += [ - value - for value in kernel_input_name_to_buffer.values() - if not isinstance(value, sympy.Symbol) - ] - - skip_mask_score = kernel_options.get("SKIP_MASK_SCORE", False) - # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. - SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) - SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) - assert V.graph.sizevars.evaluate_expr( - sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE)) - ), "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask." - assert V.graph.sizevars.evaluate_expr( - sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE)) - ), "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask." - CppFlexAttentionTemplate.add_choices( - choices=_choices, - input_nodes=input_nodes, - layout=layout, - scale=scale, - score_mod=None if skip_mask_score else subgraph_buffer, - mask_mod=None if skip_mask_score else mask_graph_buffer, - kv_block_size=SPARSE_KV_BLOCK_SIZE, - has_other_buffer=has_other_buffer, - no_full_kv_block=no_full_kv_block, - fake_buffers=fake_buffers, - len_score_other=len(score_mod_other_buffers), - len_mask_other=len(mask_mod_other_buffers), - kernel_input_name_to_buffer=kernel_input_name_to_buffer, - ) - inputs_for_autotuning = [ - query, - key, - value, - ] - res = autotune_select_algorithm( - "flex_attention", - _choices, - inputs_for_autotuning, - layout, - ) - return (res,) - # TODO: We probably also need a layout constraint? @register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None) @@ -1014,20 +817,6 @@ def flex_attention( score_mod_other_buffers, mask_mod_other_buffers, ): - if query.get_device().type == "cpu": - return lower_cpu( - query, - key, - value, - subgraph, - block_mask, - scale, - kernel_options, - score_mod_other_buffers, - mask_mod_other_buffers, - ) - - # below is cuda path if device is not cpu ( _, # q_length _, # kv_length @@ -1043,7 +832,6 @@ def flex_attention( SPARSE_KV_BLOCK_SIZE, mask_graph, ) = block_mask - placeholder_inps = [ create_placeholder(name, dtype, query.get_device()) for name, dtype in [ diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index c103081e04eed..0556697fa0e74 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -1092,15 +1092,6 @@ def _apply_kernel_options( # we always write unless in no_grad output_logsumexp = torch.is_grad_enabled() kernel_options["OUTPUT_LOGSUMEXP"] = output_logsumexp - any_inputs_on_cpu_device = ( - query.device.type == "cpu" - or key.device.type == "cpu" - or value.device.type == "cpu" - ) - if any_inputs_on_cpu_device: - # CPU with torch.compile now supports infernece, and will not return lse - # TODO: support CPU for training and return lse - kernel_options["OUTPUT_LOGSUMEXP"] = False return kernel_options @@ -1123,12 +1114,12 @@ def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor): def _validate_device(query: Tensor, key: Tensor, value: Tensor): - """TODO: Remove once non cuda/cpu devices support is added + """TODO: Remove once non cuda device support is added We only need to check query since we have already that q,k,v are on the same device """ - if query.device.type != "cuda" and query.device.type != "cpu": + if query.device.type != "cuda": raise ValueError( - "FlexAttention is only supported on CUDA or CPU devices. " + "FlexAttention is only supported on CUDA devices. " f"Found input tensors on {query.device.type} device." )