You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, thank you for the awesome work, and actively engaging in answering the issues.
I have two major questions, which is as follows:
According to my understanding, delta_decay controls how much to mix with STE and PV-Tuning, with the following rule (below is the docstring of StraightThroughAdamW).
:param delta_decay: determines whether to use straight-through estimation, direct optimization or a mixture thereof
- if delta_decay == 1, do not use straight-through estimation. In this regime, the optimizer first updates
de-quantized weights as though they were continuous, then uses modified weights to update codes, codebooks and
scales; at the end of each step, the optimizer overwrites de-quantized weights to a de-quantization of the
possibly updated quantized representations (codes, codebooks, scales).
- if delta_decay == 0, use standard straight-through estimation. In this regime, the optimizer creates
an internal set of straight-through buffers in the shape of de-quantized weights. The optimizer trains these
buffers as though they were continuous; the quantized weights are then updated to minimize the L2 distance to
these straight-through buffers; finally, the optimizer updates de-quantized weights from the quantized versions.
- if delta_decay is between 0 and 1, use penalized straight-through estimation. The optimizer acts as though
using standard straight-through estimation (see delta_decay == 0), but after every step, the straight-through
buffers are set to (1 - delta_decay) * straight_through_buffer + delta_decay * quantized_weight.
From the PV-Tuning paper, I noticed that you compared (1) STE, (2) Subspace Linearized PV, and (3) Subspace Linearized PV + STE in Table 1.
Thus, I initially thought (1) corresponds to setting delta_decay as 0, (2) corresponds to setting delta_decay as 1, and (3) corresponds to somewhere in between. However, in the fine-tuning command for PV-Tuning in the README.md, I noticed that the delta_decay parameter is given as 0 (--delta_decay 0), which I believe is just STE training.
Is there something I am missing? What is the exact delta_decay parameter (or full command) that I can reproduce each row of this (1) STE, (2) Subspace Linearized PV, and (3) Subspace Linearized PV + STE in Table 1?
What would be the hyperparameter that I could reproduce the row of Table 2, Llama-2-7B, Avg bits 1.02, and 2.02 rows, having Wiki2 PPL 8.28 and 5.84, respectively? I couldn't find the relevant model path in the README.md. I initially thought avg bits of 1.02 would correspond to 1x16g16 and 2.02 would correspond to 2x8g8, but found out that their perplexity in Github was worse than the number reported in the paper. The question regarding how you set the delta_decay here also is the question on my mind. Also, what would be the avg bit width of 1x16g16 format when using it to Llama-2-7B?
Thank you for your time and effort.
The text was updated successfully, but these errors were encountered:
@Vahe1994 @galqiwi @BlackSamorez @justheuristic
Hello, thank you for the awesome work, and actively engaging in answering the issues.
I have two major questions, which is as follows:
delta_decay
controls how much to mix with STE and PV-Tuning, with the following rule (below is the docstring ofStraightThroughAdamW
).From the PV-Tuning paper, I noticed that you compared (1) STE, (2) Subspace Linearized PV, and (3) Subspace Linearized PV + STE in Table 1.
Thus, I initially thought (1) corresponds to setting
delta_decay
as 0, (2) corresponds to settingdelta_decay
as 1, and (3) corresponds to somewhere in between. However, in the fine-tuning command for PV-Tuning in the README.md, I noticed that thedelta_decay
parameter is given as 0 (--delta_decay 0
), which I believe is just STE training.Is there something I am missing? What is the exact
delta_decay
parameter (or full command) that I can reproduce each row of this (1) STE, (2) Subspace Linearized PV, and (3) Subspace Linearized PV + STE in Table 1?delta_decay
here also is the question on my mind. Also, what would be the avg bit width of 1x16g16 format when using it to Llama-2-7B?Thank you for your time and effort.
The text was updated successfully, but these errors were encountered: