Skip to content

Commit

Permalink
Deploying to gh-pages from @ 3637759 🚀
Browse files Browse the repository at this point in the history
  • Loading branch information
onadegibert committed Sep 26, 2023
1 parent ae2e936 commit d06485e
Show file tree
Hide file tree
Showing 9 changed files with 399 additions and 253 deletions.
21 changes: 7 additions & 14 deletions _modules/onmt/trainer.html
Original file line number Diff line number Diff line change
Expand Up @@ -558,12 +558,12 @@ <h1>Source code for onmt.trainer</h1><div class="highlight"><pre>
<span class="n">valid_stats</span><span class="o">=</span><span class="n">valid_stats</span><span class="p">,</span>
<span class="p">)</span>

<span class="c1"># Run patience mechanism</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">earlystopper</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">earlystopper</span><span class="p">(</span><span class="n">valid_stats</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span>
<span class="c1"># If the patience has reached the limit, stop training</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">earlystopper</span><span class="o">.</span><span class="n">has_stopped</span><span class="p">():</span>
<span class="k">break</span>
<span class="c1"># # Run patience mechanism</span>
<span class="c1"># if self.earlystopper is not None:</span>
<span class="c1"># self.earlystopper(valid_stats, step)</span>
<span class="c1"># # If the patience has reached the limit, stop training</span>
<span class="c1"># if self.earlystopper.has_stopped():</span>
<span class="c1"># break</span>

<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_saver</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="p">(</span><span class="n">save_checkpoint_steps</span> <span class="o">!=</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">step</span> <span class="o">%</span> <span class="n">save_checkpoint_steps</span> <span class="o">==</span> <span class="mi">0</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">model_saver</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">moving_average</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">moving_average</span><span class="p">)</span>
Expand Down Expand Up @@ -625,10 +625,6 @@ <h1>Source code for onmt.trainer</h1><div class="highlight"><pre>
<span class="c1"># Set model back to training mode.</span>
<span class="n">valid_model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>

<span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="s1">&#39;has_grad&#39;</span><span class="p">):</span>
<span class="n">p</span><span class="o">.</span><span class="n">has_grad</span> <span class="o">=</span> <span class="kc">False</span>

<span class="k">return</span> <span class="n">stats</span></div>

<span class="k">def</span> <span class="nf">_gradient_accumulation_over_lang_pairs</span><span class="p">(</span>
Expand All @@ -643,7 +639,7 @@ <h1>Source code for onmt.trainer</h1><div class="highlight"><pre>
<span class="n">seen_comm_batches</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">comm_batch</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_method</span> <span class="o">==</span> <span class="s2">&quot;tokens&quot;</span><span class="p">:</span>
<span class="n">num_tokens</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">batch</span><span class="o">.</span><span class="n">labels</span><span class="p">[</span><span class="mi">1</span><span class="p">:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">ne</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loss_md</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;trainloss</span><span class="si">{</span><span class="n">metadata</span><span class="o">.</span><span class="n">tgt_lang</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">padding_idx</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
<span class="n">batch</span><span class="o">.</span><span class="n">tgt</span><span class="p">[</span><span class="mi">1</span><span class="p">:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">ne</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loss_md</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;trainloss</span><span class="si">{</span><span class="n">metadata</span><span class="o">.</span><span class="n">tgt_lang</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">padding_idx</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
<span class="p">)</span>
<span class="n">normalization</span> <span class="o">+=</span> <span class="n">num_tokens</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
Expand All @@ -663,9 +659,6 @@ <h1>Source code for onmt.trainer</h1><div class="highlight"><pre>
<span class="k">if</span> <span class="n">src_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">report_stats</span><span class="o">.</span><span class="n">n_src_words</span> <span class="o">+=</span> <span class="n">src_lengths</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>

<span class="c1"># tgt_outer corresponds to the target-side input. The expected</span>
<span class="c1"># decoder output will be read directly from the batch:</span>
<span class="c1"># cf. `onmt.utils.loss.CommonLossCompute._make_shard_state`</span>
<span class="n">tgt_outer</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">tgt</span>

<span class="n">bptt</span> <span class="o">=</span> <span class="kc">False</span>
Expand Down
Loading

0 comments on commit d06485e

Please sign in to comment.