Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request for Head-Specific KV Cache Compression Feature #7

Closed
FFY0 opened this issue Nov 21, 2024 · 10 comments
Closed

Request for Head-Specific KV Cache Compression Feature #7

FFY0 opened this issue Nov 21, 2024 · 10 comments
Assignees
Labels
feature request New feature or request good first issue Good for newcomers

Comments

@FFY0
Copy link

FFY0 commented Nov 21, 2024

🚀 Feature

Adding support for head-specific KV cache compression which employs variable compression rates for each attention head.

Motivation

Ada-KV[1] has demonstrated that employing different compression rates across attention heads can significantly enhance cache compression methods. Recently, numerous head-specific approaches, such as DuoAttention[2], RazorAttention[3], and HeadKV[4], have emerged, each introducing unique techniques to improve compression quality through head-specific methods. However, these methods involve handling variable-length cache entries across different heads, a feature that KVPress currently does not support. We believe supporting this feature will significantly enhance the flexibility of KVPress and align it with emerging head-specific compression strategies.

[1] Feng, Y., Lv, J., Cao, Y., Xie, X., & Zhou, S. K. (2024). Ada-KV: Optimizing KV Cache Eviction by Adaptive Budget Allocation for Efficient LLM Inference. arXiv preprint arXiv:2407.11550.
[2] Xiao, G., Tang, J., Zuo, J., Guo, J., Yang, S., Tang, H., ... & Han, S. (2024). DuoAttention: Efficient Long-Context LLM Inference with Retrieval and Streaming Heads. arXiv preprint arXiv:2410.10819.
[3] Tang, H., Lin, Y., Lin, J., Han, Q., Hong, S., Yao, Y., & Wang, G. (2024). Razorattention: Efficient kv cache compression through retrieval heads. arXiv preprint arXiv:2407.15891.
[4] Fu, Y., Cai, Z., Asi, A., Xiong, W., Dong, Y., & Xiao, W. (2024). Not All Heads Matter: A Head-Level KV Cache Compression Method with Integrated Retrieval and Reasoning. arXiv preprint arXiv:2410.19258.

@SimJeg SimJeg added the good first issue Good for newcomers label Nov 21, 2024
@SimJeg
Copy link
Collaborator

SimJeg commented Nov 21, 2024

Hi @FFY0,

Definitely a good issue, that's a key feature for several compression techniques. However it requires to implement a new kernel to be efficient so it's a significant effort (except if we find a trick... I do have some ideas ^^)

@FFY0
Copy link
Author

FFY0 commented Nov 21, 2024

Thanks, @SimJeg!
Looking forward to the head-specific KV cache compression feature. This will effectively drive progress in the field of head-wise adaptive compression! 🚀

@SimJeg SimJeg added the feature request New feature or request label Nov 26, 2024
@FFY0
Copy link
Author

FFY0 commented Nov 30, 2024

Hi, @SimJeg.

Recently, I tried to implement a Head-Specific KV Cache compression solution within the current project architecture and developed the Ada-SnapKV compression method as described in the AdaKV paper. This solution introduces several new components while minimizing intrusive changes to the existing architecture. The main modifications include:

  1. To support head-specific cache management, I created a new cache class, DynamicCacheSplitHeadFlatten, along with the corresponding CUDA kernel, update_flatten_klenN_view, to manage and update a flattened KV cache layout.
  2. For efficient attention computation in head-specific methods, I extended the LlamaAttention class by adding a new AdaLlamaFlashAttention class. This class manages some metadata of the flattened KV cache layout and utilizes FlashAttention to perform the necessary computations with a flattened KV cache layout.
  3. Introduced a new press base class for Head-Specific KV Cache compression methods, AdaBasePress, by inheriting from the existing BasePress. This class is responsible for performing compression on the flattened KV Cache layout and updating the corresponding metadata after compression.
  4. Developed a specific subclass, AdaSnapKVPress, based on AdaBasePress, which implements the Ada-SnapKV method proposed in the AdaKV paper.

Once a new subclass based on AdaBasePress is called (e.g. AdaSnapKVPress), AdaLlamaFlashAttention and DynamicCacheSplitHeadFlatten, along with the associated CUDA kernel, are automatically integrated to support Head-Specific KV Cache compression.

So far, I have obtained some preliminary results for Ada-SnapKV on the ruler benchmark, and the performance looks promising. Moving forward, I plan to conduct some tests on corner cases. The code is currently available in a branch of my forked repository. I would appreciate your feedback or suggestions. If progress aligns with expectations, I would be happy to continue working on this and eventually attempt to merge the changes into the main branch.

Commit Details

@SimJeg
Copy link
Collaborator

SimJeg commented Dec 4, 2024

Thanks @FFY0 for the hard work ! We need to decide internally if we want to host kernels in this repository. Is the kernel you propose here already available by pip install somewhere else ?

@SimJeg SimJeg mentioned this issue Dec 4, 2024
@FFY0
Copy link
Author

FFY0 commented Dec 5, 2024

Hi @SimJeg,

