diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 93ac18c993d..12eb5fd13dc 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -108,7 +108,7 @@ runtime.python_library( "source_transformation/pre_quantization.py", "source_transformation/prune_vocab.py", "source_transformation/quantize.py", - "source_transformation/quantized_kv_cache.py", + "source_transformation/custom_kv_cache.py", "source_transformation/rms_norm.py", "source_transformation/rope.py", "source_transformation/sdpa.py", @@ -208,9 +208,9 @@ runtime.python_library( ) runtime.python_library( - name = "quantized_kv_cache", + name = "custom_kv_cache", srcs = [ - "source_transformation/quantized_kv_cache.py", + "source_transformation/custom_kv_cache.py", ], _is_external_target = True, visibility = ["//executorch/..."], @@ -240,7 +240,7 @@ runtime.python_test( "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", ], deps = [ - ":quantized_kv_cache", + ":custom_kv_cache", "//caffe2:torch", "//executorch/examples/models/llama:llama_transformer", ], @@ -255,7 +255,7 @@ runtime.python_test( "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", ], deps = [ - ":quantized_kv_cache", + ":custom_kv_cache", ":sdpa", "//caffe2:torch", "//executorch/examples/models/llama:llama_transformer", diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 249a25f23c4..01179e8ee56 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -59,14 +59,14 @@ ) from .source_transformation.attention import replace_attention_to_attention_sha +from .source_transformation.custom_kv_cache import ( + replace_kv_cache_with_custom_kv_cache, + replace_kv_cache_with_quantized_kv_cache, +) from .source_transformation.quantize import ( get_quant_embedding_transform, get_quant_weight_transform, ) -from .source_transformation.quantized_kv_cache import ( - replace_kv_cache_with_custom_kv_cache, - replace_kv_cache_with_quantized_kv_cache, -) from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py similarity index 100% rename from examples/models/llama/source_transformation/quantized_kv_cache.py rename to examples/models/llama/source_transformation/custom_kv_cache.py diff --git a/examples/models/llama/source_transformation/test_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_quantized_kv_cache.py index 4252518a4ee..07c8e1bf9a0 100644 --- a/examples/models/llama/source_transformation/test_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_quantized_kv_cache.py @@ -10,7 +10,7 @@ from executorch.examples.models.llama.attention import KVCache -from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( +from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( QuantizedCacheType, QuantizedKVCache, ) diff --git a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py index 35c88e10b6b..b2c93d7d93d 100644 --- a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py @@ -10,7 +10,7 @@ from executorch.examples.models.llama.attention import KVCache -from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( +from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( CustomKVCache, QuantizedCacheType, QuantizedKVCache, diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index 63ae0f4a118..5fcddb610b7 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -20,13 +20,13 @@ build_args_parser, get_quantizer_and_quant_params, ) +from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( + replace_kv_cache_with_custom_kv_cache, +) from executorch.examples.models.llama.source_transformation.quantize import ( EmbeddingQuantHandler, get_quant_weight_transform, ) -from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( - replace_kv_cache_with_custom_kv_cache, -) from executorch.examples.models.llama.source_transformation.sdpa import ( replace_sdpa_with_custom_op, ) diff --git a/examples/models/llava/model.py b/examples/models/llava/model.py index 6ce4b701bbe..351356607c8 100644 --- a/examples/models/llava/model.py +++ b/examples/models/llava/model.py @@ -15,7 +15,7 @@ from executorch.examples.models.llama.llama_transformer import Transformer from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( +from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( replace_kv_cache_with_custom_kv_cache, ) from executorch.examples.models.llama.source_transformation.sdpa import (