From e01027a4db71f29cc791fac48bbe78f36f8b9e76 Mon Sep 17 00:00:00 2001 From: HuangFuSL Date: Mon, 25 Nov 2024 13:35:42 +0800 Subject: [PATCH] Update: LLM Inference --- docs/coding/dl-from-scratch/.pages | 1 + docs/coding/dl-from-scratch/index.md | 2 + docs/coding/dl-from-scratch/llama-2.ipynb | 139 ++++++++++++++++++- docs/coding/dl-from-scratch/llm-inference.md | 41 ++++++ 4 files changed, 177 insertions(+), 6 deletions(-) create mode 100644 docs/coding/dl-from-scratch/llm-inference.md diff --git a/docs/coding/dl-from-scratch/.pages b/docs/coding/dl-from-scratch/.pages index 0b9b2f3f6..3fd024f2a 100644 --- a/docs/coding/dl-from-scratch/.pages +++ b/docs/coding/dl-from-scratch/.pages @@ -13,3 +13,4 @@ nav: - 位置编码: positional-embedding.md - 实现BERT: bert.ipynb - 实现Llama-2: llama-2.ipynb + - 模型推理: llm-inference.md diff --git a/docs/coding/dl-from-scratch/index.md b/docs/coding/dl-from-scratch/index.md index 572c57730..1c36d7441 100644 --- a/docs/coding/dl-from-scratch/index.md +++ b/docs/coding/dl-from-scratch/index.md @@ -24,3 +24,5 @@ hide: 4. [RoPE](positional-embedding.md#_3) 5. [transformer变种](transformer-variants.md) * [验证Llama-2实现](llama-2.ipynb) +6. [大模型推理](llm-inference.md) + diff --git a/docs/coding/dl-from-scratch/llama-2.ipynb b/docs/coding/dl-from-scratch/llama-2.ipynb index f341b6147..4fe515037 100644 --- a/docs/coding/dl-from-scratch/llama-2.ipynb +++ b/docs/coding/dl-from-scratch/llama-2.ipynb @@ -259,9 +259,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Llama模型构建\n", - "\n", - "将如上结构拼起来,即可构建出Llama模型。" + "## Llama模型构建" ] }, { @@ -347,7 +345,22 @@ "cell_type": "code", "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "daf152debbe1474c852dfcde91c1100a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00 User: Given two numbers, 13.11 and 13.8, which is larger?\n", + "Assistant:\n", + " User: Given two numbers, 13.11 and 13.8, which is larger?\n", + "Assistant: The\n", + " User: Given two numbers, 13.11 and 13.8, which is larger?\n", + "Assistant: The larger\n", + " User: Given two numbers, 13.11 and 13.8, which is larger?\n", + "Assistant: The larger number\n", + " User: Given two numbers, 13.11 and 13.8, which is larger?\n", + "Assistant: The larger number is\n", + " User: Given two numbers, 13.11 and 13.8, which is larger?\n", + "Assistant: The larger number is \n", + " User: Given two numbers, 13.11 and 13.8, which is larger?\n", + "Assistant: The larger number is 1\n", + " User: Given two numbers, 13.11 and 13.8, which is larger?\n", + "Assistant: The larger number is 13\n", + " User: Given two numbers, 13.11 and 13.8, which is larger?\n", + "Assistant: The larger number is 13.\n", + " User: Given two numbers, 13.11 and 13.8, which is larger?\n", + "Assistant: The larger number is 13.8\n", + " User: Given two numbers, 13.11 and 13.8, which is larger?\n", + "Assistant: The larger number is 13.8.\n", + " User: Given two numbers, 13.11 and 13.8, which is larger?\n", + "Assistant: The larger number is 13.8.\n", + "\n", + "Predicted 11 in 37.42 seconds, 0.29 tokens/s\n" + ] + } + ], + "source": [ + "import time\n", + "query = 'User: Given two numbers, 13.11 and 13.8, which is larger?\\nAssistant:'\n", + "\n", + "start_time = time.time()\n", + "num_tokens = complete_string(llama, llama_tokenizer, query, 10)\n", + "end_time = time.time()\n", + "\n", + "time_span = end_time - start_time\n", + "print(f'Predicted {num_tokens} in {time_span:.2f} seconds, {num_tokens / time_span:.2f} tokens/s')" + ] } ], "metadata": { @@ -564,7 +691,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.11.7" } }, "nbformat": 4, diff --git a/docs/coding/dl-from-scratch/llm-inference.md b/docs/coding/dl-from-scratch/llm-inference.md new file mode 100644 index 000000000..0ffaa29c0 --- /dev/null +++ b/docs/coding/dl-from-scratch/llm-inference.md @@ -0,0 +1,41 @@ +# 大模型推理 + +在推理阶段,大模型根据输入的文本序列,不断预测输入文本的下一个词,直到预测出结束符号为止。不过,在预训练阶段中,不包含对bos(begin of sentence)和eos(end of sentence)的预测。所以,如果使用基座模型,模型会不断输出token。 + +## KV-cache + +把输入序列重新输出LLM,预测下一个单词推理流程的计算效率非常低,原因在于每一次预测都需要对整个序列计算注意力得分。然而,由于每一个token的表示只依赖于前面的token,在推理阶段,我们不必重算前边的token,只需要存储前边的token的KV表示即可。只需用最后一个token的Q计算注意力即可得到对下一个token的预测。这种技术即称为KV缓存。包含KV缓存的推理过程分为两个阶段,即缓存填充(prefill)阶段和解码(decode)阶段。 + +!!! note "为何只需要KV-cache而不需要Q-cache" + + 假设输入序列经过$QKV$计算后的序列为$S^Q, S^K, S^V$,并且可以拆分为列向量表示: + + $$ + \begin{aligned} + S^Q &= \begin{bmatrix} Q_1 & Q_2 & \cdots & Q_n \end{bmatrix} \\ + S^K &= \begin{bmatrix} K_1 & K_2 & \cdots & K_n \end{bmatrix} \\ + S^V &= \begin{bmatrix} V_1 & V_2 & \cdots & V_n \end{bmatrix} + $$ + + 注意力得分矩阵为 + + $$ + \begin{bmatrix} + Q_1K_1 & Q_1K_2 & \cdots & Q_1K_n \\ + Q_2K_1 & Q_2K_2 & \cdots & Q_2K_n \\ + \vdots & \vdots & \ddots & \vdots \\ + Q_nK_1 & Q_nK_2 & \cdots & Q_nK_n + $$ + + 由于decoder的causal mask限制,最终计算的表示为 + + $$ + \begin{aligned} + O_1 &= \sum_{i=1}^1 \text{softmax}(Q_1K_i)V_i \\ + O_2 &= \sum_{i=1}^2 \text{softmax}(Q_2K_i)V_i \\ + \end{aligned} + $$ + + 也即,每个位置的表示,只依赖该位置的Q,和该位置及之前的K和V。因此预测下一个单词时,不需要重新计算KV,将KV存储下来即可。 + +将输入序列输入LLM时,首先进行缓存填充,保存下所有层的KV-cache,然后进行解码,不断预测下一个单词,并记录预测单词的KV表示。对于decoder的每一层,都需要保存一个KV-cache,每一层cache的大小为$2\times N\times L\times D$。对于长序列,KV-cache的大小会非常大。 \ No newline at end of file