Skip to content

Commit

Permalink
Make StaticCache configurable at model construct time
Browse files Browse the repository at this point in the history
  • Loading branch information
Guang Yang committed Sep 4, 2024
1 parent 0b066be commit a573552
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 51 deletions.
5 changes: 5 additions & 0 deletions docs/source/en/internal/modeling_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ Most of those are only useful if you are studying the code of the models in the
[[autodoc]] modeling_utils.SequenceSummary
- forward

[[autodoc]] integrations.executorch.TorchExportableModuleWithStaticCache
- forward

[[autodoc]] integrations.executorch.convert_and_export_with_cache

## PyTorch Helper Functions

[[autodoc]] pytorch_utils.apply_chunking_to_forward
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
"is_tensorboard_available",
"is_wandb_available",
],
"integrations.executorch": [],
"modelcard": ["ModelCard"],
"modeling_tf_pytorch_utils": [
"convert_tf_weight_name_to_pt_weight_name",
Expand Down Expand Up @@ -4430,6 +4431,9 @@
_import_structure["models.musicgen_melody"].append("MusicgenMelodyFeatureExtractor")
_import_structure["models.musicgen_melody"].append("MusicgenMelodyProcessor")

if is_torch_available():
_import_structure["integrations.executorch"].append("TorchExportableModuleWithStaticCache")
_import_structure["integrations.executorch"].append("convert_and_export_with_cache")

# FLAX-backed objects
try:
Expand Down
40 changes: 40 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,46 @@ def validate(self):
)


@dataclass
class StaticCacheConfig(CacheConfig):
"""
Configuration class for static cache settings.
"""

cache_implementation = "static"

def __init__(self, batch_size: int, max_cache_len: int, device="cpu"):
self.batch_size = batch_size
self.max_cache_len = max_cache_len
self.device = device

def validate(self):
"""Validates if the arguments passed are correct"""

incorrect_arg_msg = (
"Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
"but found {found_value}"
)

if self.batch_size <= 0:
raise ValueError(
incorrect_arg_msg.format(
key="batch_size",
correct_value="> 0",
found_value=self.batch_size,
),
)

if self.max_cache_len <= 0:
raise ValueError(
incorrect_arg_msg.format(
key="max_cache_len",
correct_value="> 0",
found_value=self.max_cache_len,
),
)


class DynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@
NEEDS_CACHE_CONFIG = {}

if is_torch_available():
from ..cache_utils import QuantizedCacheConfig
from ..cache_utils import QuantizedCacheConfig, StaticCacheConfig

NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig
NEEDS_CACHE_CONFIG["static"] = StaticCacheConfig


class GenerationMode(ExplicitEnum):
Expand Down
149 changes: 149 additions & 0 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from transformers import (
PreTrainedModel,
StaticCache,
)
from transformers.testing_utils import is_torch_available


if is_torch_available():
import torch

from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3

class TorchExportableModuleWithStaticCache(torch.nn.Module):
"""
A wrapper module designed to make a `PreTrainedModel` exportable with `torch.export`,
specifically for use with static caching. This module ensures that the exported model
is compatible with further lowering and execution in `ExecuTorch`.
Note:
This class is specifically designed to support export process using `torch.export`
in a way that ensures the model can be further lowered and run efficiently in `ExecuTorch`.
"""

def __init__(self, model: PreTrainedModel):
"""
Initializes the wrapper module with the pretrained model.
Args:
model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching
enabled and use a 'static' caching implementation.
Raises:
AssertionError: If the pretrained model does not have caching enabled or if it does
not use a 'static' caching implementation in `model.generation_config`.
"""
super().__init__()

# Sanity checks
if model.generation_config is None:
raise AssertionError(
"The model must have a generation config to be exported with static caching. "
"Please set `generation_config`."
)

if not model.generation_config.use_cache:
raise AssertionError(
"The model must have caching enabled to be exported with static caching. "
"Please set `generation_config.use_cache=True`."
)

if model.generation_config.cache_implementation != "static":
raise AssertionError(
"The model must use a 'static' caching implementation to be exported with static caching. "
"Please set `generation_config.cache_implementation='static'`."
)

self.model = model
self.static_cache = StaticCache(
config=self.model.config,
batch_size=self.model.generation_config.cache_config.batch_size,
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
dtype=self.model.config.torch_dtype,
)
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
if self.is_causal:
causal_mask = torch.tril(
torch.ones(
self.static_cache.max_cache_len,
self.static_cache.max_cache_len,
dtype=torch.bool,
)
)
self.register_buffer("mask", causal_mask, persistent=False)

def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
"""
Forward pass of the module, which is compatible with the ExecuTorch runtime.
Args:
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
Returns:
torch.Tensor: Logits output from the model.
This forward adapter serves two primary purposes:
1. **Making the Model `torch.export`-Compatible**:
The adapter hides unsupported objects, such as the `Cache`, from the graph inputs and outputs,
enabling the model to be exportable using `torch.export` without encountering issues.
2. **Ensuring Compatibility with `ExecuTorch` runtime**:
The adapter matches the model's forward signature with that in `executorch/extension/llm/runner`,
ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
"""
_, seqlen = input_ids.shape
attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
outs = self.model(
input_ids=input_ids,
attention_mask=attn_mask,
position_ids=cache_position.unsqueeze(0),
cache_position=cache_position,
past_key_values=self.static_cache,
use_cache=True,
)
return outs.logits

