-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
222 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import torch | ||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS | ||
|
||
|
||
def search_hyperplane(X, max_iter: int = 1000): | ||
""" | ||
Given a tensor X of shape (bsz, seq_len, head_dim), search for an hyperplane Y (bsz, head_dim) | ||
such that for every i, <X[:, i], Y> <= 0. Returns - 1e5 * Y / ||Y|| ** 2 to ensure exp(<X, Y>) = 0 | ||
Raises a ValueError if no such hyperplane is found | ||
""" | ||
Y = X.mean(1) # this initialization is enough for most cases | ||
for _ in range(max_iter): | ||
mask = torch.bmm(X, Y.unsqueeze(-1)) <= 0 | ||
if not mask.any(): | ||
return -1e5 * Y / Y.norm(dim=-1, keepdim=True) ** 2 | ||
Y += (X * mask).sum(1) / mask.sum(1).clamp(min=1) | ||
raise ValueError("Could not find fake keys such that for every query q, exp(<q, k>) = 0") | ||
|
||
|
||
def attention_patch(func): | ||
""" | ||
Decorator to udpate the keys before the attention computation at the indices provided in module.masked_key_indices | ||
The keys are updated with a fake key k such that exp(<q, k>) = 0 to fake head-wise compression | ||
This solution is not optimal as it does not reduce peak memory and slightly increase runtime | ||
""" | ||
|
||
def wrapper(module, query, key, value, attention_mask, dropout, **kwargs): | ||
if query.shape[2] == key.shape[2]: | ||
# Prefilling | ||
module.masked_key_indices = None | ||
elif module.masked_key_indices is not None: | ||
# Decoding: build fake keys k s.t. exp(<q, k>) = 0 | ||
bsz, num_heads, seq_len, head_dim = query.shape | ||
num_key_value_heads = key.shape[1] | ||
num_groups = num_heads // num_key_value_heads | ||
|
||
# Build a fake key k per key group such that for every query q, exp(<q, k>) = 0 | ||
q = query.view(bsz, num_key_value_heads, num_groups, seq_len, head_dim) | ||
q = q.reshape(bsz * num_key_value_heads, num_groups * seq_len, head_dim) | ||
k = search_hyperplane(q) | ||
k = k.view(bsz, num_key_value_heads, head_dim) | ||
|
||
# At indices, update the keys to the fake keys | ||
batch_indices, head_indices, seq_indices = module.masked_key_indices | ||
key[batch_indices, head_indices, seq_indices] = k[batch_indices, head_indices] | ||
|
||
return func(module, query, key, value, attention_mask, dropout, **kwargs) | ||
|
||
return wrapper | ||
|
||
|
||
def patch_attention_functions(): | ||
""" | ||
Add the attention_patch decorator to functions in ALL_ATTENTION_FUNCTIONS | ||
""" | ||
|
||
for name, func in ALL_ATTENTION_FUNCTIONS.items(): | ||
ALL_ATTENTION_FUNCTIONS[name] = attention_patch(func) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
from dataclasses import dataclass | ||
|
||
import torch | ||
|
||
from kvpress.presses.base_press import BasePress | ||
from kvpress.presses.scorer_press import ScorerPress | ||
|
||
|
||
@dataclass | ||
class AdaKVPress(BasePress): | ||
""" | ||
AdaKV (https://arxiv.org/abs/2407.11550) selects the top-k keys and values among all heads in a layer | ||
based on the scores, achieving head-specific compression. | ||
A safeguard is applied to ensure a minimum fraction of KV pairs per head (alpha_safeguard parameter) | ||
This press has been reviewed by Yuan Feng, first author of AdaKV. | ||
""" | ||
|
||
scorer: ScorerPress | ||
alpha_safeguard: float = 0.20 | ||
|
||
def __post_init__(self): | ||
assert isinstance(self.scorer, ScorerPress), "AdaKVPress requires a ScorerPress as input" | ||
assert 0 <= self.alpha_safeguard <= 1, "alpha_safeguard should be in [0, 1]" | ||
|
||
@property | ||
def compression_ratio(self): | ||
return self.scorer.compression_ratio | ||
|
||
@compression_ratio.setter | ||
def compression_ratio(self, value): | ||
self.scorer.compression_ratio = value | ||
|
||
def compress(self, module, hidden_states, keys, values, attentions, kwargs): | ||
if self.compression_ratio == 0: | ||
return keys, values | ||
|
||
assert module.config._attn_implementation != "eager", "eager mode not supported" | ||
|
||
# Compute scores | ||
scores = self.scorer.score(module, hidden_states, keys, values, attentions, kwargs) | ||
bsz, num_key_value_heads, q_len = scores.shape | ||
|
||
# Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head | ||
n_kept = int(q_len * (1 - self.compression_ratio)) # ScorerPress definition | ||
n_safe = int(n_kept * self.alpha_safeguard) | ||
top_indices = torch.topk(scores, n_safe, dim=-1).indices | ||
scores.scatter_(-1, top_indices, torch.finfo(scores.dtype).max) | ||
|
||
# Compute bottom-k across heads | ||
n_pruned = num_key_value_heads * (q_len - n_kept) | ||
indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten() | ||
|
||
# Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details | ||
batch_indices = torch.arange(bsz).repeat_interleave(n_pruned) | ||
head_indices = indices // q_len | ||
seq_indices = indices % q_len | ||
module.masked_key_indices = (batch_indices, head_indices, seq_indices) | ||
return keys, values |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import torch | ||
from kvpress.attention_patch import search_hyperplane | ||
|
||
|
||
def test_search_hyperplane(): | ||
bsz, seq_len, head_dim = 50, 500, 128 | ||
X = torch.rand(bsz, seq_len, head_dim) | ||
Y = search_hyperplane(X) | ||
assert torch.exp(torch.bmm(X, Y.unsqueeze(-1))).max() == 0 |