From e43f794cffca003dee1aba475fd0c7031f27a50a Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 9 Apr 2025 21:15:05 -0700 Subject: [PATCH 1/2] [Executorch][SDPA] Refactor + Make quantized sdpa handle sequence at dim 1 or 2 Pull Request resolved: https://github.com/pytorch/executorch/pull/9943 For quantized SDPA we want to evaluate performance impact of having seq at dim 1 as well as dim 2. This diff refactors the code to enable this. The same should be done also for float SDPA but left for future. ghstack-source-id: 277233484 @exported-using-ghexport Differential Revision: [D71833060](https://our.internmc.facebook.com/intern/diff/D71833060/) --- extension/llm/custom_ops/op_sdpa.cpp | 35 +++--- extension/llm/custom_ops/op_sdpa.h | 1 + extension/llm/custom_ops/op_sdpa_aot.cpp | 18 ++- extension/llm/custom_ops/op_sdpa_impl.h | 119 ++++++++++++------ .../llm/custom_ops/test_quantized_sdpa.py | 78 +++++++++++- 5 files changed, 189 insertions(+), 62 deletions(-) diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 391d2ab0646..c5c9b79b280 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -264,14 +264,14 @@ Tensor& flash_attention_kernel_out( InvalidArgument, output); - auto q_seq_len = query.size(2); + auto seq_len = query.size(2); ET_SWITCH_FLOAT_TYPES( query.scalar_type(), ctx, "flash_attention", CTYPE, [&] { // TODO we need to re-evaluate this for ARM CPUs // And there can be many so instead of templatizing // we might consider another appraoch - if (q_seq_len >= 768) { + if (seq_len >= 768) { sdpa::impl::cpu_flash_attention( output, query, @@ -287,7 +287,7 @@ Tensor& flash_attention_kernel_out( nullopt, nullopt, nullopt); - } else if (q_seq_len >= 192) { + } else if (seq_len >= 192) { sdpa::impl::cpu_flash_attention( output, query, @@ -341,7 +341,8 @@ Tensor& custom_sdpa_out_impl( const optional& k_zero_points = nullopt, const optional& k_scales = nullopt, const optional& v_zero_points = nullopt, - const optional& v_scales = nullopt) { + const optional& v_scales = nullopt, + bool is_seq_at_dim_2 = false) { ET_KERNEL_CHECK_MSG( ctx, !attn_mask.has_value() || !is_causal, @@ -357,13 +358,15 @@ Tensor& custom_sdpa_out_impl( "Invalid arguments"); int64_t seq_len = q.size(1); - auto q_seq_len = q.size(1); + SeqDim seq_dim{SeqDim::TWO}; + if (!is_seq_at_dim_2) { + seq_dim = SeqDim::ONE; + } - bool is_seq_at_dim_1{true}; if (q.scalar_type() == ScalarType::Char) { - is_seq_at_dim_1 = false; - seq_len = q.size(2); - q_seq_len = q.size(2); + if (seq_dim == SeqDim::TWO) { + seq_len = q.size(2); + } ET_KERNEL_CHECK_MSG( ctx, q_scales.has_value() && q_zero_points.has_value() && @@ -412,7 +415,7 @@ Tensor& custom_sdpa_out_impl( // TODO we need to re-evaluate this for ARM CPUs // And there can be many so instead of templatizing // we might consider another appraoch - if (q_seq_len >= 768) { + if (seq_len >= 768) { sdpa::impl::cpu_flash_attention( output, q, @@ -428,10 +431,10 @@ Tensor& custom_sdpa_out_impl( k_scales, // k_scales v_zero_points, // v_zero_points v_scales, // v_scales - is_seq_at_dim_1, /* is_seq_at_dim_1 */ + seq_dim, /* seq_dim */ start_pos, num_keys_for_causal_attention); - } else if (q_seq_len >= 192) { + } else if (seq_len >= 192) { sdpa::impl::cpu_flash_attention( output, q, @@ -447,7 +450,7 @@ Tensor& custom_sdpa_out_impl( k_scales, // k_scales v_zero_points, // v_zero_points v_scales, // v_scales - is_seq_at_dim_1, /* is_seq_at_dim_1 */ + seq_dim, /* seq_dim */ start_pos, num_keys_for_causal_attention); } else { @@ -466,7 +469,7 @@ Tensor& custom_sdpa_out_impl( k_scales, // k_scales v_zero_points, // v_zero_points v_scales, // v_scales - is_seq_at_dim_1, /* is_seq_at_dim_1 */ + seq_dim, /* seq_dim */ start_pos, num_keys_for_causal_attention); } @@ -492,6 +495,7 @@ Tensor& custom_quantized_sdpa_out( const optional& k_scales, const optional& v_zero_points, const optional& v_scales, + const bool is_seq_at_dim_2, Tensor& output) { return custom_sdpa_out_impl( ctx, @@ -509,7 +513,8 @@ Tensor& custom_quantized_sdpa_out( k_zero_points, k_scales, v_zero_points, - v_scales); + v_scales, + is_seq_at_dim_2); } #endif // ENABLE_CUSTOM_QUANTIZED_SDPA diff --git a/extension/llm/custom_ops/op_sdpa.h b/extension/llm/custom_ops/op_sdpa.h index 92b8a41b706..3deb27b3989 100644 --- a/extension/llm/custom_ops/op_sdpa.h +++ b/extension/llm/custom_ops/op_sdpa.h @@ -74,6 +74,7 @@ Tensor& custom_quantized_sdpa_out( const optional& k_scales, const optional& v_zero_points, const optional& v_scales, + const bool is_seq_at_dim_1, Tensor& output); #endif // ENABLE_CUSTOM_QUANTIZED_SDPA } // namespace native diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index a3adcbbf866..2da915a19b8 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -96,6 +96,7 @@ Tensor& custom_quantized_sdpa_out_no_context( const optional k_scales, const optional v_zero_points, const optional v_scales, + const bool is_seq_at_dim_2, Tensor& output); at::Tensor custom_quantized_sdpa_aten( @@ -115,7 +116,8 @@ at::Tensor custom_quantized_sdpa_aten( const std::optional& k_zero_points, const std::optional& k_scales, const std::optional& v_zero_points, - const std::optional& v_scales); + const std::optional& v_scales, + const bool is_seq_at_dim_2); #endif // ENABLE_CUSTOM_QUANTIZED_SDPA Tensor& update_cache_out_no_context( @@ -258,6 +260,7 @@ Tensor& custom_quantized_sdpa_out_no_context( const optional k_scales, const optional v_zero_points, const optional v_scales, + const bool is_seq_at_dim_2, Tensor& output) { executorch::aten::RuntimeContext context{}; return torch::executor::native::custom_quantized_sdpa_out( @@ -276,6 +279,7 @@ Tensor& custom_quantized_sdpa_out_no_context( k_scales, v_zero_points, v_scales, + is_seq_at_dim_2, output); } @@ -296,9 +300,10 @@ at::Tensor custom_quantized_sdpa_aten( const std::optional& k_zero_points, const std::optional& k_scales, const std::optional& v_zero_points, - const std::optional& v_scales) { + const std::optional& v_scales, + const bool is_seq_at_dim_2) { auto output = at::empty(q.sizes()); - WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 14) + WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 15) (q, k, v, @@ -313,6 +318,7 @@ at::Tensor custom_quantized_sdpa_aten( k_scales, v_zero_points, v_scales, + is_seq_at_dim_2, output); return output; } @@ -371,13 +377,13 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " "float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, " "Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, " - "Tensor? v_scales=None) -> Tensor"); + "Tensor? v_scales=None, bool is_seq_at_dim_2=False) -> Tensor"); m.def( "custom_quantized_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, " "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " "float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, " "Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, " - "Tensor? v_scales=None, *, Tensor(a!) out) -> Tensor(a!)"); + "Tensor? v_scales=None, bool is_seq_at_dim_2=False, *, Tensor(a!) out) -> Tensor(a!)"); #endif // ENABLE_CUSTOM_QUANTIZED_SDPA } @@ -404,6 +410,6 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { m.impl( "custom_quantized_sdpa.out", WRAP_TO_ATEN( - torch::executor::native::custom_quantized_sdpa_out_no_context, 14)); + torch::executor::native::custom_quantized_sdpa_out_no_context, 15)); #endif // ENABLE_CUSTOM_QUANTIZED_SDPA } diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index 5a0fb708220..7607a1e283d 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -32,6 +32,8 @@ namespace executor { namespace native { +enum class SeqDim { ONE = 1, TWO }; + namespace sdpa::impl { struct MaybeQuantizedMatrixData { @@ -39,6 +41,8 @@ struct MaybeQuantizedMatrixData { const int8_t* zero_points{nullptr}; const float* scales{nullptr}; int64_t m = 0, n = 0; + const int64_t zero_points_stride{1}; + const int64_t scales_stride{1}; ScalarType dtype{ScalarType::Float}; MaybeQuantizedMatrixData() = default; MaybeQuantizedMatrixData( @@ -47,12 +51,15 @@ struct MaybeQuantizedMatrixData { const float* scales_, int64_t m_, int64_t n_, + int64_t qparams_stride, ScalarType dtype_) : data(data_), zero_points(zero_points_), scales(scales_), m(m_), n(n_), + zero_points_stride(qparams_stride), + scales_stride(qparams_stride), dtype(dtype_) {} }; @@ -91,8 +98,9 @@ void _q_at_k_gemm( static_cast(k_data.zero_points), static_cast(q_data.scales), static_cast(k_data.scales), - 1, - 1); + // LHS and RHS are assumed to have same stride for qparams + q_data.zero_points_stride, + k_data.zero_points_stride); } else { ET_CHECK_MSG( false, "Accumulation in dtype other than float not supported yet"); @@ -152,7 +160,7 @@ void _qk_at_v_gemm( static_cast(v_data.zero_points), static_cast(v_data.scales), beta, - 1); + v_data.zero_points_stride); } else { ET_CHECK_MSG( false, "Accumulation in dtype other than float not supported yet"); @@ -351,6 +359,40 @@ sdpa_with_kv_cache does not use attn_mask. TODO: Just handle conversion of bool mask to float */ +/** + * @brief Implements Flash Attention algorithm on CPU + * + * This function computes scaled dot-product attention with optimizations for + CPU. + * It supports both regular and quantized attention computation. + * + * @tparam scalar_t The data type for computation (e.g., float) + * @tparam q_split_size Block size for query matrix in tiling algorithm + * @tparam kv_split_size Block size for key/value matrices in tiling algorithm + * + * @param output Output tensor to store attention results + * @param query Query tensor [Batch x Num_heads x Q_seq_len x Dim_per_head] + * @param key Key tensor [Batch x Num_heads_kv x KV_seq_len x Dim_per_head] + * @param value Value tensor [Batch x Num_heads_kv x KV_seq_len x Dim_per_head] + * @param dropout_p Dropout probability (not used in current implementation) + * @param is_causal Whether to apply causal mask (lower triangular) + * @param attn_mask Optional explicit attention mask + * @param scale Optional custom scaling factor (default: 1/sqrt(head_dim)) + * @param q_zero_points Optional zero points for quantized query + * @param q_scales Optional scales for quantized query + * @param k_zero_points Optional zero points for quantized key + * @param k_scales Optional scales for quantized key + * @param v_zero_points Optional zero points for quantized value + * @param v_scales Optional scales for quantized value + * @param seq_dim Which dimension is sequence dimension. + If SeqDim::One, then query, key, value are + expected to be in shape [Batch x Q_seq_len x Dim_per_head x Num_heads] and + output is expected to be in shape [Batch x Q_seq_len x Dim_per_head x + Num_heads] + * @param start_pos Starting position for causal masking in generation + * @param num_keys_for_causal_attention Number of keys to consider for causal + attention (-1 for all) + */ template void cpu_flash_attention( Tensor& output, @@ -367,22 +409,10 @@ void cpu_flash_attention( const optional& k_scales, const optional& v_zero_points, const optional& v_scales, - bool is_seq_at_dim_1 = false, + const SeqDim seq_dim = SeqDim::TWO, const int64_t start_pos = 0, const int64_t num_keys_for_causal_attention = -1) { (void)dropout_p; - // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) - // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) - // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) - - /* - // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) - at::Tensor query = q.transpose(1, 2); - // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) - at::Tensor key = k.transpose(1, 2); - // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) - at::Tensor value = v.transpose(1, 2); - */ // Without this we have out-of-bounds writes for // causal masking @@ -408,7 +438,7 @@ void cpu_flash_attention( int64_t kvSize = value.size(2); int64_t num_heads_kv = key.size(1); - if (is_seq_at_dim_1) { + if (seq_dim == SeqDim::ONE) { num_head = query.size(2); num_heads_kv = key.size(2); qSize = query.size(1); @@ -466,7 +496,7 @@ void cpu_flash_attention( int64_t qStrideH = strides[1]; int64_t qStrideM = strides[2]; - if (is_seq_at_dim_1) { + if (seq_dim == SeqDim::ONE) { qStrideH = strides[2]; qStrideM = strides[1]; } @@ -476,7 +506,7 @@ void cpu_flash_attention( int64_t kStrideH = strides[1]; int64_t kStrideN = strides[2]; - if (is_seq_at_dim_1) { + if (seq_dim == SeqDim::ONE) { kStrideH = strides[2]; kStrideN = strides[1]; } @@ -486,7 +516,7 @@ void cpu_flash_attention( int64_t vStrideH = strides[1]; int64_t vStrideN = strides[2]; - if (is_seq_at_dim_1) { + if (seq_dim == SeqDim::ONE) { vStrideH = strides[2]; vStrideN = strides[1]; } @@ -502,20 +532,36 @@ void cpu_flash_attention( int64_t v_quant_params_StrideN = 0; if (is_quantized_sdpa) { - strides = q_zero_points.value().strides(); - q_quant_params_StrideB = strides[0]; - q_quant_params_StrideH = strides[1]; - q_quant_params_StrideM = strides[2]; - - strides = k_zero_points.value().strides(); - k_quant_params_StrideB = strides[0]; - k_quant_params_StrideH = strides[1]; - k_quant_params_StrideN = strides[2]; - - strides = v_zero_points.value().strides(); - v_quant_params_StrideB = strides[0]; - v_quant_params_StrideH = strides[1]; - v_quant_params_StrideN = strides[2]; + auto q_strides = q_zero_points.value().strides(); + q_quant_params_StrideB = q_strides[0]; + q_quant_params_StrideH = q_strides[1]; + q_quant_params_StrideM = q_strides[2]; + + auto k_strides = k_zero_points.value().strides(); + k_quant_params_StrideB = k_strides[0]; + k_quant_params_StrideH = k_strides[1]; + k_quant_params_StrideN = k_strides[2]; + + auto v_strides = v_zero_points.value().strides(); + v_quant_params_StrideB = v_strides[0]; + v_quant_params_StrideH = v_strides[1]; + v_quant_params_StrideN = v_strides[2]; + + ET_CHECK_MSG( + (v_quant_params_StrideN == k_quant_params_StrideN) && + (v_quant_params_StrideN == q_quant_params_StrideM), + "Quant params strides must be same for seq dim"); + + if (seq_dim == SeqDim::ONE) { + q_quant_params_StrideH = q_strides[2]; + q_quant_params_StrideM = q_strides[1]; + + k_quant_params_StrideH = k_strides[2]; + k_quant_params_StrideN = k_strides[1]; + + v_quant_params_StrideH = v_strides[2]; + v_quant_params_StrideN = v_strides[1]; + } } strides = output.strides(); @@ -523,7 +569,7 @@ void cpu_flash_attention( int64_t oStrideH = strides[1]; int64_t oStrideM = strides[2]; - if (is_seq_at_dim_1) { + if (seq_dim == SeqDim::ONE) { oStrideH = strides[2]; oStrideM = strides[1]; } @@ -679,6 +725,7 @@ void cpu_flash_attention( q_scales_ptr, qBlockSize, headSize, + q_quant_params_StrideM, query.scalar_type()); MaybeQuantizedMatrixData k_sub_matrix_data = MaybeQuantizedMatrixData( static_cast(k_sub_matrix_data_ptr), @@ -686,6 +733,7 @@ void cpu_flash_attention( k_scales_ptr, kvBlockSize, headSize, + k_quant_params_StrideN, key.scalar_type()); _q_at_k_gemm( qBlockSize, @@ -835,6 +883,7 @@ void cpu_flash_attention( v_scales_ptr, kvBlockSize, headSize, + v_quant_params_StrideN, value.scalar_type()); // Calculate Softmax(q @ k.T) @ v _qk_at_v_gemm( diff --git a/extension/llm/custom_ops/test_quantized_sdpa.py b/extension/llm/custom_ops/test_quantized_sdpa.py index f5540a4e614..f7b28e1508f 100644 --- a/extension/llm/custom_ops/test_quantized_sdpa.py +++ b/extension/llm/custom_ops/test_quantized_sdpa.py @@ -35,6 +35,7 @@ def setUp(self): self.float_dtype = torch.float32 self.q_shape = None self.kv_shape = None + self.is_seq_at_dim_2 = True def _scale_tensor(self, tensor, min_value, max_value, scale=True): normalized_tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) @@ -105,6 +106,10 @@ def _sdpa_ref( self.float_dtype, ) + if not self.is_seq_at_dim_2: + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() num_heads_q = q.size(1) num_heads_kv = k.size(1) seq_len = q.size(2) @@ -119,6 +124,8 @@ def _sdpa_ref( k = k.repeat_interleave(n_reps, dim=1) v = v.repeat_interleave(n_reps, dim=1) out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + if not self.is_seq_at_dim_2: + out = out.transpose(1, 2).contiguous() return out def _int_matmul( @@ -212,7 +219,7 @@ def _test_sdpa_common( seq_len, scale_tensors=False, atol=1e-5, - is_seq_at_dim_2=True, + is_seq_at_dim_2=False, ): # Range arbitrarily chosen to reproduce a numerical error on x86 in some of the long context tests tensor_scale_max = 15 @@ -221,9 +228,10 @@ def _test_sdpa_common( self.n_heads_q = n_heads_q self.head_dim = head_dim self.max_seq_len = max_seq_len + self.is_seq_at_dim_2 = is_seq_at_dim_2 seq_dim = 2 self.q_shape = (self.n_batch, self.n_heads_q, seq_len, self.head_dim) - self.kv_shape = (self.n_batch, self.n_heads_q, self.max_seq_len, self.head_dim) + self.kv_shape = (self.n_batch, self.n_heads_kv, self.max_seq_len, self.head_dim) if not is_seq_at_dim_2: seq_dim = 1 self.q_shape = (self.n_batch, seq_len, self.n_heads_q, self.head_dim) @@ -286,7 +294,6 @@ def _test_sdpa_common( quantized_dtype, ) - start_pos = 0 seq_len = q.size(seq_dim) attn_mask = self.mask[start_pos : start_pos + seq_len, :] attn_mask = attn_mask[:, : start_pos + seq_len] @@ -334,6 +341,7 @@ def _test_sdpa_common( k_scale_fp32, v_zero_point_int8, v_scale_fp32, + is_seq_at_dim_2, ) self.assertTrue(torch.allclose(ref_output, op_output, atol=atol)) # Following line crashes due to some weird issues in mkldnn with crash in mkl_sgemm with `wild jump` @@ -374,6 +382,7 @@ def _test_sdpa_common( k_scale_fp32, v_zero_point_int8, v_scale_fp32, + is_seq_at_dim_2, ) self.assertTrue(torch.allclose(ref_output, op_output, atol=atol)) @@ -393,6 +402,18 @@ def test_sdpa_with_custom_quantized(self): seq_len, True, atol=1e-4, + is_seq_at_dim_2=True, + ) + self._test_sdpa_common( + n_heads_kv, + n_heads_q, + head_dim, + max_seq_len, + start_pos, + seq_len, + True, + atol=1e-4, + is_seq_at_dim_2=False, ) def test_sdpa_with_custom_quantized_seq_len_1(self): @@ -403,7 +424,22 @@ def test_sdpa_with_custom_quantized_seq_len_1(self): seq_len = 1 start_pos = 0 self._test_sdpa_common( - n_heads_kv, n_heads_q, head_dim, max_seq_len, start_pos, seq_len + n_heads_kv, + n_heads_q, + head_dim, + max_seq_len, + start_pos, + seq_len, + is_seq_at_dim_2=True, + ) + self._test_sdpa_common( + n_heads_kv, + n_heads_q, + head_dim, + max_seq_len, + start_pos, + seq_len, + is_seq_at_dim_2=False, ) def test_sdpa_with_custom_quantized_seq_len_small(self): @@ -414,7 +450,22 @@ def test_sdpa_with_custom_quantized_seq_len_small(self): seq_len = 4 start_pos = 0 self._test_sdpa_common( - n_heads_kv, n_heads_q, head_dim, max_seq_len, start_pos, seq_len + n_heads_kv, + n_heads_q, + head_dim, + max_seq_len, + start_pos, + seq_len, + is_seq_at_dim_2=True, + ) + self._test_sdpa_common( + n_heads_kv, + n_heads_q, + head_dim, + max_seq_len, + start_pos, + seq_len, + is_seq_at_dim_2=False, ) def test_sdpa_with_custom_quantized_seq_len_llava_example(self): @@ -466,5 +517,20 @@ def test_sdpa_with_cache_mqa(self): seq_len = 24 start_pos = 0 self._test_sdpa_common( - n_heads_kv, n_heads_q, head_dim, max_seq_len, start_pos, seq_len + n_heads_kv, + n_heads_q, + head_dim, + max_seq_len, + start_pos, + seq_len, + is_seq_at_dim_2=True, + ) + self._test_sdpa_common( + n_heads_kv, + n_heads_q, + head_dim, + max_seq_len, + start_pos, + seq_len, + is_seq_at_dim_2=False, ) From aa0b2b8e8713225767623ae474ec95bccb4cab2f Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 9 Apr 2025 21:15:08 -0700 Subject: [PATCH 2/2] [Executorch][llama] Renamed quantized_kv_cache to custom_kv_cache Pull Request resolved: https://github.com/pytorch/executorch/pull/9944 Because old name was misnomer ghstack-source-id: 277233486 @exported-using-ghexport Differential Revision: [D71833067](https://our.internmc.facebook.com/intern/diff/D71833067/) --- examples/models/llama/TARGETS | 10 +++++----- examples/models/llama/export_llama_lib.py | 8 ++++---- .../{quantized_kv_cache.py => custom_kv_cache.py} | 0 .../source_transformation/test_quantized_kv_cache.py | 2 +- .../test_sdpa_with_quantized_kv_cache.py | 2 +- examples/models/llava/export_llava.py | 6 +++--- examples/models/llava/model.py | 2 +- 7 files changed, 15 insertions(+), 15 deletions(-) rename examples/models/llama/source_transformation/{quantized_kv_cache.py => custom_kv_cache.py} (100%) diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 93ac18c993d..12eb5fd13dc 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -108,7 +108,7 @@ runtime.python_library( "source_transformation/pre_quantization.py", "source_transformation/prune_vocab.py", "source_transformation/quantize.py", - "source_transformation/quantized_kv_cache.py", + "source_transformation/custom_kv_cache.py", "source_transformation/rms_norm.py", "source_transformation/rope.py", "source_transformation/sdpa.py", @@ -208,9 +208,9 @@ runtime.python_library( ) runtime.python_library( - name = "quantized_kv_cache", + name = "custom_kv_cache", srcs = [ - "source_transformation/quantized_kv_cache.py", + "source_transformation/custom_kv_cache.py", ], _is_external_target = True, visibility = ["//executorch/..."], @@ -240,7 +240,7 @@ runtime.python_test( "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", ], deps = [ - ":quantized_kv_cache", + ":custom_kv_cache", "//caffe2:torch", "//executorch/examples/models/llama:llama_transformer", ], @@ -255,7 +255,7 @@ runtime.python_test( "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", ], deps = [ - ":quantized_kv_cache", + ":custom_kv_cache", ":sdpa", "//caffe2:torch", "//executorch/examples/models/llama:llama_transformer", diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 249a25f23c4..01179e8ee56 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -59,14 +59,14 @@ ) from .source_transformation.attention import replace_attention_to_attention_sha +from .source_transformation.custom_kv_cache import ( + replace_kv_cache_with_custom_kv_cache, + replace_kv_cache_with_quantized_kv_cache, +) from .source_transformation.quantize import ( get_quant_embedding_transform, get_quant_weight_transform, ) -from .source_transformation.quantized_kv_cache import ( - replace_kv_cache_with_custom_kv_cache, - replace_kv_cache_with_quantized_kv_cache, -) from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py similarity index 100% rename from examples/models/llama/source_transformation/quantized_kv_cache.py rename to examples/models/llama/source_transformation/custom_kv_cache.py diff --git a/examples/models/llama/source_transformation/test_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_quantized_kv_cache.py index 4252518a4ee..07c8e1bf9a0 100644 --- a/examples/models/llama/source_transformation/test_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_quantized_kv_cache.py @@ -10,7 +10,7 @@ from executorch.examples.models.llama.attention import KVCache -from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( +from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( QuantizedCacheType, QuantizedKVCache, ) diff --git a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py index 35c88e10b6b..b2c93d7d93d 100644 --- a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py @@ -10,7 +10,7 @@ from executorch.examples.models.llama.attention import KVCache -from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( +from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( CustomKVCache, QuantizedCacheType, QuantizedKVCache, diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index 63ae0f4a118..5fcddb610b7 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -20,13 +20,13 @@ build_args_parser, get_quantizer_and_quant_params, ) +from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( + replace_kv_cache_with_custom_kv_cache, +) from executorch.examples.models.llama.source_transformation.quantize import ( EmbeddingQuantHandler, get_quant_weight_transform, ) -from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( - replace_kv_cache_with_custom_kv_cache, -) from executorch.examples.models.llama.source_transformation.sdpa import ( replace_sdpa_with_custom_op, ) diff --git a/examples/models/llava/model.py b/examples/models/llava/model.py index 6ce4b701bbe..351356607c8 100644 --- a/examples/models/llava/model.py +++ b/examples/models/llava/model.py @@ -15,7 +15,7 @@ from executorch.examples.models.llama.llama_transformer import Transformer from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( +from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( replace_kv_cache_with_custom_kv_cache, ) from executorch.examples.models.llama.source_transformation.sdpa import (