Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3.1 inference memory requirements #2345

Open
satojkovic opened this issue Sep 14, 2024 · 2 comments
Open

Llama3.1 inference memory requirements #2345

satojkovic opened this issue Sep 14, 2024 · 2 comments

Comments

@satojkovic
Copy link

https://huggingface.co/blog/llama31#inference-memory-requirements
Please tell me about the calculation of inference memory requirements for Llama 3.1 in this post.

The table below shows an excerpt of the KV cache size for FP16.

Model Size 1k tokens 16k tokens 128k tokens
8B 0.125 GB 1.95 GB 15.62 GB
70B 0.313 GB 4.88 GB 39.06 GB
405B 0.984 GB 15.38 GB 123.05 GB

I used the formula in this article to do my own calculations.
The formula in the article as follows:
image
This shows the size of the KV cache per token, where the first factor of 2 accounts for the K and V matrices.
num_layers and num_heads and dim_heads refer to the values in the Llama3 paper.
image

For example, for the 8B model with 16k tokens and 128k tokens, the calculation is as follows and matches the numbers in the table above.

16000 * (2 * 32 * 8 * (4096/32) * 2) / 1024**3
# 1.953125
128000 * (2 * 32 * 8 * (4096/32) * 2) / 1024**3
# 15.625

However, if we calculate 16k tokens and 128k tokens in the same way for the 405B model, the numbers do not match those in the table above. The calculated values seem to be half of the values in the table.

16000 * (2 * 126 * 8 * (16384/128) * 2) / 1024**3
# 7.6904296875
128000 * (2 * 126 * 8 * (16384/128) * 2) / 1024**3
# 61.5234375

Am I misunderstanding something? Or is there another factor that needs to be taken into account for the 405B model?

Also, for 1k tokens, the numbers are slightly different. Is it calculated as 1024 in the table?

1000 * (2 * 32 * 8 * (4096/32) * 2) / 1024**3
# 0.1220703125
1024 * (2 * 32 * 8 * (4096/32) * 2) / 1024**3
# 0.125

Thank you!

@ZeusXuan
Copy link
Contributor

This chart in llama3 paper has something wrong. The key/value cache head number for 405B model is 16 rather than 8. You can find the answer in this link

@satojkovic
Copy link
Author

@ZeusXuan
Thank you for the comment!
I read the reddit post.
Does this mean that the number of KV heads on the 405B model was 16,
but has been changed to 8, the same as in the white paper?
I found the following link to the commit that fixes it to 8 kv heads.
https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-FP8/discussions/15

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants