Skip to content

Commit

Permalink
AdaKVPress (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimJeg authored Jan 13, 2025
1 parent 7260696 commit fe4610e
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 7 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ Some presses rely on a different logic:
- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846))

Finally we provide special presses:
- `AdaKVPress`: prune bottom scores of any `ScorerPress` but across all heads, achieving head-wise compressions (see [paper](https://arxiv.org/abs/2407.11550))
- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental). This press can be used with any other press that allows to set a compression_ratio
- `ComposedPress`: compose multiple presses together by chaining their forward hooks
- `KeyRerotationPress`: rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press that inherits from `ScorerPress`.
Expand Down
3 changes: 3 additions & 0 deletions evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from zero_scrolls.calculate_metrics import calculate_metrics as zero_scrolls_scorer

from kvpress import (
AdaKVPress,
ExpectedAttentionPress,
KnormPress,
ObservedAttentionPress,
Expand Down Expand Up @@ -44,6 +45,8 @@
}

PRESS_DICT = {
"adasnapkv": AdaKVPress(SnapKVPress()),
"ada_expected_attention": AdaKVPress(ExpectedAttentionPress()),
"expected_attention": ExpectedAttentionPress(),
"knorm": KnormPress(),
"observed_attention": ObservedAttentionPress(),
Expand Down
5 changes: 5 additions & 0 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


from kvpress.pipeline import KVPressTextGenerationPipeline
from kvpress.presses.adakv_press import AdaKVPress
from kvpress.presses.base_press import BasePress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
Expand All @@ -18,8 +19,12 @@
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress

from kvpress.attention_patch import patch_attention_functions
# Patch the attention functions to support head-wise compression
patch_attention_functions()

__all__ = [
"AdaKVPress",
"BasePress",
"ComposedPress",
"ScorerPress",
Expand Down
58 changes: 58 additions & 0 deletions kvpress/attention_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS


def search_hyperplane(X, max_iter: int = 1000):
"""
Given a tensor X of shape (bsz, seq_len, head_dim), search for an hyperplane Y (bsz, head_dim)
such that for every i, <X[:, i], Y> <= 0. Returns - 1e5 * Y / ||Y|| ** 2 to ensure exp(<X, Y>) = 0
Raises a ValueError if no such hyperplane is found
"""
Y = X.mean(1) # this initialization is enough for most cases
for _ in range(max_iter):
mask = torch.bmm(X, Y.unsqueeze(-1)) <= 0
if not mask.any():
return -1e5 * Y / Y.norm(dim=-1, keepdim=True) ** 2
Y += (X * mask).sum(1) / mask.sum(1).clamp(min=1)
raise ValueError("Could not find fake keys such that for every query q, exp(<q, k>) = 0")


def attention_patch(func):
"""
Decorator to udpate the keys before the attention computation at the indices provided in module.masked_key_indices
The keys are updated with a fake key k such that exp(<q, k>) = 0 to fake head-wise compression
This solution is not optimal as it does not reduce peak memory and slightly increase runtime
"""

def wrapper(module, query, key, value, attention_mask, dropout, **kwargs):
if query.shape[2] == key.shape[2]:
# Prefilling
module.masked_key_indices = None
elif module.masked_key_indices is not None:
# Decoding: build fake keys k s.t. exp(<q, k>) = 0
bsz, num_heads, seq_len, head_dim = query.shape
num_key_value_heads = key.shape[1]
num_groups = num_heads // num_key_value_heads

# Build a fake key k per key group such that for every query q, exp(<q, k>) = 0
q = query.view(bsz, num_key_value_heads, num_groups, seq_len, head_dim)
q = q.reshape(bsz * num_key_value_heads, num_groups * seq_len, head_dim)
k = search_hyperplane(q)
k = k.view(bsz, num_key_value_heads, head_dim)

# At indices, update the keys to the fake keys
batch_indices, head_indices, seq_indices = module.masked_key_indices
key[batch_indices, head_indices, seq_indices] = k[batch_indices, head_indices]

return func(module, query, key, value, attention_mask, dropout, **kwargs)

return wrapper


def patch_attention_functions():
"""
Add the attention_patch decorator to functions in ALL_ATTENTION_FUNCTIONS
"""

for name, func in ALL_ATTENTION_FUNCTIONS.items():
ALL_ATTENTION_FUNCTIONS[name] = attention_patch(func)
62 changes: 62 additions & 0 deletions kvpress/presses/adakv_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


from dataclasses import dataclass

import torch

from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress


@dataclass
class AdaKVPress(BasePress):
"""
AdaKV (https://arxiv.org/abs/2407.11550) selects the top-k keys and values among all heads in a layer
based on the scores, achieving head-specific compression.
A safeguard is applied to ensure a minimum fraction of KV pairs per head (alpha_safeguard parameter)
This press has been reviewed by Yuan Feng, first author of AdaKV.
"""

scorer: ScorerPress
alpha_safeguard: float = 0.20

def __post_init__(self):
assert isinstance(self.scorer, ScorerPress), "AdaKVPress requires a ScorerPress as input"
assert 0 <= self.alpha_safeguard <= 1, "alpha_safeguard should be in [0, 1]"

@property
def compression_ratio(self):
return self.scorer.compression_ratio

@compression_ratio.setter
def compression_ratio(self, value):
self.scorer.compression_ratio = value

def compress(self, module, hidden_states, keys, values, attentions, kwargs):
if self.compression_ratio == 0:
return keys, values

assert module.config._attn_implementation != "eager", "eager mode not supported"

# Compute scores
scores = self.scorer.score(module, hidden_states, keys, values, attentions, kwargs)
bsz, num_key_value_heads, q_len = scores.shape

# Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head
n_kept = int(q_len * (1 - self.compression_ratio)) # ScorerPress definition
n_safe = int(n_kept * self.alpha_safeguard)
top_indices = torch.topk(scores, n_safe, dim=-1).indices
scores.scatter_(-1, top_indices, torch.finfo(scores.dtype).max)

# Compute bottom-k across heads
n_pruned = num_key_value_heads * (q_len - n_kept)
indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten()

# Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details
batch_indices = torch.arange(bsz).repeat_interleave(n_pruned)
head_indices = indices // q_len
seq_indices = indices % q_len
module.masked_key_indices = (batch_indices, head_indices, seq_indices)
return keys, values
5 changes: 3 additions & 2 deletions kvpress/presses/composed_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from kvpress.presses.base_press import BasePress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.adakv_press import AdaKVPress


@dataclass
Expand All @@ -15,8 +16,8 @@ class ComposedPress(BasePress):
def __post_init__(self):
self.compression_ratio = None
assert not any(
isinstance(press, (ObservedAttentionPress)) for press in self.presses
), "ComposedPress cannot contains ObservedAttentionPress"
isinstance(press, (ObservedAttentionPress, AdaKVPress)) for press in self.presses
), "ComposedPress cannot contains ObservedAttentionPress or AdaKVPress"

def forward_hook(self, module, input, kwargs, output):
self.compression_ratio = 1.0
Expand Down
78 changes: 74 additions & 4 deletions notebooks/new_press.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"from contextlib import contextmanager\n",
"\n",
"import torch\n",
"from torch import nn\n",
Expand All @@ -26,9 +27,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n",
"Device set to use cuda:0\n"
]
}
],
"source": [
"# Load pipeline\n",
"\n",
Expand All @@ -40,7 +50,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -229,6 +239,66 @@
"Note that in the `compress` method is itself used in the `forward_hook` method which ensures quantization is handled properly and that the compression is only performed during prefilling. While we don't recommend to change the `forward_hook` method directly, you can still modify it if you need to !"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.3 Head-wise compression\n",
"\n",
"Since 0.2.0, kvpress support head-wise compression, where the KV cache of each head might be compressed by a different compression ratio. \n",
"\n",
"To achieve proper head-wise compression, one should implement a new kernel for attention along with a custom cache class. Instead, the current implementation fakes head-wise compression by updating the pruned keys by a fake key so that the output of the attention layer is not affected. This is implemented through `kvpress.attention_patch.patch_attention_functions`.\n",
"\n",
"To implement a method that compresses the KV cache head-wise, one should instantiate the `masked_key_indices` as outlined below."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"compression_ratio: 0\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Answer: The purpose of this step-by-step guide is to provide a comprehensive and easy-to-follow tutorial on how to create a new press in the KVPress platform. The guide is designed to help users understand the process of setting up a new press, including the\n",
"\n",
"compression_ratio: 0.25\n",
"Answer: The purpose of this guide is to provide a step-by-step process for creating a new press in KVPRESS, which is a popular open-source web server. The guide will cover the necessary steps to set up and configure a new press, including installing\n",
"\n",
"compression_ratio: 0.9\n",
"Answer: This guide is not a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a\n"
]
}
],
"source": [
"@dataclass\n",
"class RandomHeadPress(BasePress):\n",
"\n",
" compression_ratio: float = 0.0\n",
"\n",
" def compress(self, module, hidden_states, keys, values, attentions, kwargs):\n",
" assert keys.shape[0] == 1, \"Only batch size 1 is supported\"\n",
" scores = torch.rand(keys.shape[:-1], device=keys.device)\n",
" mask = scores < torch.quantile(scores, self.compression_ratio)\n",
" module.masked_key_indices = torch.nonzero(mask, as_tuple=True)\n",
" \n",
" return keys, values\n",
"\n",
"for compression_ratio in [0, 0.25, 0.9]:\n",
" press = RandomHeadPress(compression_ratio)\n",
" print(f\"\\ncompression_ratio: {compression_ratio}\")\n",
" print(f\"Answer: {pipe(context, question=question, press=press)['answer']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
8 changes: 7 additions & 1 deletion tests/presses/test_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
KeyRerotationPress,
KnormPress,
ObservedAttentionPress,
AdaKVPress,
ThinKPress,
ScorerPress,
)
Expand All @@ -29,7 +30,7 @@ def test_composed_press(unit_test_model): # noqa: F811


@pytest.mark.parametrize("press_dict", default_presses)
@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress])
@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress, AdaKVPress])
def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811
cls = press_dict["cls"]
for kwargs in press_dict["kwargs"]:
Expand All @@ -38,6 +39,11 @@ def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811
press = ComposedPress(presses=[press])
if isinstance(wrapper_press, KeyRerotationPress):
press = KeyRerotationPress(press=press)
if isinstance(wrapper_press, AdaKVPress):
if not isinstance(press, ScorerPress):
return
else:
press = AdaKVPress(press=press)

with press(unit_test_model):
input_ids = unit_test_model.dummy_inputs["input_ids"]
Expand Down
9 changes: 9 additions & 0 deletions tests/test_attention_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch
from kvpress.attention_patch import search_hyperplane


def test_search_hyperplane():
bsz, seq_len, head_dim = 50, 500, 128
X = torch.rand(bsz, seq_len, head_dim)
Y = search_hyperplane(X)
assert torch.exp(torch.bmm(X, Y.unsqueeze(-1))).max() == 0

0 comments on commit fe4610e

Please sign in to comment.