Skip to content

Commit

Permalink
Make StaticCache configurable at model construct time
Browse files Browse the repository at this point in the history
  • Loading branch information
Guang Yang committed Sep 3, 2024
1 parent 0b066be commit 7d4092b
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 46 deletions.
41 changes: 41 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,47 @@ def validate(self):
)


@dataclass
class StaticCacheConfig(CacheConfig):
"""
Configuration class for static cache settings.
"""

cache_implementation = "static"

def __init__(self, batch_size: int, max_cache_len: int, device="cpu"):
self.batch_size = batch_size
self.max_cache_len = max_cache_len
self.device = device

def validate(self):
"""Validates if the arguments passed are correct"""

incorrect_arg_msg = (
"Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
"but found {found_value}"
)
# Check that the values are reasonable in general (nbits, axis)
# Later in QuantizedCache init we check if they are supported for that particular backend
if self.batch_size <= 0:
raise ValueError(
incorrect_arg_msg.format(
key="batch_size",
correct_value="> 0",
found_value=self.batch_size,
),
)

if self.max_cache_len <= 0:
raise ValueError(
incorrect_arg_msg.format(
key="max_cache_len",
correct_value="> 0",
found_value=self.max_cache_len,
),
)


class DynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@
NEEDS_CACHE_CONFIG = {}

if is_torch_available():
from ..cache_utils import QuantizedCacheConfig
from ..cache_utils import QuantizedCacheConfig, StaticCacheConfig

NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig
NEEDS_CACHE_CONFIG["static"] = StaticCacheConfig


class GenerationMode(ExplicitEnum):
Expand Down
140 changes: 140 additions & 0 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from packaging import version

from transformers.testing_utils import (
is_torch_available,
require_torch,
)


if is_torch_available():
import torch

from transformers import (
PreTrainedModel,
StaticCache,
)


@require_torch
class TorchExportatibleModuleWithStaticCache(torch.nn.Module):
"""
A wrapper module designed to make a `PreTrainedModel` exportable with `torch.export`,
specifically for use with static caching. This module ensures that the exported model
is compatible with further lowering and execution in `ExecuTorch`.
Args:
model (PreTrainedModel): The pretrained model to wrap. The model must have caching
enabled and use a 'static' caching implementation.
Attributes:
model (PreTrainedModel): The underlying pretrained model.
static_cache (StaticCache): A static cache instance used to store past key values for faster inference.
is_causal (bool): Indicates whether the model architecture supports causal masking (e.g., causal language models).
Note:
This class is specifically designed to support export process using `torch.export`
in a way that ensures the model can be further lowered and run efficiently in `ExecuTorch`.
Raises:
AssertionError: If `model` does not have caching enabled or if it does not use a 'static' caching implementation.
"""

def __init__(self, model: PreTrainedModel):
super().__init__()
assert model.generation_config.use_cache is True
assert model.generation_config.cache_implementation == "static"

self.model = model
self.static_cache = StaticCache(
config=self.model.config,
batch_size=self.model.generation_config.cache_config.batch_size,
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
dtype=self.model.config.torch_dtype,
)
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
if self.is_causal:
causal_mask = torch.tril(
torch.ones(
self.static_cache.max_cache_len,
self.static_cache.max_cache_len,
dtype=torch.bool,
)
)
self.register_buffer("mask", causal_mask, persistent=False)

def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
"""
Forward pass of the module, which is compatible with the ExecuTorch runtime.
Args:
input_ids (torch.Tensor): Tensor representing current input token id to the module.
cache_position (torch.Tensor): Tensor representing current input position in the cache.
Returns:
torch.Tensor: Logits output from the model.
This forward adapter serves two primary purposes:
1. **Making the Model `torch.export`-Compatible**:
The adapter hides unsupported objects, such as the `Cache`, from the graph inputs and outputs,
enabling the model to be exportable using `torch.export` without encountering issues.
2. **Ensuring Compatibility with `ExecuTorch` runtime**:
The adapter matches the model's forward signature with that in `executorch/extension/llm/runner`,
ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
"""
_, seqlen = input_ids.shape
attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
outs = self.model(
input_ids=input_ids,
attention_mask=attn_mask,
position_ids=cache_position.unsqueeze(0),
cache_position=cache_position,
past_key_values=self.static_cache,
use_cache=True,
)
return outs.logits


@require_torch
def convert_and_export(
model: PreTrainedModel,
example_input_ids: torch.Tensor = None,
example_cache_position: torch.Tensor = None,
):
"""
Convert a `PreTrainedModel` into an exportable module and export it using `torch.export`,
ensuring the exported model is compatible with `ExecuTorch`.
Args:
model (PreTrainedModel): The pretrained model to be exported.
example_input_ids (torch.Tensor): Example input token id used by `torch.export`.
example_cache_position (torch.Tensor): Example current cache position used by `torch.export`.
Returns:
torch.export.ExportedProgram: The exported program generated via `torch.export`.
"""
import torch

assert version.parse(torch.__version__) >= version.parse("2.3"), "VersionError: torch >= 2.3 is required."

with torch.no_grad():
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
import torch.export._trace

example_input_ids = (
example_input_ids if example_input_ids is not None else torch.tensor([[1]], dtype=torch.long)
)
example_cache_position = (
example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long)
)

exported_program = torch.export._trace._export(
TorchExportatibleModuleWithStaticCache(model),
args=(example_input_ids,),
kwargs={"cache_position": example_cache_position},
pre_dispatch=False,
strict=True,
)
return exported_program
7 changes: 7 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3184,6 +3184,7 @@ def from_pretrained(
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
adapter_name = kwargs.pop("adapter_name", "default")
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
generation_config = kwargs.pop("generation_config", None)

gguf_file = kwargs.pop("gguf_file", None)
# Cache path to the GGUF file
Expand Down Expand Up @@ -3975,6 +3976,12 @@ def from_pretrained(
_from_pipeline=from_pipeline,
**kwargs,
)
if generation_config is not None:
logger.info(
"Both `pretrained_model_name_or_path` and `generation_config` are provided. The `generation_config`"
" will be used to override the `pretrained_model_name_or_path` generation config."
)
model.generation_config = model.generation_config.from_dict(generation_config.to_dict())
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
Expand Down
85 changes: 40 additions & 45 deletions tests/utils/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import torch

from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
DynamicCache,
Expand All @@ -44,6 +43,7 @@
SinkCache,
StaticCache,
)
from transformers.integrations.executorch import convert_and_export


@require_torch
Expand Down Expand Up @@ -179,56 +179,51 @@ def test_static_cache_exportability(self):
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")

set_seed(0)
device = "cpu"
dtype = torch.float32
cache_implementation = "static"
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
batch_size = 1

config = AutoConfig.from_pretrained(
max_cache_len = 1234
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
device_map=device,
torch_dtype=dtype,
use_cache=True,
attn_implementation=attn_implementation,
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_cache_len,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_cache_len,
},
),
)
m = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
config=config,
torch_dtype=dtype,
attn_implementation="sdpa", # Export and ExecuTorch only works for SdpaAttention
).to(device)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
inputs = tokenizer(["The best color is"], return_tensors="pt").to(device)["input_ids"]

class ExportatibleModelWithStaticCache(torch.nn.Module):
def __init__(self, config, model):
super().__init__()
self.config = config
self.model = model
self.static_cache = StaticCache(
config=config, batch_size=batch_size, max_cache_len=config.max_length, device=device
)

def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor):
outs = self.model(
input_ids=tokens,
attention_mask=None,
position_ids=input_pos.unsqueeze(0),
cache_position=input_pos,
past_key_values=self.static_cache,
use_cache=True,
)
return outs.logits

set_seed(0)
with torch.no_grad():
import torch.export._trace
from torch.export import ExportedProgram

model = ExportatibleModelWithStaticCache(config, m)
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.4.1+ release.
exported_program = torch.export._trace._export(
model, args=(inputs,), kwargs={"input_pos": torch.arange(1)}, pre_dispatch=False, strict=True
)
self.assertTrue(isinstance(exported_program, ExportedProgram))
# Check if cache config is passed through correctly
self.assertEqual(model.generation_config.use_cache, True)
self.assertEqual(model.generation_config.cache_implementation, cache_implementation)
self.assertEqual(model.generation_config.max_length, max_cache_len)
self.assertTrue(model.generation_config.cache_config is not None)
self.assertEqual(model.generation_config.cache_config.batch_size, batch_size)
self.assertEqual(model.generation_config.cache_config.max_cache_len, max_cache_len)

exported_program = convert_and_export(model)

# Check if the exported model is configured with the `StaticCache` correctly
n_static_key_caches = n_static_value_caches = 0
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"):
self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_key_caches = n_static_key_caches + 1
if buffer_name.startswith("static_cache.value_cache"):
self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_value_caches = n_static_value_caches + 1
self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)


@require_torch_gpu
Expand Down

0 comments on commit 7d4092b

Please sign in to comment.