Skip to content

Commit

Permalink
Merge pull request #413 from HuangFuSL/ml-from-scratch
Browse files Browse the repository at this point in the history
Update: LLM Inference
  • Loading branch information
github-actions[bot] authored Nov 25, 2024
2 parents 7ef3f93 + e01027a commit 60a666f
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/coding/dl-from-scratch/.pages
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ nav:
- 位置编码: positional-embedding.md
- 实现BERT: bert.ipynb
- 实现Llama-2: llama-2.ipynb
- 模型推理: llm-inference.md
2 changes: 2 additions & 0 deletions docs/coding/dl-from-scratch/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ hide:
4. [RoPE](positional-embedding.md#_3)
5. [transformer变种](transformer-variants.md)
* [验证Llama-2实现](llama-2.ipynb)
6. [大模型推理](llm-inference.md)

139 changes: 133 additions & 6 deletions docs/coding/dl-from-scratch/llama-2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Llama模型构建\n",
"\n",
"将如上结构拼起来,即可构建出Llama模型。"
"## Llama模型构建"
]
},
{
Expand Down Expand Up @@ -347,7 +345,22 @@
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "daf152debbe1474c852dfcde91c1100a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"hf_llama = transformers.LlamaForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')\n",
"\n",
Expand All @@ -370,7 +383,7 @@
" (custom_layer.ffn.fc2, hf_layer.mlp.down_proj)\n",
" ])\n",
" for custom_layer, hf_layer in layer_pairs:\n",
" custom_layer.weight.data.copy_(hf_layer.weight.data)\n",
" custom_layer.weight.data = hf_layer.weight.data.clone().detach()\n",
" return custom_model\n",
"\n",
"llama = load_params(llama, hf_llama)"
Expand Down Expand Up @@ -546,6 +559,120 @@
"for _ in max_prob_tokens:\n",
" print(llama_tokenizer.decode(_))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"21"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"del hf_llama\n",
"del llama_layers_output\n",
"del hf_logits\n",
"import gc\n",
"gc.collect()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 模型推理\n",
"\n",
"在推理阶段,模型根据输入序列,不断预测序列中的下一个单词,并且将其加入序列中。重复这个过程,直到出现结束token。"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def complete_ids(model: LlamaModel, input_ids: torch.Tensor):\n",
" assert input_ids.size(0) == 1\n",
"\n",
" next_token = model(input_ids, torch.ones_like(input_ids))[0, -1, :].topk(1).indices[0].item()\n",
" return next_token\n",
"\n",
"def complete_string(model: LlamaModel, tokenizer: transformers.LlamaTokenizer, prompt: str, max_length: int):\n",
" model.eval()\n",
" gc.collect()\n",
" tokenized_sentence = tokenizer([prompt], return_tensors='pt')['input_ids'].tolist()\n",
" next_token = None\n",
" print(tokenizer.decode(tokenized_sentence[0]))\n",
" num_tokens = 0\n",
"\n",
" while next_token != tokenizer.eos_token_id and num_tokens <= max_length:\n",
" num_tokens += 1\n",
" next_token = complete_ids(model, torch.tensor(tokenized_sentence))\n",
" tokenized_sentence[0].append(next_token)\n",
" print(tokenizer.decode(tokenized_sentence[0]))\n",
"\n",
" gc.collect()\n",
"\n",
" return num_tokens"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<s> User: Given two numbers, 13.11 and 13.8, which is larger?\n",
"Assistant:\n",
"<s> User: Given two numbers, 13.11 and 13.8, which is larger?\n",
"Assistant: The\n",
"<s> User: Given two numbers, 13.11 and 13.8, which is larger?\n",
"Assistant: The larger\n",
"<s> User: Given two numbers, 13.11 and 13.8, which is larger?\n",
"Assistant: The larger number\n",
"<s> User: Given two numbers, 13.11 and 13.8, which is larger?\n",
"Assistant: The larger number is\n",
"<s> User: Given two numbers, 13.11 and 13.8, which is larger?\n",
"Assistant: The larger number is \n",
"<s> User: Given two numbers, 13.11 and 13.8, which is larger?\n",
"Assistant: The larger number is 1\n",
"<s> User: Given two numbers, 13.11 and 13.8, which is larger?\n",
"Assistant: The larger number is 13\n",
"<s> User: Given two numbers, 13.11 and 13.8, which is larger?\n",
"Assistant: The larger number is 13.\n",
"<s> User: Given two numbers, 13.11 and 13.8, which is larger?\n",
"Assistant: The larger number is 13.8\n",
"<s> User: Given two numbers, 13.11 and 13.8, which is larger?\n",
"Assistant: The larger number is 13.8.\n",
"<s> User: Given two numbers, 13.11 and 13.8, which is larger?\n",
"Assistant: The larger number is 13.8.\n",
"\n",
"Predicted 11 in 37.42 seconds, 0.29 tokens/s\n"
]
}
],
"source": [
"import time\n",
"query = 'User: Given two numbers, 13.11 and 13.8, which is larger?\\nAssistant:'\n",
"\n",
"start_time = time.time()\n",
"num_tokens = complete_string(llama, llama_tokenizer, query, 10)\n",
"end_time = time.time()\n",
"\n",
"time_span = end_time - start_time\n",
"print(f'Predicted {num_tokens} in {time_span:.2f} seconds, {num_tokens / time_span:.2f} tokens/s')"
]
}
],
"metadata": {
Expand All @@ -564,7 +691,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
41 changes: 41 additions & 0 deletions docs/coding/dl-from-scratch/llm-inference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# 大模型推理

在推理阶段,大模型根据输入的文本序列,不断预测输入文本的下一个词,直到预测出结束符号为止。不过,在预训练阶段中,不包含对bos(begin of sentence)和eos(end of sentence)的预测。所以,如果使用基座模型,模型会不断输出token。

## KV-cache

把输入序列重新输出LLM,预测下一个单词推理流程的计算效率非常低,原因在于每一次预测都需要对整个序列计算注意力得分。然而,由于每一个token的表示只依赖于前面的token,在推理阶段,我们不必重算前边的token,只需要存储前边的token的KV表示即可。只需用最后一个token的Q计算注意力即可得到对下一个token的预测。这种技术即称为KV缓存。包含KV缓存的推理过程分为两个阶段,即缓存填充(prefill)阶段和解码(decode)阶段。

!!! note "为何只需要KV-cache而不需要Q-cache"

假设输入序列经过$QKV$计算后的序列为$S^Q, S^K, S^V$,并且可以拆分为列向量表示:

$$
\begin{aligned}
S^Q &= \begin{bmatrix} Q_1 & Q_2 & \cdots & Q_n \end{bmatrix} \\
S^K &= \begin{bmatrix} K_1 & K_2 & \cdots & K_n \end{bmatrix} \\
S^V &= \begin{bmatrix} V_1 & V_2 & \cdots & V_n \end{bmatrix}
$$

注意力得分矩阵为

$$
\begin{bmatrix}
Q_1K_1 & Q_1K_2 & \cdots & Q_1K_n \\
Q_2K_1 & Q_2K_2 & \cdots & Q_2K_n \\
\vdots & \vdots & \ddots & \vdots \\
Q_nK_1 & Q_nK_2 & \cdots & Q_nK_n
$$

由于decoder的causal mask限制,最终计算的表示为

$$
\begin{aligned}
O_1 &= \sum_{i=1}^1 \text{softmax}(Q_1K_i)V_i \\
O_2 &= \sum_{i=1}^2 \text{softmax}(Q_2K_i)V_i \\
\end{aligned}
$$

也即,每个位置的表示,只依赖该位置的Q,和该位置及之前的K和V。因此预测下一个单词时,不需要重新计算KV,将KV存储下来即可。

将输入序列输入LLM时,首先进行缓存填充,保存下所有层的KV-cache,然后进行解码,不断预测下一个单词,并记录预测单词的KV表示。对于decoder的每一层,都需要保存一个KV-cache,每一层cache的大小为$2\times N\times L\times D$。对于长序列,KV-cache的大小会非常大。

0 comments on commit 60a666f

Please sign in to comment.