-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlocal-search.xml
185 lines (89 loc) · 157 KB
/
local-search.xml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
<?xml version="1.0" encoding="utf-8"?>
<search>
<entry>
<title>在 WSL 中优雅地使用宿主机的代理</title>
<link href="/2023/11/28/wsl-proxy/"/>
<url>/2023/11/28/wsl-proxy/</url>
<content type="html"><![CDATA[<p>打开 WSL 终端的时候,会出现下面的一个报错:</p><figure class="highlight text"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs text">wsl: 检测到 localhost 代理配置,但未镜像到 WSL。NAT 模式下的 WSL 不支持 localhost 代理。<br></code></pre></td></tr></table></figure><p>这个报错告诉我们,不能用 localhost 直接访问宿主机的代理。 但是如果在 WSL 中能够访问宿主机的 IP,并且宿主机的代理服务“允许来自局域网的连接”,那么 WSL 就可以轻松使用宿主机的代理了。</p><p>宿主机的 IP 可以在 <code>/etc/resolv.conf</code> 文件里面找到,但这个 IP 可能会变动。因此只要在使用的时候,去解析出这个 IP 即可。 我们可以通过如下的命令解析出宿主机 IP:</p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs shell">cat /etc/resolv.conf |grep "nameserver" |cut -f 2 -d " "<br></code></pre></td></tr></table></figure><p>有了宿主机 IP,我们就可以在<code>~/.bashrc</code> 或 <code>~/.zshrc</code> 直接定义一个<code>alias proxy</code>,</p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><code class="hljs shell">host_ip=$(cat /etc/resolv.conf |grep "nameserver" |cut -f 2 -d " ")<br>proxy_port=10809<br>alias proxy="all_proxy=http://$host_ip:$proxy_port"<br></code></pre></td></tr></table></figure><p><code>proxy_port</code> 要改成宿主机代理端口。现在就可以在 WSL 中使用宿主机代理执行 HTTP 请求了:</p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs shell">proxy curl www.some-proxy-domain.com<br></code></pre></td></tr></table></figure>]]></content>
<categories>
<category>O&M</category>
</categories>
<tags>
<tag>wsl</tag>
<tag>proxy</tag>
</tags>
</entry>
<entry>
<title>一文梳理 RNN、GRU 和 LSTM</title>
<link href="/2023/06/08/rnn-gru-lstm/"/>
<url>/2023/06/08/rnn-gru-lstm/</url>
<content type="html"><![CDATA[<h2 id="rnn">RNN</h2><p>RNN 的全称是 <strong>R</strong>ecurrent <strong>N</strong>eural <strong>N</strong>etwork,翻译成<strong>“循环神经网络”</strong>。它是一种用于处理序列数据的神经网络模型。与传统的前馈神经网络不同,RNN 具有向后连接的特性,可以<strong>将前面的输出作为后面的输入</strong>,从而实现对序列数据的建模和处理。RNN 被广泛应用于自然语言处理、语音识别、机器翻译等领域。一个典型的 RNN 结构图如下:</p><img src="/2023/06/08/rnn-gru-lstm/rnn.png" class="" title="RNN 结构图"><p>为了统一描述,本文规定:上标 <span class="math inline">\(\langle t \rangle\)</span> 表示处在时间步第 <span class="math inline">\(t\)</span> 步的对象或操作。</p><p>在这个图中,</p><ul><li><p>为了便于说明,我们假设输入和输出具有相同的时间步长,即 <span class="math inline">\(T_y = T_x\)</span>。</p></li><li><p><span class="math inline">\(x^{\langle t \rangle}\)</span> 是输入词汇通过 One-Hot 处理后的一维数组,用 <span class="math inline">\(n_x\)</span> 表示数组大小(它也是词汇表中词汇的数量),也可以用 <span class="math inline">\(n_x \times 1\)</span> 的矩阵来表示。</p></li><li><p><span class="math inline">\(y^{\langle t \rangle}\)</span> 是输出的结果。它一般是通过 Sigmoid(二分类)或 Softmax(多元分类)等激活函数处理后的一维数组,用 <span class="math inline">\(n_y\)</span> 表示数组大小,也可以用 <span class="math inline">\(n_y \times 1\)</span> 的矩阵来表示。</p></li><li><p><span class="math inline">\(h^{\langle t \rangle}\)</span> 是隐藏状态 (Hidden State), <span class="math inline">\(n_h \times 1\)</span> 矩阵,或大小为 <span class="math inline">\(n_h\)</span> 的一维数组。每执行一步,其数值都会更新,并把结果传递给下一步。</p></li></ul><p>为什么是这样的结构?没有隐藏状态可以不可以?答案是否定的。如果没有 Hidden State 这个过程媒介,前一次的推理结果不能反馈给下一次推理,相当于所有输入的词汇都没有进行推理而独立生成结果。这显然不符合文本序列的生成规律。</p><h3 id="one-hot-编码处理输入词汇">One-Hot 编码处理输入词汇</h3><p>我们说,<span class="math inline">\(x^{\langle t \rangle}\)</span> 是 One-Hot 编码后的一维数组,这个数组的长度是词汇表的大小。所谓 One-Hot 编码,是一种将分类变量(在这里是词汇表里定义的各个不同的词汇)转换为数值变量的方法。它将每个分类变量的可能取值映射到一个二进制向量的位置。在这个向量中,只有对应于该分类变量取值的位置为 1,其他位置为 0。</p><p>为了简化说明,假设词汇表只有 3 个词汇,分别是 <code>["a", "b", "c"]</code>,那么词汇 <code>"a"</code> 的索引值是 0,One-Hot 后变成数组 <code>[1, 0, 0]</code>; <code>"b"</code> 的索引值是 1,One-Hot 后变成数组 <code>[0, 1, 0]</code>。依此类推,<code>"c"</code> 用数组 <code>[0, 0, 1]</code>来表示。</p><p>实际的词汇表可能长这样:<code>["a", "aaron", ..., "and", ..., "harry", ..., "potter", ..., "zulu"]</code>。更复杂的,还要处理词汇大小写问题(比如处理 green 和 Green),不认识的词汇(比如用 <code><UNK></code> 表示),生成结束的标识(比如用 <code><EOS></code> 表示),掩码标识(比如用 <code><MASK></code> 表示),等等。这些就涉及到分词 (Tokenization) 领域了,本文不做展开。</p><p>One-Hot 的好处是<strong>将不可计算的分类变量变成了可计算的数组</strong>,以便将分类变量输入到模型中进行训练和预测。</p><h3 id="softmax-函数输出概率分布">Softmax 函数输出概率分布</h3><p>那么输出的 <span class="math inline">\(y^{\langle t \rangle}\)</span> 是什么呢?这个取决于不同的应用场合。假如我们希望得到每个词汇被选中的概率,那么 <span class="math inline">\(n_y\)</span> 就有可能等于 <span class="math inline">\(n_x\)</span>(当然,有些情况两者也不相等。比如词汇表的末尾几个词汇是特殊 Token,那么 <span class="math inline">\(y\)</span> 就可能直接去掉它们,从而使得 <span class="math inline">\(n_y < n_x\)</span>)。</p><p>现在就假设 <span class="math inline">\(n_y = n_x = n\)</span>,即词汇表中有 <span class="math inline">\(n\)</span> 个词汇,我们希望:</p><ul><li><p>输出的 <span class="math inline">\(y^{\langle t \rangle}\)</span> 是一个长度为 <span class="math inline">\(n\)</span> 的数组。</p></li><li><p>数组中的每个元素的索引值是对应词汇的索引值,而索引值对应的数值是该词汇被选中的概率,每个值得取值范围是 <code>[0, 1]</code> 。</p></li><li><p>所有元素的数值加起来等于 1。</p></li></ul><p>例如,在上面 <code>["a", "b", "c"]</code> 的例子中,若输出的结果是 <code>[0.14, 0.08, 0.78]</code>,那么就意味着 <code>"a"</code> 被选中的概率是 0.14,<code>"b"</code>被选中的概率是 0.08,<code>"c"</code> 被选中的概率是 0.78。数组只告诉了我们每个词汇被选中的概率,还不是我们的最终想要输出的词汇,因此还需要进一步处理。例如,可以根据概率随机选择。以下是使用 <code>numpy</code>库随机选择的例子:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np<br><br><span class="hljs-comment"># 词汇表</span><br>vocab = [<span class="hljs-string">"a"</span>, <span class="hljs-string">"b"</span>, <span class="hljs-string">"c"</span>]<br><br><span class="hljs-comment"># 概率数组</span><br>probabilities = [<span class="hljs-number">0.14</span>, <span class="hljs-number">0.08</span>, <span class="hljs-number">0.78</span>]<br><br><span class="hljs-comment"># 词汇表长度</span><br>vocab_size = <span class="hljs-built_in">len</span>(vocab) <br><br><span class="hljs-comment"># 按照概率输出选中词汇的索引值</span><br>index = np.random.choice(vocab_size, p=probabilities) <br><br><span class="hljs-comment"># 输出词汇</span><br>word = vocab[index]<br></code></pre></td></tr></table></figure><p>在深度学习中,通常可以用 Softmax 函数作为输出层的激活函数,将输出值映射到 <code>[0, 1]</code> 之间,同时满足所有输出值之和为 1,表示每个类别的概率,用于判断样本属于那个类别的概率。</p><p>Softmax 函数的公式如下所示:</p><p><span class="math display">\[\operatorname{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{n}e^{z_j}}\]</span></p><p>其中,<span class="math inline">\(z_i\)</span> 是输入向量的第 <span class="math inline">\(i\)</span> 个分量,<span class="math inline">\(n\)</span> 是向量的维度。Softmax 函数将输入向量中的每个分量映射到 <span class="math inline">\([0, 1]\)</span> 区间内,并且所有分量的和为 <span class="math inline">\(1\)</span>,因此可以将其解释为概率分布。</p><p>我曾经尝试询问 ChatGPT,为什么 Softmax 函数需要 e 指数。以下是它的回答,仅供参考:</p><blockquote><p>在 Softmax 函数中,使用 e 指数是因为它具有以下两个性质:</p><ul><li>e 指数是一个正数,它可以保证 Softmax 函数的输出也是一个正数。</li><li>e 指数具有快速增长的特点,这可以使得输出向量中的某些元素比其他元素更突出。</li></ul><p>因此,使用 e 指数可以将输入向量中的每个元素转化为一个正数,并且可以突出某些元素的重要性,从而更好地表达输入向量的特征。同时,使用 e 指数也可以使 Softmax 函数具有平滑性和可微性,使得神经网络的训练更加稳定和高效。</p></blockquote><h3 id="定义隐藏状态-hidden-state">定义隐藏状态 (Hidden State)</h3><p>隐藏状态的维度实质上是定义了 <span class="math inline">\(n_h\)</span> 个可以在状态之间进行推理传递的独立特征。我们并不清楚这些特征是什么,也不知道总共应该有多少个这样的独立特征。我们能做的,就是假设有 <span class="math inline">\(n_h\)</span> 个独立特征,然后让机器来学习拟合这些特征。当然,一般来说,<span class="math inline">\(n_h\)</span>、<span class="math inline">\(n_x\)</span> 和 <span class="math inline">\(n_y\)</span> 之间没有直接关系。</p><h3 id="前向传播">前向传播</h3><p>结合上面的信息,我们可以推断隐藏状态的更新计算方式如下: <span class="math display">\[h^{\langle t \rangle} = \operatorname{g}(W_{hh}h^{\langle t-1 \rangle} + W_{hx}x^{\langle t \rangle} + b_h)\]</span> 其中,</p><ul><li><span class="math inline">\(\operatorname{g}\)</span> 是激活函数,可以选择 <span class="math inline">\(\operatorname{tanh}(\cdot)\)</span>, <span class="math inline">\(\operatorname{relu}(\cdot)\)</span> 等。例如,如果选择的是<span class="math inline">\(\operatorname{tanh}(\cdot)\)</span>,上面公式就变成:</li></ul><p><span class="math display">\[h^{\langle t \rangle} = \operatorname{tanh}(W_{hh}h^{\langle t-1 \rangle} + W_{hx}x^{\langle t \rangle} + b_h)\]</span></p><ul><li>参数矩阵 <span class="math inline">\(W_{hh}\)</span> 的维度是 <span class="math inline">\((n_h, n_h)\)</span>, 而 <span class="math inline">\(W_{hx}\)</span> 的维度是 <span class="math inline">\((n_h, n_x)\)</span>。</li><li>偏置矩阵 <span class="math inline">\(b_h\)</span> 的维度和隐藏状态 <span class="math inline">\(h^{\langle t \rangle}\)</span> 的维度一致,是 <span class="math inline">\((n_h, 1)\)</span>。</li></ul><p>然后用输出的隐藏状态来计算预测值 (Prediction) <span class="math inline">\(\hat{y}\)</span>: <span class="math display">\[\hat{y}^{\langle t \rangle} = \operatorname{softmax}(W_{yh}h^{\langle t \rangle} + b_y)\]</span> 其中,</p><ul><li>我们使用 <span class="math inline">\(\hat{y}\)</span> 表示预测值,而不是标定值。</li><li>参数矩阵 <span class="math inline">\(W_{yh}\)</span> 的维度是 <span class="math inline">\((n_y, n_h)\)</span>。</li><li>偏置矩阵 <span class="math inline">\(b_y\)</span> 的维度和预测值 <span class="math inline">\(\hat{y}^{\langle t \rangle}\)</span> 的维度一致,是 <span class="math inline">\((n_y, 1)\)</span>。</li></ul><p>把这两个公式结合起来,就可以得到如下 RNN 单元的内部结构:</p><img src="/2023/06/08/rnn-gru-lstm/rnn_computation.png" class="" title="RNN 单元计算图"><p>我们也可以借助 <code>NumPy</code> 库,来实现前向传播:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np<br><br><span class="hljs-comment"># 定义 Softmax 激活函数</span><br><span class="hljs-keyword">def</span> <span class="hljs-title function_">softmax</span>(<span class="hljs-params">z</span>):<br> exp_z = np.exp(z)<br> <span class="hljs-keyword">return</span> exp_z / np.<span class="hljs-built_in">sum</span>(exp_z)<br><br><span class="hljs-comment"># 定义 RNN 类</span><br><span class="hljs-keyword">class</span> <span class="hljs-title class_">RNN</span>:<br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, input_size, hidden_size, output_size</span>):<br> <span class="hljs-comment"># 初始化权重和偏置</span><br> self.Whx = np.random.randn(hidden_size, input_size) * <span class="hljs-number">0.01</span><br> self.Whh = np.random.randn(hidden_size, hidden_size) * <span class="hljs-number">0.01</span><br> self.Wyh = np.random.randn(output_size, hidden_size) * <span class="hljs-number">0.01</span><br> self.bh = np.zeros((hidden_size, <span class="hljs-number">1</span>))<br> self.by = np.zeros((output_size, <span class="hljs-number">1</span>))<br> <br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, x</span>):<br> <span class="hljs-comment"># 初始化隐藏状态和输出</span><br> h = np.zeros((self.Whh.shape[<span class="hljs-number">0</span>], <span class="hljs-number">1</span>))<br> y = np.zeros((self.Wyh.shape[<span class="hljs-number">0</span>], <span class="hljs-number">1</span>))<br> <br> <span class="hljs-comment"># 遍历输入序列</span><br> <span class="hljs-keyword">for</span> t <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-built_in">len</span>(x)):<br> <span class="hljs-comment"># 计算隐藏状态</span><br> h = np.tanh(np.dot(self.Whx, x[t]) + np.dot(self.Whh, h) + self.bh)<br> <span class="hljs-comment"># 计算输出</span><br> y = softmax(np.dot(self.Why, h) + self.by)<br> <br> <span class="hljs-keyword">return</span> y<br></code></pre></td></tr></table></figure><h2 id="gru">GRU</h2><p>从前面 RNN 的结构来看,hidden state 携带的特征信息有助于下一步的推导,因此具备一定的记忆能力。记忆是序列生成的一项重要能力。试想一下,如果你在说一段话的时候,说着说着就不记得前面说了什么,那么后面说的内容可能就与上文没有任何关系,表现出来的就是在胡说八道。那么,传统 RNN 结构的 hidden state 到底能携带多少记忆信息呢?直接增大 hidden state 的维度,携带的记忆信息是否更多?这些可以在实验过程中去尝试。但总的来说,根据以往的经验,随着时间步的增加,这种结构的 RNN 可能会出现梯度消失 (Vanishing gradients) 或梯度爆炸 (Exploding gradients) 的问题。因此,我们也需要根据具体的任务和数据情况,来改进 RNN 架构。GRU 和 LSTM 就是其中的两种。从历史来看,LSTM 比 GRU 更早提出来;但考虑到 GRU 更简单一些,我们就先讨论 GRU。</p><p>GRU 是 <strong>G</strong>ated <strong>R</strong>ecurrent <strong>U</strong>nit 的简称,由 Cho 等人在 2014 年<sup id="fnref:1" class="footnote-ref"><a href="#fn:1" rel="footnote"><span class="hint--top hint--rounded" aria-label="Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio. Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation. arXiv preprint [arXiv: 1406.1078](https://arxiv.org/abs/1406.1078), 2014.">[1]</span></a></sup>提出。与传统的 RNN 相比,GRU 可以更好地处理梯度消失和梯度爆炸等问题,同时也具有更强的记忆能力和表达能力。</p><p>GRU 与传统 RNN 相比,增加了一个门控机制,用于控制 hidden state 中的信息流动和遗忘。具体来说,GRU 通过引入<strong>更新门</strong>和<strong>重置门</strong>,来控制 hidden state 中的信息更新和遗忘。</p><ul><li>更新门 (update gate) 用于控制当前时刻输入和上一时刻 hidden state 的<strong>“信息更新”</strong>。</li><li>重置门 (reset gate) 用于控制当前时刻输入和上一时刻 hidden state 的<strong>“信息遗忘”</strong>。</li></ul><p>通过这种方式,GRU 可以更好地控制 hidden state 中的信息流动,从而提高模型的性能和记忆能力。</p><p>首先,通过 sigmoid 激活函数,定义两个控制门,分别是更新门 <span class="math inline">\(\Gamma_u^{\langle t \rangle}\)</span> 和 重置门 <span class="math inline">\(\Gamma_r^{\langle t \rangle}\)</span>: <span class="math display">\[\Gamma_u^{\langle t \rangle} = \sigma(W_{uh}h^{\langle t-1 \rangle} + W_{ux}x^{\langle t \rangle} + b_u)\]</span></p><p><span class="math display">\[\Gamma_r^{\langle t \rangle} = \sigma(W_{rh}h^{\langle t-1 \rangle} + W_{rx}x^{\langle t \rangle} + b_r)\]</span></p><p>将重置门引入 hidden state,来控制 hidden state 的各个特征是否要重置,如下: <span class="math display">\[\tilde{h}^{\langle t \rangle} = \operatorname{tanh}(W_{hh}(\Gamma_r^{\langle t \rangle} \odot h^{\langle t-1 \rangle}) + W_{hx}x^{\langle t \rangle} + b_h)\]</span> 其中 <span class="math inline">\(\odot\)</span> 表示逐元素相乘。和前面 hidden state 的更新公式相比,我们用 <span class="math inline">\(\Gamma_r^{\langle t \rangle} \odot h^{\langle t-1 \rangle}\)</span> 来替换 <span class="math inline">\(h^{\langle t-1 \rangle}\)</span>。这样做带来的好处是:如果 <span class="math inline">\(\Gamma_r^{\langle t \rangle}\)</span> 的某个元素是 0,则代表其对应的特征被重置为 0,也就是被遗忘了。另外,我们用 <span class="math inline">\(\tilde{h}^{\langle t \rangle}\)</span> 而不是 <span class="math inline">\(h^{\langle t \rangle}\)</span> 来表示输出的结果,是因为它还不是最终的 hidden state,而是 candidate hidden state(候选隐藏状态)。最终的 hidden state 应该是: <span class="math display">\[h^{\langle t \rangle} = (1-\Gamma_u^{\langle t \rangle}) \odot h^{\langle t-1 \rangle} + \Gamma_u^{\langle t \rangle} \odot \tilde{h}^{\langle t \rangle}\]</span> 它通过更新门 <span class="math inline">\(\Gamma_u^{\langle t \rangle}\)</span> 来控制是否更新到 candidate hidden state 对应的特征。如果 <span class="math inline">\(\Gamma_u^{\langle t \rangle}\)</span> 的某个元素为 0,那么 <span class="math inline">\(h^{\langle t-1 \rangle}\)</span> 对应元素的特征值就保留下来了,即不更新;如果为 1,那么就要完全用 <span class="math inline">\(\tilde{h}^{\langle t-1 \rangle}\)</span> 对应元素的特征值,即完全更新。</p><h2 id="lstm">LSTM</h2><p>LSTM 是 <strong>L</strong>ong <strong>S</strong>hort-<strong>T</strong>erm <strong>M</strong>emory 的简称,是 Hochireiter 和 Schmidhuber 于 1997 年提出的<sup id="fnref:2" class="footnote-ref"><a href="#fn:2" rel="footnote"><span class="hint--top hint--rounded" aria-label="S. Hochreiter and J. Schmidhuber. 1997. Long short-term memory. Neural Computation, 9(8):1735–1780.">[2]</span></a></sup>,比 GRU 早得多。同样为了更好地控制信息流动,LSTM 也拥有门控逻辑。具体来说,LSTM 使用了三个控制门,分别是:</p><ul><li>输入门 (input gate) <span class="math inline">\(\Gamma_i\)</span></li><li>遗忘门 (forget gate) <span class="math inline">\(\Gamma_f\)</span></li><li>输出门 (output gate) <span class="math inline">\(\Gamma_o\)</span></li></ul><p>和 GRU 里对门的定义类似,这三个门的公式分别是: <span class="math display">\[\Gamma_i^{\langle t \rangle} = \sigma(W_{ih}h^{\langle t-1 \rangle} + W_{ix}x^{\langle t \rangle} + b_i)\]</span></p><p><span class="math display">\[\Gamma_f^{\langle t \rangle} = \sigma(W_{fh}h^{\langle t-1 \rangle} + W_{fx}x^{\langle t \rangle} + b_f)\]</span></p><p><span class="math display">\[\Gamma_o^{\langle t \rangle} = \sigma(W_{oh}h^{\langle t-1 \rangle} + W_{ox}x^{\langle t \rangle} + b_o)\]</span></p><p>现在我们要多引入一个名为 cell state(单元状态,也被称为记忆细胞状态,memory cell state)的项来存储一部分记忆信息,记为 <span class="math inline">\(c^{\langle t \rangle}\)</span>。和 GRU 的候选隐藏状态类似,cell state 也有它的候选 (candidate) 形式 <span class="math inline">\(\tilde{c}^{\langle t \rangle}\)</span>,如下: <span class="math display">\[\tilde{c}^{\langle t \rangle} = \tanh(W_{ch}h^{\langle t-1 \rangle} + W_{cx}x^{\langle t \rangle} + b_c)\]</span> 和 GRU 中候选隐藏状态 (candidate hidden state) 不同的是,上述这个操作没有没有引入新的门控。最终的 cell state 如下: <span class="math display">\[c^{\langle t \rangle} = \Gamma_f^{\langle t \rangle} \odot c^{\langle t-1 \rangle} + \Gamma_i^{\langle t \rangle} \odot \tilde{c}^{\langle t \rangle}\]</span> 可以看到,<strong>遗忘门 <span class="math inline">\(\Gamma_f\)</span> 控制前一个状态 <span class="math inline">\(c^{\langle t-1 \rangle}\)</span> 的遗忘量,输入门 <span class="math inline">\(\Gamma_i\)</span> 则控制当前新输入状态 <span class="math inline">\(\tilde{c}^{\langle t \rangle}\)</span> 的输入量。</strong></p><p>最后是 hidden state。有了 cell state,我们再通过输出门控制输出量,让 hidden state 用 cell state 来表示: <span class="math display">\[h^{\langle t \rangle} = \Gamma_o \odot \operatorname{tanh}(c^{\langle t \rangle})\]</span> 仔细对比 GRU 和 LSTM 的形式,你会发现:GRU 是将具有记忆功能的 cell state,直接等价于 hidden state;而 LSTM 则是将两者分离开来。</p><h2 id="lstm-和-gru-优缺点对比">LSTM 和 GRU 优缺点对比</h2><h3 id="lstm-的优缺点">LSTM 的优缺点</h3><h4 id="优点">优点</h4><ul><li>LSTM 具有记忆单元,能够有效地处理长序列数据,避免了梯度消失或梯度爆炸的问题。</li><li>LSTM 的门控机制可以有效地控制信息流的进出,从而增强了模型的泛化能力。</li><li>LSTM 可以处理多种类型的输入,如文本、音频和图像等。</li></ul><h4 id="缺点">缺点</h4><ul><li>LSTM 的计算复杂度较高,训练时间较长。</li><li>LSTM 的模型结构较为复杂,不易理解和调试。</li><li>LSTM 对于输入序列的长度敏感,需要对序列进行截断或填充。</li></ul><h3 id="gru-的优缺点">GRU 的优缺点</h3><h4 id="优点-1">优点</h4><ul><li>GRU 的计算复杂度较低,训练时间较短。</li><li>GRU 的模型结构较简单,易于理解和调试。</li><li>GRU 的门控机制比 LSTM 更简单,但仍能够有效地控制信息流的进出。</li></ul><h4 id="缺点-1">缺点</h4><ul><li>GRU 的记忆单元较为简单,可能无法处理特别长的序列数据。</li><li>GRU 的泛化能力可能不如 LSTM。</li><li>GRU 对于输入序列的长度也比较敏感,需要对序列进行截断或填充。</li></ul><h2 id="参考">参考</h2><section class="footnotes"><div class="footnote-list"><ol><li><span id="fn:1" class="footnote-text"><span>Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio. Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation. arXiv preprint <a href="https://arxiv.org/abs/1406.1078">arXiv: 1406.1078</a>, 2014. <a href="#fnref:1" rev="footnote" class="footnote-backref"> ↩︎</a></span></span></li><li><span id="fn:2" class="footnote-text"><span>S. Hochreiter and J. Schmidhuber. 1997. Long short-term memory. Neural Computation, 9(8):1735–1780. <a href="#fnref:2" rev="footnote" class="footnote-backref"> ↩︎</a></span></span></li></ol></div></section>]]></content>
<categories>
<category>NLP</category>
</categories>
<tags>
<tag>LLM</tag>
<tag>RNN</tag>
<tag>GRU</tag>
<tag>LSTM</tag>
</tags>
</entry>
<entry>
<title>PyTorch 随机 Mask 的技巧</title>
<link href="/2023/04/27/pytorch-random-mask/"/>
<url>/2023/04/27/pytorch-random-mask/</url>
<content type="html"><![CDATA[<p>本文记录一下用 PyTorch 随机 Mask 的技巧。</p><p>这里假设数值低于 2 的 token 都是特殊 token,不做处理。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">import</span> torch<br><br><span class="hljs-comment"># 定义 mask 的 token id</span><br>mask_token_id = <span class="hljs-number">4</span><br><br><span class="hljs-keyword">def</span> <span class="hljs-title function_">mlm</span>(<span class="hljs-params">tensor</span>):<br> <span class="hljs-comment"># 克隆一份数据,避免修改原始输入数据</span><br> tensor = tensor.detach().clone()<br><br> rand = torch.rand(tensor.shape)<br> <span class="hljs-comment"># 50% 的概率随机 mask</span><br> <span class="hljs-comment"># 忽略掉数值低于 2 的特殊 token</span><br> mask_arr = (rand < <span class="hljs-number">0.5</span>) * (tensor > <span class="hljs-number">2</span>) <br> <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(tensor.shape[<span class="hljs-number">0</span>]):<br> selection = torch.flatten(mask_arr[i].nonzero()).tolist()<br> tensor[i, selection] = mask_token_id<br><br> <span class="hljs-keyword">return</span> tensor<br></code></pre></td></tr></table></figure><p>其中 <code>mask_arr = (rand < 0.5) * (tensor > 2)</code> 只是一个示例,具体应根据实际情形来调整。</p><p>简单测试一下:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><code class="hljs python">samples = torch.tensor([<br> [<span class="hljs-number">0</span>, <span class="hljs-number">1652</span>, <span class="hljs-number">233</span>, <span class="hljs-number">3252</span>, <span class="hljs-number">1234</span>, <span class="hljs-number">634</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>],<br> [<span class="hljs-number">0</span>, <span class="hljs-number">223</span>, <span class="hljs-number">1530</span>, <span class="hljs-number">232</span>, <span class="hljs-number">4134</span>, <span class="hljs-number">832</span>, <span class="hljs-number">20</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>],<br>])<br><br>labels = samples<br>input_ids = mlm(samples)<br><br><span class="hljs-built_in">print</span>(<span class="hljs-string">"labels: \n"</span>, labels)<br><span class="hljs-built_in">print</span>(<span class="hljs-string">"input_ids: \n"</span>, input_ids)<br></code></pre></td></tr></table></figure><p>输出结果(因为是随机 mask,每个人的输出会有所不同):</p><figure class="highlight text"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><code class="hljs text">labels:<br> tensor([<br> [0, 1652, 233, 3252, 1234, 634, 1, 1, 1, 1],<br> [0, 223, 1530, 232, 4134, 832, 20, 1, 1, 1]<br> ])<br><br>input_ids:<br> tensor([<br> [0, 4, 4, 4, 1234, 634, 1, 1, 1, 1],<br> [0, 4, 4, 232, 4, 832, 4, 1, 1, 1]<br>])<br></code></pre></td></tr></table></figure><p>凡是出现数值 4 的,都是被 mask 掉的。</p><p>下面是更一般的使用方式,注意最后一行用了 <code>mlm</code> 方法。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">from</span> tqdm.auto <span class="hljs-keyword">import</span> tqdm<br><br>paths = [<span class="hljs-string">"sample_0.txt"</span>, <span class="hljs-string">"sample_1.txt"</span>, <span class="hljs-string">"sample_2.txt"</span>, <span class="hljs-string">"sample_3.txt"</span>, <span class="hljs-string">"sample_4.txt"</span>]<br><br>input_ids = []<br>mask = []<br>labels = []<br><br><span class="hljs-keyword">for</span> path <span class="hljs-keyword">in</span> tqdm(paths):<br> <span class="hljs-keyword">with</span> <span class="hljs-built_in">open</span>(path, <span class="hljs-string">'r'</span>, encoding=<span class="hljs-string">'utf-8'</span>) <span class="hljs-keyword">as</span> f:<br> lines = f.read().split(<span class="hljs-string">'\n'</span>)<br> sample = tokenizer(lines, max_length=<span class="hljs-number">512</span>, padding=<span class="hljs-string">'max_length'</span>, trunction=<span class="hljs-literal">True</span>)<br> labels.append(sample.input_ids)<br> mask.append(sample.attention_mask)<br> input_ids.append(mlm(sample.input_ids))<br></code></pre></td></tr></table></figure>]]></content>
<tags>
<tag>PyTorch</tag>
</tags>
</entry>
<entry>
<title>【学习笔记】强化学习基本概念</title>
<link href="/2023/04/19/drl-study-notes-basic-concept/"/>
<url>/2023/04/19/drl-study-notes-basic-concept/</url>
<content type="html"><![CDATA[<p>为了加强理解,我们以《超级玛丽》游戏为例,让 AI 使用强化学习,操作马里奥走动和跳跃,完成通关任务。</p><p><a href="https://youtu.be/vmkRMvhCW5c" title="深度强化学习(1/5):基本概念 Deep Reinforcement Learning (1/5)"><img src="https://res.cloudinary.com/marcomontalbano/image/upload/v1681899633/video_to_markdown/images/youtube--vmkRMvhCW5c-c05b58ac6eb4c4700831b2b3070cd403.jpg" alt="深度强化学习(1/5):基本概念 Deep Reinforcement Learning (1/5)" /></a></p><h2 id="术语">术语</h2><h3 id="agent">Agent</h3><p>用强化学习驱动的作用对象,也就是下面 Action 的作用对象。在《超级玛丽》游戏中,这个 Agent 就是 AI 控制的马里奥。</p><h3 id="state">State</h3><p>当前系统状态,用 <span class="math inline">\(s\)</span> 表示。</p><h3 id="action">Action</h3><p>作用对象执行的操作或动作,用 <span class="math inline">\(a\)</span> 表示。例如,在《超级玛丽》游戏中,所有的操作包括往左走 (left)、往右走 (right)、往上走 (up)。那么,</p><p><span class="math display">\[a \in \{left, right, up\}\]</span></p><h3 id="policy">Policy</h3><p>策略函数,用 <span class="math inline">\(\pi\)</span> 表示。策略函数是根据当前观测到的状态,控制 Agent 如何从动作库(Action)中选择一种动作来执行。策略函数是一个条件概率密度函数,即 <span class="math inline">\(\pi :(s, a) \mapsto [0, 1]\)</span></p><p><span class="math display">\[\pi (a | s) = \mathbb{P}(A = a | S = s)\]</span></p><p>例如,</p><ul><li><p><span class="math inline">\(\pi(left | s) = 0.5\)</span></p></li><li><p><span class="math inline">\(\pi(right | s) = 0.3\)</span></p></li><li><p><span class="math inline">\(\pi(up | s) = 0.2\)</span></p></li></ul><p><strong>强化学习就是学习出这个策略函数</strong>。基于策略函数,控制 Agent 每一步选择出一个最优动作,完成目标。</p><h3 id="reward">Reward</h3><p>奖励,用 <span class="math inline">\(R\)</span> 表示。Agent 每做一次动作,系统就会给与一定的奖励。<strong>奖励定义的好坏非常影响强化学习的结果。</strong>例如,在《超级玛丽》游戏中,</p><ul><li><p>搜集到一个金币:<span class="math inline">\(R = +1\)</span></p></li><li><p>赢得比赛:<span class="math inline">\(R = +10000\)</span></p></li><li><p>输掉比赛:<span class="math inline">\(R = -10000\)</span></p></li><li><p>什么都没发生:<span class="math inline">\(R = 0\)</span></p></li></ul><p>强化学习的目标,是<strong>获得的奖励总和要尽量高</strong>。</p><h3 id="state-transition">State Transition</h3><p>状态转移。在执行了一个动作后,系统从一个状态变成另一个状态的过程,就是状态转移。用下面的图表示: <pre><code class="mermaid" >graph LRA[old state]--action--->B[new state]</code></pre></p><p>状态转移可以是随机的。这种随机性来源于环境 (Environment)。状态转移满足如下的公式:</p><p><span class="math display">\[p(s'|s, a) = \mathbb{P}(S' = s' | S = s, A = a)\]</span></p><h4 id="强化学习的随机性来源">强化学习的随机性来源</h4><ul><li><p>Action 有随机性</p></li><li><p>State Transition 有随机性</p></li></ul><h4 id="执行流程">执行流程</h4><pre><code class="mermaid" >graph TDS1["s(1)"]-->A1{"a(1)"}A1-->S2["s(2)"]A1-.->R1("r(1)")S2-->A2{"a(2)"}A2-->S3["s(3)"]A2-.->R2("r(2)")S3--"……"--->ST["s(T)"]ST-->AT{"a(T)"}AT-->ST1["s(T+1)"]AT-.->RT("r(T)")</code></pre><p>形成 <code>(state, action, reward)</code> 的轨迹 (trajectory): <span class="math display">\[(s_1,a_1,r_1), (s_2,a_2,r_2), \dots , (s_T, a_T, r_T)\]</span></p><h3 id="return">Return</h3><p>回报,未来的累计奖励 (Cumulative future reward)。<span class="math inline">\(t\)</span> 时刻的未来累计回报用 <span class="math inline">\(U_t\)</span> 表示,有如下公式:</p><p><span class="math display">\[U_t = R_t + R_{t+1} + R_{t+2} + R_{t+3} + \dots\]</span></p><p>从经验来看,未来的奖励 (reward) 有时不如当下的奖励来得有价值。 例如,当下的 100 元比一年后的 100 元更有价值。 为了体现这一点,我们可以给未来的每一项奖励乘上一个取值范围是 <code>[0, 1]</code> 的因子,变成</p><p><span class="math display">\[U_t = R_t + \gamma^{1}R_{t+1} + \gamma^{2}R_{t+2} + \gamma^{3}R_{t+3} + \dots\]</span></p><p>形成折扣回报 (<strong>Discounted</strong> return,即 Cumulative <strong>discounted</strong> future reward)。其中 <span class="math inline">\(\gamma\)</span> 是介于 0 到 1 之间的折扣率 (discount rate)。 折扣率是一项超参数 (Hyper-parameter),可以在训练时被调优。</p><p><span class="math inline">\(t\)</span> 时刻的回报是随机的,这是因为回报依赖于未来的奖励;而未来的奖励则依赖于未来的动作和状态。 后两者都具有随机性,这也使得回报必然具有随机性。用数学公式再来整理一遍:</p><p>由于:</p><ol type="1"><li>动作具有随机性: <span class="math display">\[\mathbb{P}(A = a | S = s) = \pi (a | s)\]</span></li><li>状态转移具有随机性: <span class="math display">\[\mathbb{P}(S' = s' | S = s, A = a) = p(s'|s, a)\]</span></li></ol><p>对于任意的 <span class="math inline">\(i \ge t\)</span>,奖励 <span class="math inline">\(R_i\)</span> 依赖于 <span class="math inline">\(S_i\)</span> 和 <span class="math inline">\(A_i\)</span>。因此,给定 <span class="math inline">\(S_t\)</span>, 回报 <span class="math inline">\(U_t\)</span> 应依赖于随机变量:<span class="math inline">\(A_t, A_{t+1}, A_{t+2}, \dots\)</span> 和 <span class="math inline">\(S_t, S_{t+1}, S_{t+2}, \dots\)</span></p><h3 id="value-functions">Value Functions</h3><h4 id="action-value-function">Action-Value Function</h4><p>动作价值函数,是在某一时刻 <span class="math inline">\(t\)</span>,对给定的策略函数 <span class="math inline">\(\pi\)</span>,输出某个动作 <span class="math inline">\(a_t\)</span>,得到未来回报期望的函数。用 <span class="math inline">\(Q\)</span> 表示。即</p><p><span class="math display">\[Q_{\pi}(s_t, a_t) = \mathbb{E}[U_t|S_t=s_t, A_t=a_t]\]</span></p><p><span class="math inline">\(U_t\)</span> 是<span class="math inline">\(a_t, a_{t+1}, a_{t+2}, \dots\)</span> 和 <span class="math inline">\(s_t, s_{t+1}, s_{t+2}, \dots\)</span> 的函数。</p><p><span class="math display">\[U_t = U(a_t, a_{t+1}, a_{t+2}, \dots, s_t, s_{t+1}, s_{t+2}, \dots)\]</span></p><p>计算期望时,所有的<span class="math inline">\(a_{t+1}, a_{t+2}, \dots\)</span> 和 <span class="math inline">\(s_{t+1}, s_{t+2}, \dots\)</span> 都应被积分(或累加)掉,最后只保留 <span class="math inline">\(a_t\)</span> 和 <span class="math inline">\(s_t\)</span>。</p><p><span class="math inline">\(t\)</span> 时刻对不同的策略函数分别求期望,找出期望值最大的那个动作价值函数,即可得到<strong>最优动作价值函数 (Optimal action-value function)</strong>。 <span class="math display">\[Q^*(s_t, a_t) = \max_{\pi}Q_{\pi}(s_t, a_t)\]</span></p><p><span class="math inline">\(Q^*\)</span> 与 <span class="math inline">\(\pi\)</span> 无关。</p><h4 id="state-value-function">State-Value Function</h4><p>对所有的动作 <span class="math inline">\(A\)</span>,计算 <span class="math inline">\(Q\)</span> 的期望,可得到状态价值函数(State-value function)。用下面公式表示:</p><p><span class="math display">\[V_{\pi}(s_t) = \mathbb{E}_A[Q_{\pi}(s_t, A)]\]</span></p><p>因为 <span class="math inline">\(A\)</span> 服从策略函数定义的概率分布,即 <span class="math inline">\(A \sim \pi(\cdot|s_t)\)</span>,因此上式也可以写成</p><p>(动作是离散的情况) <span class="math display">\[V_{\pi}(s_t) = \mathbb{E}_A[Q_{\pi}(s_t, A)] = \sum_a \pi(a|s_t)\cdot Q_{\pi}(s_t, a)\]</span> (动作是连续的情况) <span class="math display">\[V_{\pi}(s_t) = \mathbb{E}_A[Q_{\pi}(s_t, A)] = \int_a \pi(a|s_t)\cdot Q_{\pi}(s_t, a) da\]</span></p><h4 id="两种价值函数的作用">两种价值函数的作用</h4><ul><li><p>动作价值函数评估的是,对于给定的策略函数 <span class="math inline">\(\pi\)</span>,Agent 处在状态 <span class="math inline">\(s\)</span> 时,选择动作 <span class="math inline">\(a\)</span> 的好坏。</p></li><li><p>状态价值函数评估的是,对于给定的策略函数 <span class="math inline">\(\pi\)</span>,Agent 处在状态 <span class="math inline">\(s\)</span> 的好坏。</p></li></ul><p>如果要评估策略函数本身的好坏,可以对所有状态 <span class="math inline">\(S\)</span> 求状态价值函数的期望,即 <span class="math display">\[\mathbb{E}_S[V_{\pi}(S)]\]</span></p><h2 id="ai-如何控制-agent">AI 如何控制 Agent</h2><ul><li><p>假设我们有一个好的策略函数 <span class="math inline">\(\pi(a|s)\)</span>,那么我们只需要基于当前时刻的状态 <span class="math inline">\(s_t\)</span>,随机采样一个动作即可: <span class="math display">\[a_t \sim \pi(\cdot|s_t)\]</span></p></li><li><p>假设我们知道最优动作价值函数 <span class="math inline">\(Q^*(s, a)\)</span>,那么我们只需要基于当前时刻的状态 <span class="math inline">\(s_t\)</span>,选择一个可以最大化动作价值函数的动作即可:</p></li></ul><p><span class="math display">\[a_t = \arg\max_a Q^*(s_t, a)\]</span></p><p>由此可见,强化学习的目标既可以是<strong>策略函数</strong> <span class="math inline">\(\pi(a|s)\)</span>,也可以是<strong>最优动作价值函数</strong> <span class="math inline">\(Q^*(s, a)\)</span>。</p><h2 id="参考">参考</h2><ul><li><a href="https://github.com/wangshusen/DRL">Deep Reinforcement Learning</a></li></ul>]]></content>
<categories>
<category>DRL</category>
</categories>
<tags>
<tag>Agent</tag>
<tag>Environment</tag>
<tag>State</tag>
<tag>Action</tag>
<tag>Reward</tag>
<tag>Policy</tag>
<tag>StateTransition</tag>
<tag>Return</tag>
<tag>ValueFunctions</tag>
</tags>
</entry>
<entry>
<title>在生产环境中使用 Docker 部署 Flask 服务</title>
<link href="/2023/04/18/flask-uwsgi-docker/"/>
<url>/2023/04/18/flask-uwsgi-docker/</url>
<content type="html"><![CDATA[<h2 id="用-anaconda-开启一个新的环境">用 Anaconda 开启一个新的环境</h2><p>创建一个新的 conda 环境,并取名为 <code>flask</code></p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs shell">conda create --name flask python=3<br></code></pre></td></tr></table></figure><p>激活 <code>flask</code> 环境</p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs shell">conda activate flask<br></code></pre></td></tr></table></figure><h2 id="开启一个简单的-flask-服务">开启一个简单的 Flask 服务</h2><p>用 pip 安装 Flask</p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs shell">pip install flask<br></code></pre></td></tr></table></figure><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-comment"># main.py</span><br><span class="hljs-keyword">from</span> flask <span class="hljs-keyword">import</span> Flask, request, jsonify<br><br>app = Flask(__name__)<br>app.config[<span class="hljs-string">'JSON_AS_ASCII'</span>] = <span class="hljs-literal">False</span><br><br><span class="hljs-meta">@app.route(<span class="hljs-params"><span class="hljs-string">"/hello"</span>, methods=[<span class="hljs-string">"GET"</span>]</span>)</span><br><span class="hljs-keyword">def</span> <span class="hljs-title function_">hello_world</span>():<br> name = request.args.get(<span class="hljs-string">"name"</span>, <span class="hljs-string">""</span>)<br> <span class="hljs-keyword">return</span> jsonify({<br> <span class="hljs-string">"err_code"</span>: <span class="hljs-number">0</span>,<br> <span class="hljs-string">"ret"</span>: <span class="hljs-string">f"Hello, <span class="hljs-subst">{name}</span>"</span><br> })<br><br><span class="hljs-keyword">if</span> __name__ == <span class="hljs-string">"__main__"</span>:<br> app.run()<br></code></pre></td></tr></table></figure><p>运行程序,得到如下输出结果:</p><figure class="highlight pgsql"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><code class="hljs pgsql"> * Serving Flask app <span class="hljs-string">'main'</span><br> * <span class="hljs-keyword">Debug</span> mode: <span class="hljs-keyword">off</span><br><span class="hljs-built_in">WARNING</span>: This <span class="hljs-keyword">is</span> a development <span class="hljs-keyword">server</span>. <span class="hljs-keyword">Do</span> <span class="hljs-keyword">not</span> use it <span class="hljs-keyword">in</span> a production deployment. Use a production WSGI <span class="hljs-keyword">server</span> <span class="hljs-keyword">instead</span>.<br> * Running <span class="hljs-keyword">on</span> http://<span class="hljs-number">127.0</span><span class="hljs-number">.0</span><span class="hljs-number">.1</span>:<span class="hljs-number">5000</span><br>Press CTRL+C <span class="hljs-keyword">to</span> quit<br></code></pre></td></tr></table></figure><p>然后浏览器打开:</p><figure class="highlight awk"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs awk">http:<span class="hljs-regexp">//</span><span class="hljs-number">127.0</span>.<span class="hljs-number">0.1</span>:<span class="hljs-number">5000</span>/hello?name=AI探险家<br></code></pre></td></tr></table></figure><p>就会从浏览器中得到如下结果:</p><figure class="highlight json"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs json"><span class="hljs-punctuation">{</span><span class="hljs-attr">"err_code"</span><span class="hljs-punctuation">:</span><span class="hljs-number">0</span><span class="hljs-punctuation">,</span><span class="hljs-attr">"ret"</span><span class="hljs-punctuation">:</span><span class="hljs-string">"Hello, AI探险家"</span><span class="hljs-punctuation">}</span><br></code></pre></td></tr></table></figure><p>从前面开启服务的输出结果可以看出,我们的 Flask 应用目前是一个开发服务。 要正式部署的话,建议使用 WSGI。那我们就接受建议,使用 <code>uWSGI</code> 搭配 <code>nginx</code> 来部署服务搭配。</p><h2 id="安装-uwsgi">安装 uWSGI</h2><p>首先安装 uWSGI 相关的依赖环境。下面是 Ubuntu 系统下的安装命令</p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs shell">apt-get install build-essential python-dev<br></code></pre></td></tr></table></figure><p>其他操作系统可以参考官方安装教程:<a href="https://uwsgi-docs.readthedocs.io/en/latest/Install.html">Installing uWSGI</a></p><p>接下来是安装 <code>uwsgi</code>。理论上可以直接用 <code>pip</code> 安装。</p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs shell">pip install uwsgi<br></code></pre></td></tr></table></figure><p>奇怪的是,我这边在 Anaconda 环境中用 <code>pip</code> 安装 <code>uwsgi</code> 会报错。用 <code>conda</code> 命令来安装则是正常的:</p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs shell">conda install -c conda-forge uwsgi<br></code></pre></td></tr></table></figure><h2 id="配置-uwsgi">配置 uWSGI</h2><p>在 <code>main.py</code> 同个目录创建一个 <code>uwsgi.ini</code> 文件:</p><figure class="highlight ini"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><code class="hljs ini"><span class="hljs-section">[uwsgi]</span><br><span class="hljs-attr">module</span> = main:app<br><br><span class="hljs-attr">uid</span> = www-data<br><span class="hljs-attr">gid</span> = www-data<br><span class="hljs-attr">master</span> = <span class="hljs-literal">true</span><br><span class="hljs-attr">processes</span> = <span class="hljs-number">5</span><br><br><span class="hljs-attr">socket</span> = /tmp/uwsgi.socket<br><span class="hljs-attr">chmod-sock</span> = <span class="hljs-number">664</span><br><span class="hljs-attr">vacuum</span> = <span class="hljs-literal">true</span><br></code></pre></td></tr></table></figure><p>在这份配置文件中,调用的模块是 <code>main.py</code> 的 <code>app</code>。然后使用 <code>www-data</code>(WEB 服务的标准用户) 作为 <code>uwsgi</code> 进程的 <code>uid/gid</code>。 通过 <code>processes</code> 指定 5 个进程。另外,我们给 <code>uwsgi</code> 创建了一个 socket 文件 <code>/tmp/uwsgi.socket</code>,并赋予 664 的执行权限。 作为对比,也可以直接给 socket 设置端口号,例如:、</p><figure class="highlight ini"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><code class="hljs ini"><span class="hljs-section">[uwsgi]</span><br>...<br><span class="hljs-attr">socket</span> = :<span class="hljs-number">3032</span><br>...<br></code></pre></td></tr></table></figure><p>另外,当我们退出进程时,希望 <code>/tmp/uwsgi.socket</code> 文件能够被自动删除,因此可以设置 <code>vacuum = true</code> 来实现这个功能。</p><p>运行 <code>uwsgi</code></p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs shell">uwsgi uwsgi.ini<br></code></pre></td></tr></table></figure><p>得到如下输出结果:</p><figure class="highlight pgsql"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br></pre></td><td class="code"><pre><code class="hljs pgsql">[uWSGI] getting INI <span class="hljs-keyword">configuration</span> <span class="hljs-keyword">from</span> uwsgi.ini<br>*** Starting uWSGI <span class="hljs-number">2.0</span><span class="hljs-number">.21</span> (<span class="hljs-number">64</span>bit) <span class="hljs-keyword">on</span> [Tue Apr <span class="hljs-number">18</span> <span class="hljs-number">21</span>:<span class="hljs-number">13</span>:<span class="hljs-number">37</span> <span class="hljs-number">2023</span>] ***<br>compiled <span class="hljs-keyword">with</span> <span class="hljs-keyword">version</span>: <span class="hljs-number">11.2</span><span class="hljs-number">.0</span> <span class="hljs-keyword">on</span> <span class="hljs-number">28</span> March <span class="hljs-number">2023</span> <span class="hljs-number">07</span>:<span class="hljs-number">32</span>:<span class="hljs-number">14</span><br>os: Linux<span class="hljs-number">-5.15</span><span class="hljs-number">.90</span><span class="hljs-number">.1</span>-microsoft-standard-WSL2 #<span class="hljs-number">1</span> SMP Fri Jan <span class="hljs-number">27</span> <span class="hljs-number">02</span>:<span class="hljs-number">56</span>:<span class="hljs-number">13</span> UTC <span class="hljs-number">2023</span><br>nodename: Airme<br>machine: x86_64<br>clock source: unix<br>pcre jit disabled<br>detected number <span class="hljs-keyword">of</span> CPU cores: <span class="hljs-number">20</span><br><span class="hljs-keyword">current</span> working directory: /home/aizpy/Workspace/FlaskTest<br>detected binary <span class="hljs-type">path</span>: /home/aizpy/miniconda3/envs/flask/bin/uwsgi<br>your processes number <span class="hljs-keyword">limit</span> <span class="hljs-keyword">is</span> <span class="hljs-number">63631</span><br>your memory page size <span class="hljs-keyword">is</span> <span class="hljs-number">4096</span> bytes<br>detected max file descriptor number: <span class="hljs-number">1024</span><br><span class="hljs-keyword">lock</span> engine: pthread robust mutexes<br>thunder <span class="hljs-keyword">lock</span>: disabled (you can <span class="hljs-keyword">enable</span> it <span class="hljs-keyword">with</span> <span class="hljs-comment">--thunder-lock)</span><br>uwsgi socket <span class="hljs-number">0</span> bound <span class="hljs-keyword">to</span> UNIX address /tmp/uwsgi.socket fd <span class="hljs-number">3</span><br>Python <span class="hljs-keyword">version</span>: <span class="hljs-number">3.11</span><span class="hljs-number">.2</span> (main, Mar <span class="hljs-number">27</span> <span class="hljs-number">2023</span>, <span class="hljs-number">23</span>:<span class="hljs-number">42</span>:<span class="hljs-number">44</span>) [GCC <span class="hljs-number">11.2</span><span class="hljs-number">.0</span>]<br>*** Python threads support <span class="hljs-keyword">is</span> disabled. You can <span class="hljs-keyword">enable</span> it <span class="hljs-keyword">with</span> <span class="hljs-comment">--enable-threads ***</span><br>Python main interpreter initialized at <span class="hljs-number">0xa01a98</span><br>your <span class="hljs-keyword">server</span> socket <span class="hljs-keyword">listen</span> backlog <span class="hljs-keyword">is</span> limited <span class="hljs-keyword">to</span> <span class="hljs-number">100</span> connections<br>your mercy <span class="hljs-keyword">for</span> graceful operations <span class="hljs-keyword">on</span> workers <span class="hljs-keyword">is</span> <span class="hljs-number">60</span> seconds<br>mapped <span class="hljs-number">437520</span> bytes (<span class="hljs-number">427</span> KB) <span class="hljs-keyword">for</span> <span class="hljs-number">5</span> cores<br>*** Operational MODE: preforking ***<br>WSGI app <span class="hljs-number">0</span> (mountpoint=<span class="hljs-string">''</span>) ready <span class="hljs-keyword">in</span> <span class="hljs-number">0</span> seconds <span class="hljs-keyword">on</span> interpreter <span class="hljs-number">0xa01a98</span> pid: <span class="hljs-number">13221</span> (<span class="hljs-keyword">default</span> app)<br>*** uWSGI <span class="hljs-keyword">is</span> running <span class="hljs-keyword">in</span> multiple interpreter mode ***<br>spawned uWSGI master process (pid: <span class="hljs-number">13221</span>)<br>spawned uWSGI worker <span class="hljs-number">1</span> (pid: <span class="hljs-number">13222</span>, cores: <span class="hljs-number">1</span>)<br>spawned uWSGI worker <span class="hljs-number">2</span> (pid: <span class="hljs-number">13223</span>, cores: <span class="hljs-number">1</span>)<br>spawned uWSGI worker <span class="hljs-number">3</span> (pid: <span class="hljs-number">13224</span>, cores: <span class="hljs-number">1</span>)<br>spawned uWSGI worker <span class="hljs-number">4</span> (pid: <span class="hljs-number">13225</span>, cores: <span class="hljs-number">1</span>)<br>spawned uWSGI worker <span class="hljs-number">5</span> (pid: <span class="hljs-number">13226</span>, cores: <span class="hljs-number">1</span>)<br></code></pre></td></tr></table></figure><p>确实创建了 5 个进程。当输入<code>Ctrl-C</code> 来退出进程时,可以看到如下输出:</p><figure class="highlight livecodeserver"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><code class="hljs livecodeserver">^CSIGINT/SIGTERM received...killing workers...<br>worker <span class="hljs-number">1</span> buried <span class="hljs-keyword">after</span> <span class="hljs-number">1</span> <span class="hljs-built_in">seconds</span><br>worker <span class="hljs-number">2</span> buried <span class="hljs-keyword">after</span> <span class="hljs-number">1</span> <span class="hljs-built_in">seconds</span><br>worker <span class="hljs-number">3</span> buried <span class="hljs-keyword">after</span> <span class="hljs-number">1</span> <span class="hljs-built_in">seconds</span><br>worker <span class="hljs-number">4</span> buried <span class="hljs-keyword">after</span> <span class="hljs-number">1</span> <span class="hljs-built_in">seconds</span><br>worker <span class="hljs-number">5</span> buried <span class="hljs-keyword">after</span> <span class="hljs-number">1</span> <span class="hljs-built_in">seconds</span><br>goodbye <span class="hljs-built_in">to</span> uWSGI.<br>VACUUM: unix <span class="hljs-built_in">socket</span> /tmp/uwsgi.<span class="hljs-built_in">socket</span> removed.<br></code></pre></td></tr></table></figure><p>此时,5 个进程全部退出了,并且 <code>/tmp/uwsgi.socket</code> 文件也被删除了。</p><h2 id="生成依赖环境">生成依赖环境</h2><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs shell">pip freeze > requirements.txt<br></code></pre></td></tr></table></figure><p>需要注意的是,如果你是通过 conda 安装的 uwsgi,那么执行这个命令后,uWSGI 这一条会变成类似这样的结果:</p><figure class="highlight awk"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs awk">uWSGI @ file:<span class="hljs-regexp">//</span><span class="hljs-regexp">/croot/u</span>wsgi_1679988297904/work<br></code></pre></td></tr></table></figure><p>这是一个绝对路径,无法部署到其他机器上。此时可以根据之前 conda 的安装结果,来指定 uwsgi 的版本号。例如: <figure class="highlight abnf"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs abnf"><span class="hljs-attribute">uWSGI</span><span class="hljs-operator">=</span><span class="hljs-operator">=</span><span class="hljs-number">2.0</span>.<span class="hljs-number">21</span><br></code></pre></td></tr></table></figure></p><p>最终的 <code>requirements.txt</code> 长这样:</p><figure class="highlight abnf"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><code class="hljs abnf"><span class="hljs-attribute">Flask</span><span class="hljs-operator">=</span><span class="hljs-operator">=</span><span class="hljs-number">2.2</span>.<span class="hljs-number">3</span><br><span class="hljs-attribute">uWSGI</span><span class="hljs-operator">=</span><span class="hljs-operator">=</span><span class="hljs-number">2.0</span>.<span class="hljs-number">21</span><br></code></pre></td></tr></table></figure><p>这里删掉了中间的依赖项,因为安装这两个库就会自动安装其他的。</p><h2 id="配置-nginx">配置 nginx</h2><p>创建一份 <code>nginx.conf</code> 文件:</p><figure class="highlight nginx"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br></pre></td><td class="code"><pre><code class="hljs nginx"><span class="hljs-attribute">user</span> www-data;<br><span class="hljs-attribute">worker_processes</span> auto;<br><span class="hljs-attribute">pid</span> /run/nginx.pid;<br><br><span class="hljs-section">events</span> {<br> <span class="hljs-attribute">worker_connections</span> <span class="hljs-number">1024</span>;<br> <span class="hljs-attribute">use</span> <span class="hljs-literal">epoll</span>;<br> <span class="hljs-attribute">multi_accept</span> <span class="hljs-literal">on</span>;<br>}<br><br><span class="hljs-section">http</span> {<br> <span class="hljs-attribute">access_log</span> /dev/stdout;<br> <span class="hljs-attribute">error_log</span> /dev/stdout;<br><br> <span class="hljs-attribute">sendfile</span> <span class="hljs-literal">on</span>;<br> <span class="hljs-attribute">tcp_nopush</span> <span class="hljs-literal">on</span>;<br> <span class="hljs-attribute">tcp_nodelay</span> <span class="hljs-literal">on</span>;<br> <span class="hljs-attribute">keepalive_timeout</span> <span class="hljs-number">65</span>;<br> <span class="hljs-attribute">types_hash_max_size</span> <span class="hljs-number">2048</span>;<br><br> <span class="hljs-attribute">include</span> /etc/nginx/mime.types;<br> <span class="hljs-attribute">default_type</span> application/octet-stream;<br><br> <span class="hljs-attribute">index</span> index.html index.htm;<br><br> <span class="hljs-section">server</span> {<br> <span class="hljs-attribute">listen</span> <span class="hljs-number">80</span> default_server;<br> <span class="hljs-attribute">listen</span> [::]:<span class="hljs-number">80</span> default_server;<br> <span class="hljs-attribute">server_name</span> localhost;<br> <span class="hljs-attribute">root</span> /var/www/html;<br><br> <span class="hljs-section">location</span> / {<br> <span class="hljs-attribute">include</span> uwsgi_params;<br> <span class="hljs-attribute">uwsgi_pass</span> unix:/tmp/uwsgi.socket;<br> }<br> }<br>}<br></code></pre></td></tr></table></figure><h2 id="创建执行脚本">创建执行脚本</h2><p>创建一个启动脚本 <code>start.sh</code></p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><code class="hljs shell"><span class="hljs-meta prompt_">#</span><span class="language-bash">!/usr/bin/env bash</span><br>service nginx start<br>uwsgi --ini uwsgi.ini<br></code></pre></td></tr></table></figure><h2 id="编写-dockerfile">编写 Dockerfile</h2><figure class="highlight docker"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><code class="hljs docker"><span class="hljs-keyword">FROM</span> ubuntu:<span class="hljs-number">22.04</span><br><span class="hljs-keyword">MAINTAINER</span> aizpy <span class="hljs-string">"[email protected]"</span><br><br><span class="hljs-keyword">EXPOSE</span> <span class="hljs-number">80</span><br><span class="hljs-keyword">COPY</span><span class="language-bash"> . /srv/app</span><br><span class="hljs-keyword">WORKDIR</span><span class="language-bash"> /srv/app</span><br><br><span class="hljs-keyword">RUN</span><span class="language-bash"> apt-get clean && apt-get -y update</span><br><span class="hljs-keyword">RUN</span><span class="language-bash"> apt-get install -y nginx python3-dev build-essential pip</span><br><br><span class="hljs-keyword">RUN</span><span class="language-bash"> pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple</span><br><br><span class="hljs-keyword">COPY</span><span class="language-bash"> nginx.conf /etc/nginx</span><br><span class="hljs-keyword">RUN</span><span class="language-bash"> <span class="hljs-built_in">chmod</span> +x ./start.sh</span><br><span class="hljs-keyword">CMD</span><span class="language-bash"> [<span class="hljs-string">"./start.sh"</span>]</span><br></code></pre></td></tr></table></figure><h2 id="通过-dockerfile-创建镜像">通过 Dockerfile 创建镜像</h2><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs shell">docker build . -t my_flask_app<br></code></pre></td></tr></table></figure><h2 id="生成容器并运行">生成容器并运行</h2><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs shell">docker run --name flask_app -p 80:80 my_flask_app<br></code></pre></td></tr></table></figure><h2 id="参考">参考</h2><ul><li><a href="https://smirnov-am.github.io/running-flask-in-production-with-docker/">Running Flask in production with Docker</a></li></ul>]]></content>
<categories>
<category>Docker</category>
</categories>
<tags>
<tag>Docker</tag>
<tag>Flask</tag>
<tag>uWSGI</tag>
</tags>
</entry>
<entry>
<title>对 ChatGLM-6B 做 LoRA fine tuning</title>
<link href="/2023/03/30/chatglm-6b-lora/"/>
<url>/2023/03/30/chatglm-6b-lora/</url>
<content type="html"><![CDATA[<p>ChatGLM-6B 是一个支持中英双语的对话语言模型,基于 GLM (General Language Model)。它只有 62 亿个参数,量化后最低 (INT4 量化) 只需要 6GB 的显存,完全可以部署到消费级显卡上。在实际使用这个模型一段时间以后,我们发现模型的对话表现能力确实非常不错。那么,基于这个模型做 Fine-tuning 就非常有价值了。</p><span id="more"></span><p><strong>声明:</strong></p><p>本文提供的所有技术信息,都基于 <a href="https://huggingface.co/THUDM/chatglm-6b">THUDM/chatglm-6b</a> 的历史版本: <code>096f3de6b4959ce38bef7bb05f3129c931a3084e</code>。</p><p>源码地址:</p><ul><li><a href="https://github.com/aizpy/chatglm-finetune">GitHub</a></li><li><a href="https://gitee.com/aizpy/chatglm-finetune">gitee</a></li></ul><h2 id="搭建依赖环境">搭建依赖环境</h2><p>安装 PyTorch 环境: <figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs shell">pip install torch torchvision torchaudio<br></code></pre></td></tr></table></figure></p><p>按照 ChatGLM-6B 的官方指导,安装软件依赖环境: <figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs shell">pip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels <br></code></pre></td></tr></table></figure></p><p>为了做 LoRA,还要安装 peft <figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs shell">pip install peft<br></code></pre></td></tr></table></figure></p><h2 id="加载模型和-tokenizer">加载模型和 Tokenizer</h2><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, AutoModel<br><br>checkpoint = <span class="hljs-string">"THUDM/chatglm-6b"</span><br>revision = <span class="hljs-string">"096f3de6b4959ce38bef7bb05f3129c931a3084e"</span><br>model = AutoModel.from_pretrained(checkpoint, revision=revision, trust_remote_code=<span class="hljs-literal">True</span>)<br>tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision=revision, trust_remote_code=<span class="hljs-literal">True</span>)<br></code></pre></td></tr></table></figure><p>正如前面声明所述,本文使用的历史版本号是 <code>096f3de6b4959ce38bef7bb05f3129c931a3084e</code>。如果开发者需要其他版本号,只需要更改 <code>revision</code> 值,并重新训练即可。</p><h2 id="分析模型结构">分析模型结构</h2><p>模型加载完后,我们可以打印这个 <code>model</code> 和 <code>tokenizer</code>,建立对模型的基本认知。</p><p>首先打印<code>model</code>: <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-built_in">print</span>(model)<br></code></pre></td></tr></table></figure> 得到如下结果: <figure class="highlight routeros"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br></pre></td><td class="code"><pre><code class="hljs routeros">ChatGLMForConditionalGeneration(<br> (transformer): ChatGLMModel(<br> (word_embeddings): Embedding(150528, 4096)<br> (layers): ModuleList(<br> (0-27): 28 x GLMBlock(<br> (input_layernorm): LayerNorm((4096,), <span class="hljs-attribute">eps</span>=1e-05, <span class="hljs-attribute">elementwise_affine</span>=<span class="hljs-literal">True</span>)<br> (attention): SelfAttention(<br> (rotary_emb): RotaryEmbedding()<br> (query_key_value): Linear(<span class="hljs-attribute">in_features</span>=4096, <span class="hljs-attribute">out_features</span>=12288, <span class="hljs-attribute">bias</span>=<span class="hljs-literal">True</span>)<br> (dense): Linear(<span class="hljs-attribute">in_features</span>=4096, <span class="hljs-attribute">out_features</span>=4096, <span class="hljs-attribute">bias</span>=<span class="hljs-literal">True</span>)<br> )<br> (post_attention_layernorm): LayerNorm((4096,), <span class="hljs-attribute">eps</span>=1e-05, <span class="hljs-attribute">elementwise_affine</span>=<span class="hljs-literal">True</span>)<br> (mlp): GLU(<br> (dense_h_to_4h): Linear(<span class="hljs-attribute">in_features</span>=4096, <span class="hljs-attribute">out_features</span>=16384, <span class="hljs-attribute">bias</span>=<span class="hljs-literal">True</span>)<br> (dense_4h_to_h): Linear(<span class="hljs-attribute">in_features</span>=16384, <span class="hljs-attribute">out_features</span>=4096, <span class="hljs-attribute">bias</span>=<span class="hljs-literal">True</span>)<br> )<br> )<br> )<br> (final_layernorm): LayerNorm((4096,), <span class="hljs-attribute">eps</span>=1e-05, <span class="hljs-attribute">elementwise_affine</span>=<span class="hljs-literal">True</span>)<br> )<br> (lm_head): Linear(<span class="hljs-attribute">in_features</span>=4096, <span class="hljs-attribute">out_features</span>=150528, <span class="hljs-attribute">bias</span>=<span class="hljs-literal">False</span>)<br>)<br></code></pre></td></tr></table></figure> 简单分析这个模型结构,至少可以得到如下一些信息:</p><ul><li>模型使用了 Transformer 结构,因此可以使用 LoRA 进行 Fine-tuning</li><li>从 Word Embedding 层可以看出,词汇表大小是 <code>150528</code></li><li>LoRA 可以操作的目标是:<code>query_key_value</code></li></ul><p>再打印<code>tokenizer</code>:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-built_in">print</span>(tokenizer)<br></code></pre></td></tr></table></figure><p>得到如下结果(为了便于阅读,已对结果做了分行处理): <figure class="highlight routeros"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><code class="hljs routeros">ChatGLMTokenizer(<br><span class="hljs-attribute">name_or_path</span>=<span class="hljs-string">'THUDM/chatglm-6b'</span>, <br><span class="hljs-attribute">vocab_size</span>=150344, <br><span class="hljs-attribute">model_max_length</span>=2048, <br><span class="hljs-attribute">is_fast</span>=<span class="hljs-literal">False</span>, <br><span class="hljs-attribute">padding_side</span>=<span class="hljs-string">'left'</span>, <br><span class="hljs-attribute">truncation_side</span>=<span class="hljs-string">'right'</span>, <br>special_tokens={<br><span class="hljs-string">'bos_token'</span>: <span class="hljs-string">'<sop>'</span>, <br><span class="hljs-string">'eos_token'</span>: <span class="hljs-string">'</s>'</span>, <br><span class="hljs-string">'unk_token'</span>: <span class="hljs-string">'<unk>'</span>, <br><span class="hljs-string">'pad_token'</span>: <span class="hljs-string">'<pad>'</span>, <br><span class="hljs-string">'mask_token'</span>: <span class="hljs-string">'[MASK]'</span><br>}<br>)<br></code></pre></td></tr></table></figure></p><p>这里有几个可以关注的点:</p><ul><li>词汇表大小<code>vocab_size</code>是<code>150344</code></li><li>不是一个 fast Tokenizer(<code>is_fast</code> 的值是 <code>False</code>)</li><li>特殊 token 包括:<code>bos</code> <code>eos</code> <code>pad</code> 和 <code>mask</code></li></ul><p>为什么 model 中的词汇表大小是 <code>150528</code>,而 <code>tokenizer</code> 中定义的词汇表大小却是 <code>150344</code> 呢?读者可以带着这个疑问去读一读模型项目的源码,看看能不能找到答案。</p><h2 id="配置-lora">配置 LoRA</h2><p>借助 peft 库,我们可以很方便地对模型注入 LoRA。 <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">from</span> peft <span class="hljs-keyword">import</span> LoraConfig, get_peft_model, TaskType<br><br><span class="hljs-keyword">def</span> <span class="hljs-title function_">load_lora_config</span>(<span class="hljs-params">model</span>):<br>config = LoraConfig(<br> task_type=TaskType.CAUSAL_LM, <br> inference_mode=<span class="hljs-literal">False</span>,<br> r=<span class="hljs-number">8</span>, <br> lora_alpha=<span class="hljs-number">32</span>, <br> lora_dropout=<span class="hljs-number">0.1</span>,<br> target_modules=[<span class="hljs-string">"query_key_value"</span>]<br>)<br><span class="hljs-keyword">return</span> get_peft_model(model, config)<br><br>model = load_lora_config(model)<br></code></pre></td></tr></table></figure> 打印可训练的参数量: <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs python">model.print_trainable_parameters()<br></code></pre></td></tr></table></figure> 得到如下结果: <figure class="highlight apache"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs apache"><span class="hljs-attribute">trainable</span> params: <span class="hljs-number">3670016</span> || <span class="hljs-literal">all</span> params: <span class="hljs-number">6258876416</span> || trainable%: <span class="hljs-number">0</span>.<span class="hljs-number">05863697820615348</span><br></code></pre></td></tr></table></figure></p><p>可以看到,总的参数量是 <code>6,258,876,416</code>,可训练的参数量是 <code>3,670,016</code>,占比 <code>0.0586%</code> 左右。训练参数量只是百万级别的,可谓相当友好了!另外需要注意的一点是,ChatGLM-6B 是一个因果语言模型 (Causal Language Model),因此我们这里选择的任务类型是 <code>CAUSAL_LM</code>。</p><h2 id="构建数据集">构建数据集</h2><h3 id="定义常量">定义常量</h3><p>构建之前,我们先定义几个特殊 Token 常量: <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><code class="hljs python">bos = tokenizer.bos_token_id<br>eop = tokenizer.eop_token_id<br>pad = tokenizer.pad_token_id<br>mask = tokenizer.mask_token_id<br>gmask = tokenizer.sp_tokenizer[tokenizer.gMASK_token]<br></code></pre></td></tr></table></figure></p><p>将这几个值打印出来: <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-built_in">print</span>(<span class="hljs-string">"bos = "</span>, bos)<br><span class="hljs-built_in">print</span>(<span class="hljs-string">"eop = "</span>, eop)<br><span class="hljs-built_in">print</span>(<span class="hljs-string">"pad = "</span>, pad)<br><span class="hljs-built_in">print</span>(<span class="hljs-string">"mask = "</span>, mask)<br><span class="hljs-built_in">print</span>(<span class="hljs-string">"gmask = "</span>, gmask)<br></code></pre></td></tr></table></figure> 得到如下结果: <figure class="highlight abnf"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><code class="hljs abnf"><span class="hljs-attribute">bos</span> <span class="hljs-operator">=</span> <span class="hljs-number">150004</span><br><span class="hljs-attribute">eop</span> <span class="hljs-operator">=</span> <span class="hljs-number">150005</span><br><span class="hljs-attribute">pad</span> <span class="hljs-operator">=</span> <span class="hljs-number">20003</span><br><span class="hljs-attribute">mask</span> <span class="hljs-operator">=</span> <span class="hljs-number">150000</span><br><span class="hljs-attribute">gmask</span> <span class="hljs-operator">=</span> <span class="hljs-number">150001</span><br></code></pre></td></tr></table></figure></p><p>我们也可以直接用这个常量结果替换动态计算的部分。常量修改后的结果变成: <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><code class="hljs python">bos = <span class="hljs-number">150004</span><br>eop = <span class="hljs-number">150005</span><br>pad = <span class="hljs-number">20003</span><br>mask = <span class="hljs-number">150000</span><br>gmask = <span class="hljs-number">150001</span><br></code></pre></td></tr></table></figure> 除了上面定义的 Token 常量,我们还需要定义模型训练绑定的设备名,以及最大输入长度和最大输出长度等,如下: <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><code class="hljs python">device = <span class="hljs-string">"cuda"</span><br>max_src_length = <span class="hljs-number">200</span><br>max_dst_length = <span class="hljs-number">500</span><br></code></pre></td></tr></table></figure> 开发者可以结合自己的显卡性能和要处理的数据集特点来确定这些最大长度。</p><h3 id="测试-tokenizer-的编解码">测试 Tokenizer 的编解码</h3><p>我们可以先做个简单的测试: <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><code class="hljs python">text = <span class="hljs-string">"AI探险家"</span><br><span class="hljs-built_in">print</span>(tokenizer.encode(text, add_special_tokens = <span class="hljs-literal">True</span>))<br><span class="hljs-built_in">print</span>(tokenizer.encode(text, add_special_tokens = <span class="hljs-literal">False</span>))<br></code></pre></td></tr></table></figure> 输出结果是: <figure class="highlight accesslog"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><code class="hljs accesslog"><span class="hljs-string">[26738, 98715, 83920, 150001, 150004]</span><br><span class="hljs-string">[26738, 98715, 83920]</span><br></code></pre></td></tr></table></figure></p><p>从这个结果可以看出,“<a href="https://aizpy.com/about">AI探险家</a>”这几个字的裸编码是 <code>[26738, 98715, 83920]</code>。为什么是这样呢?我们可以对每一个数值再解码,看看输出结果: <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-built_in">print</span>(tokenizer.decode([<span class="hljs-number">26738</span>]))<br><span class="hljs-built_in">print</span>(tokenizer.decode([<span class="hljs-number">98715</span>]))<br><span class="hljs-built_in">print</span>(tokenizer.decode([<span class="hljs-number">83920</span>]))<br></code></pre></td></tr></table></figure> 输出结果是: <figure class="highlight nginx"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><code class="hljs nginx"><span class="hljs-attribute">AI</span><br>探险<br>家<br></code></pre></td></tr></table></figure> 观察这个结果,读者应该能对词汇表建立基本的认知了。读者如果有兴趣,还可以分别针对 “A” “I” “探” “险” 这几个字分别编码,看看编码结果是什么。</p><p>另外,当 <code>add_special_tokens = True</code> 时,编码结果会在末尾添加 <code>150001</code>和 <code>150004</code>,也就是 <code>gmask</code> 和 <code>bos</code>。请注意,我们的训练数据,要按照如下编码要求进行构造: <figure class="highlight gauss"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs gauss">[<span class="hljs-built_in">token</span>, ..., <span class="hljs-built_in">token</span>, gmask, bos, <span class="hljs-built_in">token</span>, ... <span class="hljs-built_in">token</span>, eop]<br></code></pre></td></tr></table></figure> 因此,前半部分文本的编码可以直接让 <code>add_special_tokens = True</code>,后半部分文本的编码则让 <code>add_special_tokens = False</code>,最后再拼接一个 <code>eop</code>。</p><h3 id="定义-prompt">定义 Prompt</h3><p>我们 Fine-tuning 的任务是问答任务(简称 QA),因此一个简单的 Prompt 是这样的:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><code class="hljs python">PROMPT_PATTERN = <span class="hljs-string">"问:{}\n答: "</span><br></code></pre></td></tr></table></figure><p><code>{}</code>里填入 QA 训练集的问题文本。在显存有限的情况下,如果不对长文本做限制处理,很容易出现类似 <code>CUDA out of memory</code> 这样的报错。处理长文本,在给定编码后的数组上限时,可能存在这么几种方式:</p><ul><li>截断<strong>末尾</strong>超出部分的编码</li><li>截断<strong>前面</strong>超出部分的编码</li><li>丢掉训练样本</li></ul><p>每一种方式都有各自的优劣,开发者可以根据自身数据的特点自行选择一种处理方式。当然,如果你的显存够大,也可以不处理。本文以上述第一种方式进行处理。 为了不把 <code>PROMPT_PATTERN</code> 中的 <code>\n答:</code> 这几个字截断掉,我们将整个 <code>PROMPT_PATTERN</code> 拆成两部分:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><code class="hljs python">PROMPT_PATTERN = <span class="hljs-string">"问:{}"</span><br>SEP_PATTERN = <span class="hljs-string">"\n答: "</span><br></code></pre></td></tr></table></figure><p>基于这份 Prompt 模板,我们定义下面三个辅助方法: <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">def</span> <span class="hljs-title function_">create_prompt</span>(<span class="hljs-params">question</span>):<br> <span class="hljs-keyword">return</span> PROMPT_PATTERN.<span class="hljs-built_in">format</span>(question), SEP_PATTERN<br><br><br><span class="hljs-keyword">def</span> <span class="hljs-title function_">create_prompt_ids</span>(<span class="hljs-params">tokenizer, question, max_src_length</span>):<br> prompt, sep = create_prompt(question)<br> sep_ids = tokenizer.encode(<br> sep, <br> add_special_tokens = <span class="hljs-literal">True</span><br> )<br> sep_len = <span class="hljs-built_in">len</span>(sep_ids)<br> special_tokens_num = <span class="hljs-number">2</span><br> prompt_ids = tokenizer.encode(<br> prompt, <br> max_length = max_src_length - (sep_len - special_tokens_num),<br> truncation = <span class="hljs-literal">True</span>,<br> add_special_tokens = <span class="hljs-literal">False</span><br> )<br><br> <span class="hljs-keyword">return</span> prompt_ids + sep_ids<br><br><br><span class="hljs-keyword">def</span> <span class="hljs-title function_">create_inputs_and_labels</span>(<span class="hljs-params">tokenizer, question, answer, device</span>):<br> prompt = create_prompt_ids(tokenizer, question, max_src_length)<br> completion = tokenizer.encode(<br> answer, <br> max_length = max_dst_length,<br> truncation = <span class="hljs-literal">True</span>,<br> add_special_tokens = <span class="hljs-literal">False</span><br> )<br><br> inputs = prompt + completion + [eop]<br> labels = [-<span class="hljs-number">100</span>] * <span class="hljs-built_in">len</span>(prompt) + completion + [eop] <br> <br> inputs = torch.tensor(inputs, dtype=torch.long, device=device)<br> labels = torch.tensor(labels, dtype=torch.long, device=device)<br> <span class="hljs-keyword">return</span> inputs, labels<br></code></pre></td></tr></table></figure> 值得注意的两点:</p><ul><li>从 <code>create_prompt_ids</code> 这个函数实现可以看出,我们编码分隔符 <code>SEP_PATTERN</code> 时自动添加了前面所述的 2 个特殊 Token。</li><li>对 <code>create_inputs_and_labels</code> 的函数实现中,我们将 <code>labels</code> 无需处理的部分用数值 <code>-100</code> 来表示。因为 <code>ChatGLMForConditionalGeneration</code> 内部在计算损失函数的时候,用的是 <code>torch.nn.CrossEntropyLoss</code>。该函数的参数之一 <code>ignore_index</code> 默认值是 <code>-100</code>。这就让我们在计算损失函数时,无需考虑非标识部分的数值。</li></ul><h3 id="构建-attention-mask-和-position-ids">构建 Attention Mask 和 Position IDs</h3><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">def</span> <span class="hljs-title function_">get_attention_mask</span>(<span class="hljs-params">tokenizer, input_ids, device</span>):<br> seq = input_ids.tolist()<br> context_len = seq.index(bos)<br> seq_len = <span class="hljs-built_in">len</span>(seq)<br> attention_mask = torch.ones((seq_len, seq_len), device=device)<br> attention_mask.tril_()<br> attention_mask[..., :context_len] = <span class="hljs-number">1</span><br> attention_mask.unsqueeze_(<span class="hljs-number">0</span>)<br> attention_mask = (attention_mask < <span class="hljs-number">0.5</span>).<span class="hljs-built_in">bool</span>()<br> <span class="hljs-keyword">return</span> attention_mask<br><br><br><span class="hljs-keyword">def</span> <span class="hljs-title function_">get_position_ids</span>(<span class="hljs-params">tokenizer, input_ids, device, position_encoding_2d=<span class="hljs-literal">True</span></span>):<br> seq = input_ids.tolist()<br> context_len = seq.index(bos)<br> seq_len = <span class="hljs-built_in">len</span>(seq)<br><br> mask_token = mask <span class="hljs-keyword">if</span> mask <span class="hljs-keyword">in</span> seq <span class="hljs-keyword">else</span> gmask<br> use_gmask = <span class="hljs-literal">False</span> <span class="hljs-keyword">if</span> mask <span class="hljs-keyword">in</span> seq <span class="hljs-keyword">else</span> gmask<br><br> mask_position = seq.index(mask_token)<br><br> <span class="hljs-keyword">if</span> position_encoding_2d:<br> position_ids = torch.arange(seq_len, dtype=torch.long, device=device)<br> <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> use_gmask:<br> position_ids[context_len:] = mask_position<br> block_position_ids = torch.cat((<br> torch.zeros(context_len, dtype=torch.long, device=device),<br> torch.arange(seq_len - context_len, dtype=torch.long, device=device) + <span class="hljs-number">1</span><br> ))<br> position_ids = torch.stack((position_ids, block_position_ids), dim=<span class="hljs-number">0</span>)<br> <span class="hljs-keyword">else</span>:<br> position_ids = torch.arange(seq_len, dtype=torch.long, device=device)<br> <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> use_gmask:<br> position_ids[context_len:] = mask_position<br> <br> <span class="hljs-keyword">return</span> position_ids<br></code></pre></td></tr></table></figure><p>在这个通用实现中,我们针对 <code>mask</code> 和 <code>gmask</code> 两种情况做了区分,同时也对是否执行 <code>position_encoding_2d</code> 分情况处理。本文的 QA 任务采用的是 <code>gmask</code>,并且使用 <code>position_encoding_2d = True</code>。</p><p>我们可以构建下面的问答,来验证下这几个函数的输出: <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><code class="hljs python">test_data = {<br><span class="hljs-string">"question"</span>: <span class="hljs-string">"AI探险家帅不帅?"</span>,<br><span class="hljs-string">"answer"</span>: <span class="hljs-string">"非常帅!"</span><br>}<br><br>inputs, labels = create_inputs_and_labels(tokenizer, **test_data, device=device)<br>attention_mask = get_attention_mask(tokenizer, inputs, device=device)<br>position_ids = get_position_ids(tokenizer, inputs, device=device)<br><br><span class="hljs-built_in">print</span>(<span class="hljs-string">"inputs: \n"</span>, inputs.tolist())<br><span class="hljs-built_in">print</span>(<span class="hljs-string">"\nlabels: \n"</span>, labels.tolist())<br><span class="hljs-built_in">print</span>(<span class="hljs-string">"\nposition_ids: \n"</span>, position_ids.tolist())<br><span class="hljs-built_in">print</span>(<span class="hljs-string">"\nattention_mask: \n"</span>, attention_mask.tolist())<br></code></pre></td></tr></table></figure></p><p>输出结果(为了便于阅读,已对输出进行格式化操作): <figure class="highlight mathematica"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br></pre></td><td class="code"><pre><code class="hljs mathematica"><span class="hljs-variable">inputs</span><span class="hljs-operator">:</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-number">20005</span><span class="hljs-operator">,</span> <span class="hljs-number">84286</span><span class="hljs-operator">,</span> <span class="hljs-number">20012</span><span class="hljs-operator">,</span> <span class="hljs-number">31943</span><span class="hljs-operator">,</span> <span class="hljs-number">98715</span><span class="hljs-operator">,</span> <span class="hljs-number">83920</span><span class="hljs-operator">,</span> <span class="hljs-number">87359</span><span class="hljs-operator">,</span> <span class="hljs-number">83848</span><span class="hljs-operator">,</span> <span class="hljs-number">87359</span><span class="hljs-operator">,</span> <span class="hljs-number">20031</span><span class="hljs-operator">,</span> <span class="hljs-number">20005</span><span class="hljs-operator">,</span> <span class="hljs-number">20004</span><span class="hljs-operator">,</span> <span class="hljs-number">87342</span><span class="hljs-operator">,</span> <span class="hljs-number">20012</span><span class="hljs-operator">,</span> <span class="hljs-number">150001</span><span class="hljs-operator">,</span> <span class="hljs-number">150004</span><span class="hljs-operator">,</span> <span class="hljs-number">20005</span><span class="hljs-operator">,</span> <span class="hljs-number">84122</span><span class="hljs-operator">,</span> <span class="hljs-number">87359</span><span class="hljs-operator">,</span> <span class="hljs-number">20035</span><span class="hljs-operator">,</span> <span class="hljs-number">150005</span><span class="hljs-punctuation">]</span><br><br><span class="hljs-variable">labels</span><span class="hljs-operator">:</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-operator">-</span><span class="hljs-number">100</span><span class="hljs-operator">,</span> <span class="hljs-operator">-</span><span class="hljs-number">100</span><span class="hljs-operator">,</span> <span class="hljs-operator">-</span><span class="hljs-number">100</span><span class="hljs-operator">,</span> <span class="hljs-operator">-</span><span class="hljs-number">100</span><span class="hljs-operator">,</span> <span class="hljs-operator">-</span><span class="hljs-number">100</span><span class="hljs-operator">,</span> <span class="hljs-operator">-</span><span class="hljs-number">100</span><span class="hljs-operator">,</span> <span class="hljs-operator">-</span><span class="hljs-number">100</span><span class="hljs-operator">,</span> <span class="hljs-operator">-</span><span class="hljs-number">100</span><span class="hljs-operator">,</span> <span class="hljs-operator">-</span><span class="hljs-number">100</span><span class="hljs-operator">,</span> <span class="hljs-operator">-</span><span class="hljs-number">100</span><span class="hljs-operator">,</span> <span class="hljs-operator">-</span><span class="hljs-number">100</span><span class="hljs-operator">,</span> <span class="hljs-operator">-</span><span class="hljs-number">100</span><span class="hljs-operator">,</span> <span class="hljs-operator">-</span><span class="hljs-number">100</span><span class="hljs-operator">,</span> <span class="hljs-operator">-</span><span class="hljs-number">100</span><span class="hljs-operator">,</span> <span class="hljs-operator">-</span><span class="hljs-number">100</span><span class="hljs-operator">,</span> <span class="hljs-operator">-</span><span class="hljs-number">100</span><span class="hljs-operator">,</span> <span class="hljs-number">20005</span><span class="hljs-operator">,</span> <span class="hljs-number">84122</span><span class="hljs-operator">,</span> <span class="hljs-number">87359</span><span class="hljs-operator">,</span> <span class="hljs-number">20035</span><span class="hljs-operator">,</span> <span class="hljs-number">150005</span><span class="hljs-punctuation">]</span><br><br><span class="hljs-type">position_ids</span><span class="hljs-operator">:</span> <br> <span class="hljs-punctuation">[</span><br> <span class="hljs-punctuation">[</span><span class="hljs-number">0</span><span class="hljs-operator">,</span> <span class="hljs-number">1</span><span class="hljs-operator">,</span> <span class="hljs-number">2</span><span class="hljs-operator">,</span> <span class="hljs-number">3</span><span class="hljs-operator">,</span> <span class="hljs-number">4</span><span class="hljs-operator">,</span> <span class="hljs-number">5</span><span class="hljs-operator">,</span> <span class="hljs-number">6</span><span class="hljs-operator">,</span> <span class="hljs-number">7</span><span class="hljs-operator">,</span> <span class="hljs-number">8</span><span class="hljs-operator">,</span> <span class="hljs-number">9</span><span class="hljs-operator">,</span> <span class="hljs-number">10</span><span class="hljs-operator">,</span> <span class="hljs-number">11</span><span class="hljs-operator">,</span> <span class="hljs-number">12</span><span class="hljs-operator">,</span> <span class="hljs-number">13</span><span class="hljs-operator">,</span> <span class="hljs-number">14</span><span class="hljs-operator">,</span> <span class="hljs-number">15</span><span class="hljs-operator">,</span> <span class="hljs-number">16</span><span class="hljs-operator">,</span> <span class="hljs-number">17</span><span class="hljs-operator">,</span> <span class="hljs-number">18</span><span class="hljs-operator">,</span> <span class="hljs-number">19</span><span class="hljs-operator">,</span> <span class="hljs-number">20</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-number">0</span><span class="hljs-operator">,</span> <span class="hljs-number">0</span><span class="hljs-operator">,</span> <span class="hljs-number">0</span><span class="hljs-operator">,</span> <span class="hljs-number">0</span><span class="hljs-operator">,</span> <span class="hljs-number">0</span><span class="hljs-operator">,</span> <span class="hljs-number">0</span><span class="hljs-operator">,</span> <span class="hljs-number">0</span><span class="hljs-operator">,</span> <span class="hljs-number">0</span><span class="hljs-operator">,</span> <span class="hljs-number">0</span><span class="hljs-operator">,</span> <span class="hljs-number">0</span><span class="hljs-operator">,</span> <span class="hljs-number">0</span><span class="hljs-operator">,</span> <span class="hljs-number">0</span><span class="hljs-operator">,</span> <span class="hljs-number">0</span><span class="hljs-operator">,</span> <span class="hljs-number">0</span><span class="hljs-operator">,</span> <span class="hljs-number">0</span><span class="hljs-operator">,</span> <span class="hljs-number">0</span><span class="hljs-operator">,</span> <span class="hljs-number">1</span><span class="hljs-operator">,</span> <span class="hljs-number">2</span><span class="hljs-operator">,</span> <span class="hljs-number">3</span><span class="hljs-operator">,</span> <span class="hljs-number">4</span><span class="hljs-operator">,</span> <span class="hljs-number">5</span><span class="hljs-punctuation">]</span><br> <span class="hljs-punctuation">]</span><br><br><span class="hljs-type">attention_mask</span><span class="hljs-operator">:</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-punctuation">[</span><br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">True</span><span class="hljs-punctuation">]</span><span class="hljs-operator">,</span> <br> <span class="hljs-punctuation">[</span><span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-operator">,</span> <span class="hljs-built_in">False</span><span class="hljs-punctuation">]</span><span class="hljs-punctuation">]</span><span class="hljs-punctuation">]</span><br></code></pre></td></tr></table></figure> 结合论文观察数据,基本符合预期。</p><h3 id="创建数据集">创建数据集</h3><p>我们先定义具有如下格式的训练数据: <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><code class="hljs python">train_data = [<br>{<span class="hljs-string">"question"</span>: <span class="hljs-string">"问题1"</span>, <span class="hljs-string">"answer"</span>: <span class="hljs-string">"答案1"</span>},<br>{<span class="hljs-string">"question"</span>: <span class="hljs-string">"问题2"</span>, <span class="hljs-string">"answer"</span>: <span class="hljs-string">"答案2"</span>},<br>]<br><br></code></pre></td></tr></table></figure> 定义好格式后,我们先创建一个 <code>QADataset</code> 类,如下: <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> Dataset<br><br><span class="hljs-keyword">class</span> <span class="hljs-title class_">QADataset</span>(<span class="hljs-title class_ inherited__">Dataset</span>):<br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, data, tokenizer</span>) -> <span class="hljs-literal">None</span>:<br> <span class="hljs-built_in">super</span>().__init__()<br> self.data = data<br> self.tokenizer = tokenizer<br> <br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">__getitem__</span>(<span class="hljs-params">self, index</span>):<br> item_data = self.data[index]<br> tokenizer = self.tokenizer<br> input_ids, labels = create_inputs_and_labels(<br> tokenizer, <br> device=device,<br> **item_data<br> )<br> <br> attention_mask = get_attention_mask(tokenizer, input_ids, device)<br> position_ids = get_position_ids(tokenizer, input_ids, device)<br><br> <span class="hljs-keyword">return</span> {<br> <span class="hljs-string">"input_ids"</span>: input_ids,<br> <span class="hljs-string">"labels"</span>: labels,<br> <span class="hljs-string">"attention_mask"</span>: attention_mask,<br> <span class="hljs-string">"position_ids"</span>: position_ids<br> }<br> <br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">__len__</span>(<span class="hljs-params">self</span>):<br> <span class="hljs-keyword">return</span> <span class="hljs-built_in">len</span>(self.data)<br></code></pre></td></tr></table></figure></p><p>然后创建一个 Data Collator:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">def</span> <span class="hljs-title function_">collate_fn</span>(<span class="hljs-params">batch</span>):<br> input_ids = []<br> attention_mask = []<br> labels = []<br> position_ids = []<br> <br> <span class="hljs-keyword">for</span> obj <span class="hljs-keyword">in</span> batch:<br> input_ids.append(obj[<span class="hljs-string">'input_ids'</span>])<br> labels.append(obj[<span class="hljs-string">'labels'</span>])<br> attention_mask.append(obj[<span class="hljs-string">'attention_mask'</span>])<br> position_ids.append(obj[<span class="hljs-string">'position_ids'</span>])<br> <br> <span class="hljs-keyword">return</span> {<br> <span class="hljs-string">'input_ids'</span>: torch.stack(input_ids),<br> <span class="hljs-string">'attention_mask'</span>: torch.stack(attention_mask), <br> <span class="hljs-string">'labels'</span>: torch.stack(labels),<br> <span class="hljs-string">'position_ids'</span>:torch.stack(position_ids)<br> }<br></code></pre></td></tr></table></figure><h2 id="开始训练">开始训练</h2><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments, Trainer<br>model.to(device)<br><br>training_args = TrainingArguments(<br> <span class="hljs-string">"output"</span>,<br> fp16 =<span class="hljs-literal">True</span>,<br> save_steps = <span class="hljs-number">500</span>,<br> save_total_limit = <span class="hljs-number">3</span>,<br> gradient_accumulation_steps=<span class="hljs-number">1</span>,<br> per_device_train_batch_size = <span class="hljs-number">1</span>,<br> learning_rate = <span class="hljs-number">1e-4</span>,<br> max_steps=<span class="hljs-number">1500</span>,<br> logging_steps=<span class="hljs-number">50</span>,<br> remove_unused_columns=<span class="hljs-literal">False</span>,<br> seed=<span class="hljs-number">0</span>,<br> data_seed=<span class="hljs-number">0</span>,<br> group_by_length=<span class="hljs-literal">False</span>,<br> dataloader_pin_memory=<span class="hljs-literal">False</span><br>)<br><br><span class="hljs-keyword">class</span> <span class="hljs-title class_">ModifiedTrainer</span>(<span class="hljs-title class_ inherited__">Trainer</span>):<br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_loss</span>(<span class="hljs-params">self, model, inputs, return_outputs=<span class="hljs-literal">False</span></span>):<br> <span class="hljs-keyword">return</span> model(<br> input_ids=inputs[<span class="hljs-string">"input_ids"</span>],<br> attention_mask=inputs[<span class="hljs-string">"attention_mask"</span>],<br> position_ids=inputs[<span class="hljs-string">"position_ids"</span>],<br> labels=inputs[<span class="hljs-string">"labels"</span>],<br> ).loss<br><br><br>train_dataset = QADataset(train_data, tokenizer=tokenizer)<br>trainer = ModifiedTrainer(<br> model=model,<br> train_dataset=train_dataset,<br> args=training_args,<br> data_collator=collate_fn,<br> tokenizer=tokenizer<br>)<br><br>trainer.train()<br></code></pre></td></tr></table></figure><h2 id="预测">预测</h2><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><code class="hljs python">response, history = model.chat(tokenizer, <span class="hljs-string">"AI探险家的颜值如何?"</span>, history=[])<br><span class="hljs-built_in">print</span>(response)<br></code></pre></td></tr></table></figure><h2 id="保存训练模型">保存训练模型</h2><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">import</span> os<br><br><span class="hljs-keyword">def</span> <span class="hljs-title function_">save_tuned_parameters</span>(<span class="hljs-params">model, path</span>):<br> saved_params = {<br> k: v.to(device)<br> <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> model.named_parameters()<br> <span class="hljs-keyword">if</span> v.requires_grad<br> }<br> torch.save(saved_params, path)<br><br>save_tuned_parameters(model, os.path.join(<span class="hljs-string">"/path/to/output"</span>, <span class="hljs-string">"chatglm-6b-lora.pt"</span>))<br></code></pre></td></tr></table></figure><h2 id="重载训练后的模型">重载训练后的模型</h2><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><code class="hljs python">checkpoint = <span class="hljs-string">"THUDM/chatglm-6b"</span><br>revision = <span class="hljs-string">"096f3de6b4959ce38bef7bb05f3129c931a3084e"</span><br>model = AutoModel.from_pretrained(checkpoint, revision=revision, trust_remote_code=<span class="hljs-literal">True</span>)<br>tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision=revision, trust_remote_code=<span class="hljs-literal">True</span>)<br><br>model = load_lora_config(model)<br>model.load_state_dict(torch.load(<span class="hljs-string">f"/path/to/output/chatglm-6b-lora.pt"</span>), strict=<span class="hljs-literal">False</span>)<br><br>model.half().cuda().<span class="hljs-built_in">eval</span>()<br>response, history = model.chat(tokenizer, <span class="hljs-string">"AI探险家的颜值如何?"</span>, history=[])<br><span class="hljs-built_in">print</span>(response)<br></code></pre></td></tr></table></figure>]]></content>
<categories>
<category>NLP</category>
</categories>
<tags>
<tag>LLM</tag>
<tag>LoRA</tag>
<tag>ChatGLM</tag>
<tag>Transformer</tag>
</tags>
</entry>
</search>