Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Head-Specific KV Cache Compression Feature (Ada-SnapKV, AdaKV) #25

Closed
wants to merge 17 commits into from

Conversation

FFY0
Copy link

@FFY0 FFY0 commented Dec 7, 2024

Add Feature of Head-Specific KV Cache Compression at issue

I have got some results of Ada-SnapKV on the 4K Ruler benchmark. The results look promising. I have placed the corresponding results in a new notebook, which also includes a brief explanation of the flattened KV cache layout employed by head-specific KV cache compression during computation and a tutorial on how to customize new Head-Specific methods based on the latest AdaBasePress.
Ruler
Additionally, it seems that the Head-Specific KV Cache Compression feature may require custom unit test workflow, such as instantiating new attention classes before loading models. As a result, simply adding Ada-SnapKV into the current unit test may cause failures. I will attempt to resolve this issue in the future. Feel free to let me know if there's anything else you'd like me to refine or if you need additional details!

@FFY0
Copy link
Author

FFY0 commented Dec 7, 2024

I also have some confusion regarding batch support in the current repository. It seems that much of the code assumes a batch size of 1. This is because the compression logic doesn't appear to account for padding tokens caused by varying sequence lengths across different samples. Meanwhile, the current unit tests seem to use dummy inputs with a batch size greater than 1.

To align with other methods, the current implementation of Ada-SnapKV is limited to scenarios where the batch size is 1. If necessary, I will explore support for multiple batch sizes in the future.

@SimJeg SimJeg mentioned this pull request Dec 9, 2024
@FFY0 FFY0 marked this pull request as ready for review December 11, 2024 11:14
@FFY0 FFY0 marked this pull request as draft December 11, 2024 11:22
@FFY0
Copy link
Author

FFY0 commented Dec 11, 2024

Hi @SimJeg ,

I have added unit tests in test_presses.py for Ada-SnapKV and successfully validated them. Additionally, I have updated the architecture of Ada-SnapKV to align with the refactored code in the main branch, and it seems to be working well. If you have any suggestions, please feel free to let me know.

It seems the current CI Action workflows require approval before they can run. I'm not very familiar with this process, so please let me know if there’s anything else I can contribute to. Additionally, the CI Action might fail due to the requirements of the new kernel build process. This issue may need further discussion on how to manage the kernel moving forward to identify a solution.

@FFY0 FFY0 marked this pull request as ready for review December 11, 2024 13:37
@SimJeg
Copy link
Collaborator

SimJeg commented Dec 11, 2024

Hi @FFY0, thanks for your hard work on this PR, the results you shared look really promising. One of the goal of the recent refacto was to welcome more complex presses such yours. We started to look at your PR and will come back with feedback.

@SimJeg
Copy link
Collaborator

SimJeg commented Jan 7, 2025

@FFY0 beyond the discussion on the best way to implement head-wise compression, I tried to implement my own version of AdaKV here. What do you think of the interface:

press = AdaKVPress(scorer)

where scorer is any ScorerPress object.

@FFY0
Copy link
Author

FFY0 commented Jan 9, 2025

Hi @SimJeg, this interface looks great. Using a wrapper avoids many modifications in current architecture. I tried running the code, but I seem to have encountered a minor issue.

def search_hyperplane(X, max_iter=1000):
    """
    Search for an hyperplane Y such that for every Xi, <Xi, Y> <= 1 (simple perceptron)
    Returns LARGE_NEGATIVE_FLOAT * Y to ensure exp(<X, Y>) = 0
    """
    Y = X.mean(1)
    for _ in range(max_iter):
        mask = (X * Y.unsqueeze(1)).sum(dim=2, keepdim=True) <= 1
        if not mask.any():
            return LARGE_NEGATIVE_FLOAT * Y
        Y += (X * mask).sum(1) / mask.sum(1).clamp(min=1)
    raise ValueError("Could not find fake keys such that for every query q, exp(<q, k>) = 0")

It seems that the condition in the function should be if mask.all(). Exiting the search with if not mask.any() can lead to large values in the qk dot product.

After changing it to if mask.all(), I encountered another issue: "Could not find fake keys such that for every query q." Since I'm not fully understanding the logic of search_hyperplane func, I switched back to using the least squares method.
However, using the least squares method also resulted in the same failures. Upon investigation, I found that the problem arises because the input includes too many query states with a large sequence length (e.g., 30), which complicates solving fake key states.

I think one possible solution is to sequentially call the attention method for each query state within the wrapper func during decoding. This also allows for solving fake key states for each query state easily, for example:

key_states = - query_states * 1E5

@SimJeg
Copy link
Collaborator

SimJeg commented Jan 9, 2025

@FFY0 there are still several issues to fix with this approach, I'm working on it ! Calling attention sequentially would work indeed but I'm confident I can fix search_hyperplane

@SimJeg
Copy link
Collaborator

SimJeg commented Jan 9, 2025

@FFY0 it should be fixed now, and early results with on 1% of RULER with AdaSnapKV and AdaExpectedAttention are promising ! I will benchmarks to see if I can reproduce the ones you shared at the beginning of the PR.

@FFY0
Copy link
Author

FFY0 commented Jan 9, 2025

@SimJeg This implementation is really impressive! I'm looking forward to the test results.

@SimJeg
Copy link
Collaborator

SimJeg commented Jan 13, 2025

Closing this PR as #38 has been merged

@SimJeg SimJeg closed this Jan 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants