Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AdaKVPress #38

Merged
merged 27 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Some cleaning
  • Loading branch information
SimJeg committed Jan 9, 2025
commit 31f6b12cb9003c17818a4225d293c7d38526e714
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ Some presses rely on a different logic:
- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846))

Finally we provide special presses:
- `AdaKVPress`: prune bottom scores of any `ScorerPress` but across all heads, achieving head-wise compressions (see [paper](https://arxiv.org/abs/2407.11550))
- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental). This press can be used with any other press that allows to set a compression_ratio
- `ComposedPress`: compose multiple presses together by chaining their forward hooks
- `KeyRerotationPress`: rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press that inherits from `ScorerPress`.
Expand Down
17 changes: 9 additions & 8 deletions kvpress/attention_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

def search_hyperplane(X, max_iter: int = 1000):
"""
Search for an hyperplane Y such that for every Xi, <Xi, Y> <= 0 (simple perceptron)
Search for an hyperplane Y such that for every Xi, <Xi, Y> <= 0
Returns - 1e5 * Y / ||Y|| ** 2 to ensure exp(<X, Y>) = 0
Raises a ValueError if no such hyperplane is found
"""
Y = X.mean(1)
Y = X.mean(1) # this initialization is enough for most cases
for _ in range(max_iter):
mask = torch.bmm(X, Y.unsqueeze(-1)) <= 0
if not mask.any():
Expand All @@ -18,16 +19,16 @@ def search_hyperplane(X, max_iter: int = 1000):

def attention_patch(func):
"""
Decorator to udpate the keys before the attention computation at the indices provided in module.indices
The keys are updated to a fake key k such that for the input queries q, exp(<q, k>) = 0
This is used to fake head-wise compression. A more optimal solution would be to create a new kernel.
Decorator to udpate the keys before the attention computation at the indices provided in module.masked_key_indices
The keys are updated with a fake key k such that exp(<q, k>) = 0 to fake head-wise compression
This solution is not optimal as it does not reduce peak memory and slightly increase runtime
"""

def wrapper(module, query, key, value, attention_mask, dropout, scaling=None, is_causal=None, **kwargs):
if query.shape[2] == key.shape[2]:
# Prefilling
module.indices = None
elif module.indices is not None:
module.masked_key_indices = None
elif module.masked_key_indices is not None:
# Decoding: build fake keys k s.t. exp(<q, k>) = 0
bsz, num_heads, seq_len, head_dim = query.shape
num_key_value_heads = key.shape[1]
Expand All @@ -40,7 +41,7 @@ def wrapper(module, query, key, value, attention_mask, dropout, scaling=None, is
k = k.view(bsz, num_key_value_heads, head_dim)

# At indices, update the keys to the fake keys and the values to 0
key[*module.indices] = k[*module.indices[:2]]
key[*module.masked_key_indices] = k[*module.masked_key_indices[:2]]

return func(module, query, key, value, attention_mask, dropout, scaling, is_causal, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion kvpress/presses/adakv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten()

# Save indices for attention patching in the module
module.indices = (torch.arange(bsz).repeat_interleave(n_pruned), indices // q_len, indices % q_len)
module.masked_key_indices = (torch.arange(bsz).repeat_interleave(n_pruned), indices // q_len, indices % q_len)
return keys, values
5 changes: 3 additions & 2 deletions kvpress/presses/composed_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from kvpress.presses.base_press import BasePress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.adakv_press import AdaKVPress


@dataclass
Expand All @@ -15,8 +16,8 @@ class ComposedPress(BasePress):
def __post_init__(self):
self.compression_ratio = None
assert not any(
isinstance(press, ObservedAttentionPress) for press in self.presses
), "ComposedPress cannot contains ObservedAttentionPress because attentions pruning is not handled"
isinstance(press, (ObservedAttentionPress, AdaKVPress)) for press in self.presses
), "ComposedPress cannot contains ObservedAttentionPress or AdaKVPress"

def forward_hook(self, module, input, kwargs, output):
self.compression_ratio = 1.0
Expand Down