You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The table below shows an excerpt of the KV cache size for FP16.
Model Size
1k tokens
16k tokens
128k tokens
8B
0.125 GB
1.95 GB
15.62 GB
70B
0.313 GB
4.88 GB
39.06 GB
405B
0.984 GB
15.38 GB
123.05 GB
I used the formula in this article to do my own calculations.
The formula in the article as follows:
This shows the size of the KV cache per token, where the first factor of 2 accounts for the K and V matrices.
num_layers and num_heads and dim_heads refer to the values in the Llama3 paper.
For example, for the 8B model with 16k tokens and 128k tokens, the calculation is as follows and matches the numbers in the table above.
However, if we calculate 16k tokens and 128k tokens in the same way for the 405B model, the numbers do not match those in the table above. The calculated values seem to be half of the values in the table.
This chart in llama3 paper has something wrong. The key/value cache head number for 405B model is 16 rather than 8. You can find the answer in this link
@ZeusXuan
Thank you for the comment!
I read the reddit post.
Does this mean that the number of KV heads on the 405B model was 16,
but has been changed to 8, the same as in the white paper?
I found the following link to the commit that fixes it to 8 kv heads. https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-FP8/discussions/15
https://huggingface.co/blog/llama31#inference-memory-requirements
Please tell me about the calculation of inference memory requirements for Llama 3.1 in this post.
The table below shows an excerpt of the KV cache size for FP16.
I used the formula in this article to do my own calculations.
The formula in the article as follows:
This shows the size of the KV cache per token, where the first factor of 2 accounts for the K and V matrices.
num_layers and num_heads and dim_heads refer to the values in the Llama3 paper.
For example, for the 8B model with 16k tokens and 128k tokens, the calculation is as follows and matches the numbers in the table above.
However, if we calculate 16k tokens and 128k tokens in the same way for the 405B model, the numbers do not match those in the table above. The calculated values seem to be half of the values in the table.
Am I misunderstanding something? Or is there another factor that needs to be taken into account for the 405B model?
Also, for 1k tokens, the numbers are slightly different. Is it calculated as 1024 in the table?
Thank you!
The text was updated successfully, but these errors were encountered: