From 7d4092b2dfd89cd94c96119864b433167d2efde7 Mon Sep 17 00:00:00 2001 From: Guang Yang Date: Wed, 14 Aug 2024 16:59:54 -0700 Subject: [PATCH] Make StaticCache configurable at model construct time --- src/transformers/cache_utils.py | 41 +++++ .../generation/configuration_utils.py | 3 +- src/transformers/integrations/executorch.py | 140 ++++++++++++++++++ src/transformers/modeling_utils.py | 7 + tests/utils/test_cache_utils.py | 85 +++++------ 5 files changed, 230 insertions(+), 46 deletions(-) create mode 100644 src/transformers/integrations/executorch.py diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 56eb0c4080dde9..36b5cb5e860ae5 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -293,6 +293,47 @@ def validate(self): ) +@dataclass +class StaticCacheConfig(CacheConfig): + """ + Configuration class for static cache settings. + """ + + cache_implementation = "static" + + def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): + self.batch_size = batch_size + self.max_cache_len = max_cache_len + self.device = device + + def validate(self): + """Validates if the arguments passed are correct""" + + incorrect_arg_msg = ( + "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + # Check that the values are reasonable in general (nbits, axis) + # Later in QuantizedCache init we check if they are supported for that particular backend + if self.batch_size <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="batch_size", + correct_value="> 0", + found_value=self.batch_size, + ), + ) + + if self.max_cache_len <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="max_cache_len", + correct_value="> 0", + found_value=self.max_cache_len, + ), + ) + + class DynamicCache(Cache): """ A cache that grows dynamically as more tokens are generated. This is the default for generative models. diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index aa5e77ac681740..601bf90d7c183f 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -45,9 +45,10 @@ NEEDS_CACHE_CONFIG = {} if is_torch_available(): - from ..cache_utils import QuantizedCacheConfig + from ..cache_utils import QuantizedCacheConfig, StaticCacheConfig NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig + NEEDS_CACHE_CONFIG["static"] = StaticCacheConfig class GenerationMode(ExplicitEnum): diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py new file mode 100644 index 00000000000000..1c9cb067485b39 --- /dev/null +++ b/src/transformers/integrations/executorch.py @@ -0,0 +1,140 @@ +from packaging import version + +from transformers.testing_utils import ( + is_torch_available, + require_torch, +) + + +if is_torch_available(): + import torch + + from transformers import ( + PreTrainedModel, + StaticCache, + ) + + +@require_torch +class TorchExportatibleModuleWithStaticCache(torch.nn.Module): + """ + A wrapper module designed to make a `PreTrainedModel` exportable with `torch.export`, + specifically for use with static caching. This module ensures that the exported model + is compatible with further lowering and execution in `ExecuTorch`. + + Args: + model (PreTrainedModel): The pretrained model to wrap. The model must have caching + enabled and use a 'static' caching implementation. + + Attributes: + model (PreTrainedModel): The underlying pretrained model. + static_cache (StaticCache): A static cache instance used to store past key values for faster inference. + is_causal (bool): Indicates whether the model architecture supports causal masking (e.g., causal language models). + + Note: + This class is specifically designed to support export process using `torch.export` + in a way that ensures the model can be further lowered and run efficiently in `ExecuTorch`. + + Raises: + AssertionError: If `model` does not have caching enabled or if it does not use a 'static' caching implementation. + """ + + def __init__(self, model: PreTrainedModel): + super().__init__() + assert model.generation_config.use_cache is True + assert model.generation_config.cache_implementation == "static" + + self.model = model + self.static_cache = StaticCache( + config=self.model.config, + batch_size=self.model.generation_config.cache_config.batch_size, + max_cache_len=self.model.generation_config.cache_config.max_cache_len, + dtype=self.model.config.torch_dtype, + ) + self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures) + if self.is_causal: + causal_mask = torch.tril( + torch.ones( + self.static_cache.max_cache_len, + self.static_cache.max_cache_len, + dtype=torch.bool, + ) + ) + self.register_buffer("mask", causal_mask, persistent=False) + + def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): + """ + Forward pass of the module, which is compatible with the ExecuTorch runtime. + + Args: + input_ids (torch.Tensor): Tensor representing current input token id to the module. + cache_position (torch.Tensor): Tensor representing current input position in the cache. + + Returns: + torch.Tensor: Logits output from the model. + + This forward adapter serves two primary purposes: + + 1. **Making the Model `torch.export`-Compatible**: + The adapter hides unsupported objects, such as the `Cache`, from the graph inputs and outputs, + enabling the model to be exportable using `torch.export` without encountering issues. + + 2. **Ensuring Compatibility with `ExecuTorch` runtime**: + The adapter matches the model's forward signature with that in `executorch/extension/llm/runner`, + ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box. + """ + _, seqlen = input_ids.shape + attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None + outs = self.model( + input_ids=input_ids, + attention_mask=attn_mask, + position_ids=cache_position.unsqueeze(0), + cache_position=cache_position, + past_key_values=self.static_cache, + use_cache=True, + ) + return outs.logits + + +@require_torch +def convert_and_export( + model: PreTrainedModel, + example_input_ids: torch.Tensor = None, + example_cache_position: torch.Tensor = None, +): + """ + Convert a `PreTrainedModel` into an exportable module and export it using `torch.export`, + ensuring the exported model is compatible with `ExecuTorch`. + + Args: + model (PreTrainedModel): The pretrained model to be exported. + example_input_ids (torch.Tensor): Example input token id used by `torch.export`. + example_cache_position (torch.Tensor): Example current cache position used by `torch.export`. + + Returns: + torch.export.ExportedProgram: The exported program generated via `torch.export`. + """ + import torch + + assert version.parse(torch.__version__) >= version.parse("2.3"), "VersionError: torch >= 2.3 is required." + + with torch.no_grad(): + # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal + # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. + import torch.export._trace + + example_input_ids = ( + example_input_ids if example_input_ids is not None else torch.tensor([[1]], dtype=torch.long) + ) + example_cache_position = ( + example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long) + ) + + exported_program = torch.export._trace._export( + TorchExportatibleModuleWithStaticCache(model), + args=(example_input_ids,), + kwargs={"cache_position": example_cache_position}, + pre_dispatch=False, + strict=True, + ) + return exported_program diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index bd3c3279ed19e3..9bde7aa1d54e53 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3184,6 +3184,7 @@ def from_pretrained( adapter_kwargs = kwargs.pop("adapter_kwargs", {}) adapter_name = kwargs.pop("adapter_name", "default") use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) + generation_config = kwargs.pop("generation_config", None) gguf_file = kwargs.pop("gguf_file", None) # Cache path to the GGUF file @@ -3975,6 +3976,12 @@ def from_pretrained( _from_pipeline=from_pipeline, **kwargs, ) + if generation_config is not None: + logger.info( + "Both `pretrained_model_name_or_path` and `generation_config` are provided. The `generation_config`" + " will be used to override the `pretrained_model_name_or_path` generation config." + ) + model.generation_config = model.generation_config.from_dict(generation_config.to_dict()) except OSError: logger.info( "Generation config file not found, using a generation config created from the model config." diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 4a9acf4a271f6a..09de298dc13b89 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -34,7 +34,6 @@ import torch from transformers import ( - AutoConfig, AutoModelForCausalLM, AutoTokenizer, DynamicCache, @@ -44,6 +43,7 @@ SinkCache, StaticCache, ) + from transformers.integrations.executorch import convert_and_export @require_torch @@ -179,56 +179,51 @@ def test_static_cache_exportability(self): if version.parse(torch.__version__) < version.parse("2.3"): self.skipTest(reason="This test requires torch >= 2.3 to run.") + set_seed(0) device = "cpu" dtype = torch.float32 + cache_implementation = "static" + attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention batch_size = 1 - - config = AutoConfig.from_pretrained( + max_cache_len = 1234 + model = AutoModelForCausalLM.from_pretrained( "google/gemma-2b", + device_map=device, torch_dtype=dtype, - use_cache=True, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_cache_len, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_cache_len, + }, + ), ) - m = AutoModelForCausalLM.from_pretrained( - "google/gemma-2b", - config=config, - torch_dtype=dtype, - attn_implementation="sdpa", # Export and ExecuTorch only works for SdpaAttention - ).to(device) - tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") - inputs = tokenizer(["The best color is"], return_tensors="pt").to(device)["input_ids"] - - class ExportatibleModelWithStaticCache(torch.nn.Module): - def __init__(self, config, model): - super().__init__() - self.config = config - self.model = model - self.static_cache = StaticCache( - config=config, batch_size=batch_size, max_cache_len=config.max_length, device=device - ) - - def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor): - outs = self.model( - input_ids=tokens, - attention_mask=None, - position_ids=input_pos.unsqueeze(0), - cache_position=input_pos, - past_key_values=self.static_cache, - use_cache=True, - ) - return outs.logits - - set_seed(0) - with torch.no_grad(): - import torch.export._trace - from torch.export import ExportedProgram - - model = ExportatibleModelWithStaticCache(config, m) - # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal - # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.4.1+ release. - exported_program = torch.export._trace._export( - model, args=(inputs,), kwargs={"input_pos": torch.arange(1)}, pre_dispatch=False, strict=True - ) - self.assertTrue(isinstance(exported_program, ExportedProgram)) + # Check if cache config is passed through correctly + self.assertEqual(model.generation_config.use_cache, True) + self.assertEqual(model.generation_config.cache_implementation, cache_implementation) + self.assertEqual(model.generation_config.max_length, max_cache_len) + self.assertTrue(model.generation_config.cache_config is not None) + self.assertEqual(model.generation_config.cache_config.batch_size, batch_size) + self.assertEqual(model.generation_config.cache_config.max_cache_len, max_cache_len) + + exported_program = convert_and_export(model) + + # Check if the exported model is configured with the `StaticCache` correctly + n_static_key_caches = n_static_value_caches = 0 + for buffer_name, buffer in exported_program.named_buffers(): + if buffer_name.startswith("static_cache.key_cache"): + self.assertTrue(buffer.shape[0] == batch_size) + self.assertTrue(buffer.shape[2] == max_cache_len) + n_static_key_caches = n_static_key_caches + 1 + if buffer_name.startswith("static_cache.value_cache"): + self.assertTrue(buffer.shape[0] == batch_size) + self.assertTrue(buffer.shape[2] == max_cache_len) + n_static_value_caches = n_static_value_caches + 1 + self.assertEqual(n_static_key_caches, model.config.num_hidden_layers) + self.assertEqual(n_static_value_caches, model.config.num_hidden_layers) @require_torch_gpu