def convert_and_export_with_cache(
model: PreTrainedModel,
example_input_ids: torch.Tensor = None,
example_cache_position: torch.Tensor = None,
):
"""
Convert a `PreTrainedModel` into an exportable module and export it using `torch.export`,
ensuring the exported model is compatible with `ExecuTorch`.
Args:
model (`PreTrainedModel`): The pretrained model to be exported.
example_input_ids (`torch.Tensor`): Example input token id used by `torch.export`.
example_cache_position (`torch.Tensor`): Example current cache position used by `torch.export`.
Returns:
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
"""

if not is_torch_greater_or_equal_than_2_3:
raise ImportError("torch >= 2.3 is required.")

import torch.export._trace

with torch.no_grad():
# TODO: The default inputs only work for text models. We need to add support for vision/audio models.
example_input_ids = (
example_input_ids if example_input_ids is not None else torch.tensor([[1]], dtype=torch.long)
)
example_cache_position = (
example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long)
)

# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
exported_program = torch.export._trace._export(
TorchExportableModuleWithStaticCache(model),
args=(example_input_ids,),
kwargs={"cache_position": example_cache_position},
pre_dispatch=False,
strict=True,
)
return exported_program
6 changes: 5 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3184,6 +3184,7 @@ def from_pretrained(
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
adapter_name = kwargs.pop("adapter_name", "default")
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
generation_config = kwargs.pop("generation_config", None)

gguf_file = kwargs.pop("gguf_file", None)
# Cache path to the GGUF file
Expand Down Expand Up @@ -3959,7 +3960,10 @@ def from_pretrained(
model.eval()

# If it is a model with generation capabilities, attempt to load the generation config
if model.can_generate() and pretrained_model_name_or_path is not None:
if model.can_generate() and generation_config is not None:
logger.info("The user-defined `generation_config` will be used to override the default generation config.")
model.generation_config = model.generation_config.from_dict(generation_config.to_dict())
elif model.can_generate() and pretrained_model_name_or_path is not None:
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
Expand Down
91 changes: 42 additions & 49 deletions tests/utils/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import unittest

from packaging import version
from parameterized import parameterized

from transformers import set_seed
Expand All @@ -34,7 +33,6 @@
import torch

from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
DynamicCache,
Expand All @@ -43,7 +41,9 @@
LlamaConfig,
SinkCache,
StaticCache,
convert_and_export_with_cache,
)
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3


@require_torch
Expand Down Expand Up @@ -174,61 +174,54 @@ def test_static_cache_exportability(self):
"""
Tests that static cache works with `torch.export()`
"""
import torch

if version.parse(torch.__version__) < version.parse("2.3"):
if not is_torch_greater_or_equal_than_2_3:
self.skipTest(reason="This test requires torch >= 2.3 to run.")

set_seed(0)
device = "cpu"
dtype = torch.float32
cache_implementation = "static"
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
batch_size = 1

config = AutoConfig.from_pretrained(
max_cache_len = 1234
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
device_map=device,
torch_dtype=dtype,
use_cache=True,
attn_implementation=attn_implementation,
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_cache_len,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_cache_len,
},
),
)
m = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
config=config,
torch_dtype=dtype,
attn_implementation="sdpa", # Export and ExecuTorch only works for SdpaAttention
).to(device)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
inputs = tokenizer(["The best color is"], return_tensors="pt").to(device)["input_ids"]

class ExportatibleModelWithStaticCache(torch.nn.Module):
def __init__(self, config, model):
super().__init__()
self.config = config
self.model = model
self.static_cache = StaticCache(
config=config, batch_size=batch_size, max_cache_len=config.max_length, device=device
)

def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor):
outs = self.model(
input_ids=tokens,
attention_mask=None,
position_ids=input_pos.unsqueeze(0),
cache_position=input_pos,
past_key_values=self.static_cache,
use_cache=True,
)
return outs.logits

set_seed(0)
with torch.no_grad():
import torch.export._trace
from torch.export import ExportedProgram

model = ExportatibleModelWithStaticCache(config, m)
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.4.1+ release.
exported_program = torch.export._trace._export(
model, args=(inputs,), kwargs={"input_pos": torch.arange(1)}, pre_dispatch=False, strict=True
)
self.assertTrue(isinstance(exported_program, ExportedProgram))
# Check if cache config is passed through correctly
self.assertEqual(model.generation_config.use_cache, True)
self.assertEqual(model.generation_config.cache_implementation, cache_implementation)
self.assertEqual(model.generation_config.max_length, max_cache_len)
self.assertTrue(model.generation_config.cache_config is not None)
self.assertEqual(model.generation_config.cache_config.batch_size, batch_size)
self.assertEqual(model.generation_config.cache_config.max_cache_len, max_cache_len)

exported_program = convert_and_export_with_cache(model)

# Check if the exported model is configured with the `StaticCache` correctly
n_static_key_caches = n_static_value_caches = 0
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"):
self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_key_caches = n_static_key_caches + 1
if buffer_name.startswith("static_cache.value_cache"):
self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_value_caches = n_static_value_caches + 1
self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)


@require_torch_gpu
Expand Down

0 comments on commit a573552

Please sign in to comment.