Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question regarding hyperparameter delta_decay and STE baseline in PV-Tuning paper #162

Open
jusjinuk opened this issue Dec 27, 2024 · 0 comments

Comments

@jusjinuk
Copy link

jusjinuk commented Dec 27, 2024

@Vahe1994 @galqiwi @BlackSamorez @justheuristic

Hello, thank you for the awesome work, and actively engaging in answering the issues.

I have two major questions, which is as follows:

  1. According to my understanding,
    delta_decay controls how much to mix with STE and PV-Tuning, with the following rule (below is the docstring of StraightThroughAdamW).
    :param delta_decay: determines whether to use straight-through estimation, direct optimization or a mixture thereof
        - if delta_decay == 1, do not use straight-through estimation. In this regime, the optimizer first updates
         de-quantized weights as though they were continuous, then uses modified weights to update codes, codebooks and
         scales; at the end of each step, the optimizer overwrites de-quantized weights to a de-quantization of the
         possibly updated quantized representations (codes, codebooks, scales).
        - if delta_decay == 0, use standard straight-through estimation. In this regime, the optimizer creates
        an internal set of straight-through buffers in the shape of de-quantized weights. The optimizer trains these
        buffers as though they were continuous; the quantized weights are then updated to minimize the L2 distance to
        these straight-through buffers; finally, the optimizer updates de-quantized weights from the quantized versions.
        - if delta_decay is between 0 and 1, use penalized straight-through estimation. The optimizer acts as though
        using standard straight-through estimation (see delta_decay == 0), but after every step, the straight-through
        buffers are set to (1 - delta_decay) * straight_through_buffer + delta_decay * quantized_weight.

From the PV-Tuning paper, I noticed that you compared (1) STE, (2) Subspace Linearized PV, and (3) Subspace Linearized PV + STE in Table 1.

Thus, I initially thought (1) corresponds to setting delta_decay as 0, (2) corresponds to setting delta_decay as 1, and (3) corresponds to somewhere in between. However, in the fine-tuning command for PV-Tuning in the README.md, I noticed that the delta_decay parameter is given as 0 (--delta_decay 0), which I believe is just STE training.

Is there something I am missing? What is the exact delta_decay parameter (or full command) that I can reproduce each row of this (1) STE, (2) Subspace Linearized PV, and (3) Subspace Linearized PV + STE in Table 1?

  1. What would be the hyperparameter that I could reproduce the row of Table 2, Llama-2-7B, Avg bits 1.02, and 2.02 rows, having Wiki2 PPL 8.28 and 5.84, respectively? I couldn't find the relevant model path in the README.md. I initially thought avg bits of 1.02 would correspond to 1x16g16 and 2.02 would correspond to 2x8g8, but found out that their perplexity in Github was worse than the number reported in the paper. The question regarding how you set the delta_decay here also is the question on my mind. Also, what would be the avg bit width of 1x16g16 format when using it to Llama-2-7B?

Thank you for your time and effort.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant