diff --git a/README.md b/README.md index 2b9e59c..6ada295 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,7 @@ Some presses rely on a different logic: - `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846)) Finally we provide special presses: +- `AdaKVPress`: prune bottom scores of any `ScorerPress` but across all heads, achieving head-wise compressions (see [paper](https://arxiv.org/abs/2407.11550)) - `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental). This press can be used with any other press that allows to set a compression_ratio - `ComposedPress`: compose multiple presses together by chaining their forward hooks - `KeyRerotationPress`: rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press that inherits from `ScorerPress`. diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index 3b2b9b2..d165fda 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -17,6 +17,7 @@ from zero_scrolls.calculate_metrics import calculate_metrics as zero_scrolls_scorer from kvpress import ( + AdaKVPress, ExpectedAttentionPress, KnormPress, ObservedAttentionPress, @@ -44,6 +45,8 @@ } PRESS_DICT = { + "adasnapkv": AdaKVPress(SnapKVPress()), + "ada_expected_attention": AdaKVPress(ExpectedAttentionPress()), "expected_attention": ExpectedAttentionPress(), "knorm": KnormPress(), "observed_attention": ObservedAttentionPress(), diff --git a/kvpress/__init__.py b/kvpress/__init__.py index a285240..8674886 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -3,6 +3,7 @@ from kvpress.pipeline import KVPressTextGenerationPipeline +from kvpress.presses.adakv_press import AdaKVPress from kvpress.presses.base_press import BasePress from kvpress.presses.composed_press import ComposedPress from kvpress.presses.expected_attention_press import ExpectedAttentionPress @@ -18,8 +19,12 @@ from kvpress.presses.think_press import ThinKPress from kvpress.presses.tova_press import TOVAPress +from kvpress.attention_patch import patch_attention_functions +# Patch the attention functions to support head-wise compression +patch_attention_functions() __all__ = [ + "AdaKVPress", "BasePress", "ComposedPress", "ScorerPress", diff --git a/kvpress/attention_patch.py b/kvpress/attention_patch.py new file mode 100644 index 0000000..c40e1e4 --- /dev/null +++ b/kvpress/attention_patch.py @@ -0,0 +1,58 @@ +import torch +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + +def search_hyperplane(X, max_iter: int = 1000): + """ + Given a tensor X of shape (bsz, seq_len, head_dim), search for an hyperplane Y (bsz, head_dim) + such that for every i, <= 0. Returns - 1e5 * Y / ||Y|| ** 2 to ensure exp() = 0 + Raises a ValueError if no such hyperplane is found + """ + Y = X.mean(1) # this initialization is enough for most cases + for _ in range(max_iter): + mask = torch.bmm(X, Y.unsqueeze(-1)) <= 0 + if not mask.any(): + return -1e5 * Y / Y.norm(dim=-1, keepdim=True) ** 2 + Y += (X * mask).sum(1) / mask.sum(1).clamp(min=1) + raise ValueError("Could not find fake keys such that for every query q, exp() = 0") + + +def attention_patch(func): + """ + Decorator to udpate the keys before the attention computation at the indices provided in module.masked_key_indices + The keys are updated with a fake key k such that exp() = 0 to fake head-wise compression + This solution is not optimal as it does not reduce peak memory and slightly increase runtime + """ + + def wrapper(module, query, key, value, attention_mask, dropout, **kwargs): + if query.shape[2] == key.shape[2]: + # Prefilling + module.masked_key_indices = None + elif module.masked_key_indices is not None: + # Decoding: build fake keys k s.t. exp() = 0 + bsz, num_heads, seq_len, head_dim = query.shape + num_key_value_heads = key.shape[1] + num_groups = num_heads // num_key_value_heads + + # Build a fake key k per key group such that for every query q, exp() = 0 + q = query.view(bsz, num_key_value_heads, num_groups, seq_len, head_dim) + q = q.reshape(bsz * num_key_value_heads, num_groups * seq_len, head_dim) + k = search_hyperplane(q) + k = k.view(bsz, num_key_value_heads, head_dim) + + # At indices, update the keys to the fake keys + batch_indices, head_indices, seq_indices = module.masked_key_indices + key[batch_indices, head_indices, seq_indices] = k[batch_indices, head_indices] + + return func(module, query, key, value, attention_mask, dropout, **kwargs) + + return wrapper + + +def patch_attention_functions(): + """ + Add the attention_patch decorator to functions in ALL_ATTENTION_FUNCTIONS + """ + + for name, func in ALL_ATTENTION_FUNCTIONS.items(): + ALL_ATTENTION_FUNCTIONS[name] = attention_patch(func) diff --git a/kvpress/presses/adakv_press.py b/kvpress/presses/adakv_press.py new file mode 100644 index 0000000..a4ea1c3 --- /dev/null +++ b/kvpress/presses/adakv_press.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from dataclasses import dataclass + +import torch + +from kvpress.presses.base_press import BasePress +from kvpress.presses.scorer_press import ScorerPress + + +@dataclass +class AdaKVPress(BasePress): + """ + AdaKV (https://arxiv.org/abs/2407.11550) selects the top-k keys and values among all heads in a layer + based on the scores, achieving head-specific compression. + A safeguard is applied to ensure a minimum fraction of KV pairs per head (alpha_safeguard parameter) + This press has been reviewed by Yuan Feng, first author of AdaKV. + """ + + scorer: ScorerPress + alpha_safeguard: float = 0.20 + + def __post_init__(self): + assert isinstance(self.scorer, ScorerPress), "AdaKVPress requires a ScorerPress as input" + assert 0 <= self.alpha_safeguard <= 1, "alpha_safeguard should be in [0, 1]" + + @property + def compression_ratio(self): + return self.scorer.compression_ratio + + @compression_ratio.setter + def compression_ratio(self, value): + self.scorer.compression_ratio = value + + def compress(self, module, hidden_states, keys, values, attentions, kwargs): + if self.compression_ratio == 0: + return keys, values + + assert module.config._attn_implementation != "eager", "eager mode not supported" + + # Compute scores + scores = self.scorer.score(module, hidden_states, keys, values, attentions, kwargs) + bsz, num_key_value_heads, q_len = scores.shape + + # Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head + n_kept = int(q_len * (1 - self.compression_ratio)) # ScorerPress definition + n_safe = int(n_kept * self.alpha_safeguard) + top_indices = torch.topk(scores, n_safe, dim=-1).indices + scores.scatter_(-1, top_indices, torch.finfo(scores.dtype).max) + + # Compute bottom-k across heads + n_pruned = num_key_value_heads * (q_len - n_kept) + indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten() + + # Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details + batch_indices = torch.arange(bsz).repeat_interleave(n_pruned) + head_indices = indices // q_len + seq_indices = indices % q_len + module.masked_key_indices = (batch_indices, head_indices, seq_indices) + return keys, values diff --git a/kvpress/presses/composed_press.py b/kvpress/presses/composed_press.py index 1c2b2bd..cb2b4aa 100644 --- a/kvpress/presses/composed_press.py +++ b/kvpress/presses/composed_press.py @@ -2,6 +2,7 @@ from kvpress.presses.base_press import BasePress from kvpress.presses.observed_attention_press import ObservedAttentionPress +from kvpress.presses.adakv_press import AdaKVPress @dataclass @@ -15,8 +16,8 @@ class ComposedPress(BasePress): def __post_init__(self): self.compression_ratio = None assert not any( - isinstance(press, (ObservedAttentionPress)) for press in self.presses - ), "ComposedPress cannot contains ObservedAttentionPress" + isinstance(press, (ObservedAttentionPress, AdaKVPress)) for press in self.presses + ), "ComposedPress cannot contains ObservedAttentionPress or AdaKVPress" def forward_hook(self, module, input, kwargs, output): self.compression_ratio = 1.0 diff --git a/notebooks/new_press.ipynb b/notebooks/new_press.ipynb index 3ffb279..a64ede7 100644 --- a/notebooks/new_press.ipynb +++ b/notebooks/new_press.ipynb @@ -11,11 +11,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from dataclasses import dataclass\n", + "from contextlib import contextmanager\n", "\n", "import torch\n", "from torch import nn\n", @@ -26,9 +27,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n", + "Device set to use cuda:0\n" + ] + } + ], "source": [ "# Load pipeline\n", "\n", @@ -40,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -229,6 +239,66 @@ "Note that in the `compress` method is itself used in the `forward_hook` method which ensures quantization is handled properly and that the compression is only performed during prefilling. While we don't recommend to change the `forward_hook` method directly, you can still modify it if you need to !" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.3 Head-wise compression\n", + "\n", + "Since 0.2.0, kvpress support head-wise compression, where the KV cache of each head might be compressed by a different compression ratio. \n", + "\n", + "To achieve proper head-wise compression, one should implement a new kernel for attention along with a custom cache class. Instead, the current implementation fakes head-wise compression by updating the pruned keys by a fake key so that the output of the attention layer is not affected. This is implemented through `kvpress.attention_patch.patch_attention_functions`.\n", + "\n", + "To implement a method that compresses the KV cache head-wise, one should instantiate the `masked_key_indices` as outlined below." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "compression_ratio: 0\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Answer: The purpose of this step-by-step guide is to provide a comprehensive and easy-to-follow tutorial on how to create a new press in the KVPress platform. The guide is designed to help users understand the process of setting up a new press, including the\n", + "\n", + "compression_ratio: 0.25\n", + "Answer: The purpose of this guide is to provide a step-by-step process for creating a new press in KVPRESS, which is a popular open-source web server. The guide will cover the necessary steps to set up and configure a new press, including installing\n", + "\n", + "compression_ratio: 0.9\n", + "Answer: This guide is not a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a\n" + ] + } + ], + "source": [ + "@dataclass\n", + "class RandomHeadPress(BasePress):\n", + "\n", + " compression_ratio: float = 0.0\n", + "\n", + " def compress(self, module, hidden_states, keys, values, attentions, kwargs):\n", + " assert keys.shape[0] == 1, \"Only batch size 1 is supported\"\n", + " scores = torch.rand(keys.shape[:-1], device=keys.device)\n", + " mask = scores < torch.quantile(scores, self.compression_ratio)\n", + " module.masked_key_indices = torch.nonzero(mask, as_tuple=True)\n", + " \n", + " return keys, values\n", + "\n", + "for compression_ratio in [0, 0.25, 0.9]:\n", + " press = RandomHeadPress(compression_ratio)\n", + " print(f\"\\ncompression_ratio: {compression_ratio}\")\n", + " print(f\"Answer: {pipe(context, question=question, press=press)['answer']}\")" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index 2a96e5c..d76192a 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -12,6 +12,7 @@ KeyRerotationPress, KnormPress, ObservedAttentionPress, + AdaKVPress, ThinKPress, ScorerPress, ) @@ -29,7 +30,7 @@ def test_composed_press(unit_test_model): # noqa: F811 @pytest.mark.parametrize("press_dict", default_presses) -@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress]) +@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress, AdaKVPress]) def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811 cls = press_dict["cls"] for kwargs in press_dict["kwargs"]: @@ -38,6 +39,11 @@ def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811 press = ComposedPress(presses=[press]) if isinstance(wrapper_press, KeyRerotationPress): press = KeyRerotationPress(press=press) + if isinstance(wrapper_press, AdaKVPress): + if not isinstance(press, ScorerPress): + return + else: + press = AdaKVPress(press=press) with press(unit_test_model): input_ids = unit_test_model.dummy_inputs["input_ids"] diff --git a/tests/test_attention_patch.py b/tests/test_attention_patch.py new file mode 100644 index 0000000..9333609 --- /dev/null +++ b/tests/test_attention_patch.py @@ -0,0 +1,9 @@ +import torch +from kvpress.attention_patch import search_hyperplane + + +def test_search_hyperplane(): + bsz, seq_len, head_dim = 50, 500, 128 + X = torch.rand(bsz, seq_len, head_dim) + Y = search_hyperplane(X) + assert torch.exp(torch.bmm(X, Y.unsqueeze(-1))).max() == 0