Skip to content

Commit

Permalink
Deploying to gh-pages from @ 43d2460 🚀
Browse files Browse the repository at this point in the history
  • Loading branch information
TimotheeMickus committed Sep 25, 2023
1 parent f18d7f5 commit 90cd2a4
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions _modules/onmt/utils/loss.html
Original file line number Diff line number Diff line change
Expand Up @@ -357,19 +357,19 @@ <h1>Source code for onmt.utils.loss</h1><div class="highlight"><pre>
<span class="n">batch_stats</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">stats</span><span class="p">)</span>
<span class="k">return</span> <span class="kc">None</span><span class="p">,</span> <span class="n">batch_stats</span>

<span class="k">def</span> <span class="nf">_stats</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loss</span><span class="p">,</span> <span class="n">scores</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">_stats</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loss</span><span class="p">,</span> <span class="n">scores</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Args:</span>
<span class="sd"> loss (:obj:`FloatTensor`): the loss computed by the loss criterion.</span>
<span class="sd"> scores (:obj:`FloatTensor`): a score for each possible output</span>
<span class="sd"> target (:obj:`FloatTensor`): true targets</span>
<span class="sd"> labels (:obj:`FloatTensor`): true targets</span>

<span class="sd"> Returns:</span>
<span class="sd"> :obj:`onmt.utils.Statistics` : statistics for this batch.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">pred</span> <span class="o">=</span> <span class="n">scores</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="mi">1</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">non_padding</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">ne</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">padding_idx</span><span class="p">)</span>
<span class="n">num_correct</span> <span class="o">=</span> <span class="n">pred</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">target</span><span class="p">)</span><span class="o">.</span><span class="n">masked_select</span><span class="p">(</span><span class="n">non_padding</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="n">non_padding</span> <span class="o">=</span> <span class="n">labels</span><span class="o">.</span><span class="n">ne</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">padding_idx</span><span class="p">)</span>
<span class="n">num_correct</span> <span class="o">=</span> <span class="n">pred</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">masked_select</span><span class="p">(</span><span class="n">non_padding</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="n">num_non_padding</span> <span class="o">=</span> <span class="n">non_padding</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="k">return</span> <span class="n">onmt</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">Statistics</span><span class="p">(</span><span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">num_non_padding</span><span class="p">,</span> <span class="n">num_correct</span><span class="p">)</span>

Expand Down Expand Up @@ -399,14 +399,14 @@ <h1>Source code for onmt.utils.loss</h1><div class="highlight"><pre>

<span class="bp">self</span><span class="o">.</span><span class="n">confidence</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">label_smoothing</span>

<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> output (FloatTensor): batch_size x n_classes</span>
<span class="sd"> target (LongTensor): batch_size</span>
<span class="sd"> labels (LongTensor): batch_size</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">model_prob</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">one_hot</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">target</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">model_prob</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">target</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">confidence</span><span class="p">)</span>
<span class="n">model_prob</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">((</span><span class="n">target</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">ignore_index</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">model_prob</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">one_hot</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">labels</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">model_prob</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">labels</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">confidence</span><span class="p">)</span>
<span class="n">model_prob</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">((</span><span class="n">labels</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">ignore_index</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>

<span class="k">return</span> <span class="n">F</span><span class="o">.</span><span class="n">kl_div</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">model_prob</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;sum&#39;</span><span class="p">)</span>

Expand Down

0 comments on commit 90cd2a4

Please sign in to comment.