This kernel is a modified version of the original AdaKV kernel. It is currently compiled within the adakvpress/kvpress/csrc folder and is not hosted elsewhere. If hosting this kernel in the repository is preferred, I am happy to follow your decision.

I will also make further adjustments to the code and merge the branch you mentioned into my commit, and it seems they could integrate easily.

@SimJeg
Copy link
Collaborator

SimJeg commented Dec 18, 2024

I created a branch introducing another way to do head-wise compression here. It does not contain AdaKVPress but a simple RandomHeadPress and a related notebook to show how to use it.

How it works:

  • it introduces a new DynamicHeadCache which has an indices argument pointing to the KV pairs indices that should be masked during decoding. This cache does not reduce peak memory usage but the attention outputs will be the same as if these KV pairs had been removed.
  • during decoding, the DynamicHeadCache updates the keys and values at the indices with fake keys k and null values v=0 where k is picked so that $e^{q \cdot k} = 0$ for the new input queries q (k is computed using least squares)

This implementation is very short (~ 60 loc) and fits nicely with the kvpress package, however:

  • it requires to add the query_states in the cache_kwargs and the only way to do it in a concise manner is to use exec which is not safe (see __init__.py)
  • it does not effectively reduce peak memory usage (somehow similar to ThinKPress, it fakes compression)

@FFY0 the reason I investigated it is that your current PR implies many changes of the current code:

  1. adds a kernel
  2. adds a cache
  3. adds new attention classes that are not model agnostics
  4. updates the pipeline

Anyway, I don't think I will merge this branch as is because of this exec workaround for point 3 might not be safe goes against the repeat yourself philosophy of transformers.

@SimJeg
Copy link
Collaborator

SimJeg commented Dec 18, 2024

I just pushed an udpated version (commit) without exec. It's a bit cleaner but the downside is that it adds a lot of lines of codes (i.e. thousands !).

How it works:

  • For each model identified by a name (e.g. llama), a modeling_{name}.py is automatically generated with a given version of transformers that adds query_states in cache_kwargs. This is done once using kvpress.models.utils.rewrite_modeling_scripts
  • During kvpress initialization, the trasnformers {NAME}_ATTENTION_CLASSES is updated with the same dictionary coming from kvpress.models.modeling_{name}

@SimJeg
Copy link
Collaborator

SimJeg commented Dec 22, 2024

Above proposal will be deprecated with v4.48 of transformers.

Other idea: https://docs.flashinfer.ai/api/decode.html#batch-decoding

@FFY0
Copy link
Author

FFY0 commented Dec 22, 2024

Hi @SimJeg,

The approximate masking method you mentioned is indeed a clever way to simulate head-specific compression. It minimizes additional code requirements and is a feasible approach. From my understanding of the Transformers library, it seems to prioritize a single-model-file policy over strictly adhering to the DRY principle. Therefore, introducing a new modeling file or class does not appear to conflict with this policy.

In my forked repository, the primary objective was to achieve efficiency within the head-compression paradigm to align with standard computation. However, this inevitably resulted in increased code complexity. If we are open to trading off some computational efficiency for cleaner code, several potential optimization points could be considered:

  1. Modifying the current PR to remove custom kernels—head-granular cache tensor management can still be implemented using PyTorch operations. A straightforward cache management method was explored during the early development of AdaKV, which avoided kernel-level cache management. See [issues] (Speed diff between DynamicCacheSplitHeadFlatten and DynamicCacheSplitHead? FFY0/AdaKV#2) and [code]
    (https://github.com/FFY0/AdaKV/blob/d2252fa9870a2eb233b8512ce67242187dbed987/adaptive_snapkv/monkeypatch/snapkv_utils).py#L101.
  2. Rebuilding flattened cache metadata on demand rather than maintaining it persistently, thereby reducing code complexity.
  3. Decoupling pipeline modifications in the current PR from the pipeline class, as these changes are not strictly necessary.

Regarding the flash_attn_varlen_func used in my attention class implementation, I think it aligns well with the concept in the link you shared. The link appears to showcase the implementation of a similar idea under the paged attention framework. In contrast, my implementation utilizes a flattened cache layout to achieve this approach, primarily because the Transformers library does not seem to support the paged attention mechanism. This idea is adopted in many LLM inference frameworks for handling variable-length sequence inputs (e.g., dynamic batch decoding). However, such methods seem inherently tied to model-specific modifications in Transformers library, necessitating the addition of a new attention class for each modeling file.

I believe the trade-off primarily revolves around two aspects: code simplicity and computational efficiency. Based on this, it would be worth exploring a more balanced and suitable implementation approach. Regarding the additional code, it mainly involves adding new new subclasses into existing modeling file. This seems to align with the commonly pursued open-closed principle. What are your thoughts on this?

@SimJeg
Copy link
Collaborator

SimJeg commented Jan 13, 2025

Issue has been solved by #38 although a better implementation could be envisioned using a proper kernel and cache implementation as proposed by @FFY0 (to be reviewed depending on user request)

@SimJeg SimJeg closed this as completed Jan 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants