Skip to content

Commit

Permalink
update sphinx docs
Browse files Browse the repository at this point in the history
  • Loading branch information
kristian-georgiev committed Mar 23, 2023
1 parent 918d546 commit 32862db
Show file tree
Hide file tree
Showing 21 changed files with 195 additions and 64 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
author = 'Kristian Georgiev'

# The full version, including alpha/beta/rc tags
release = '0.1.0'
release = '0.1.1'


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion docs/html/.buildinfo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
config: def77ad9866373dc2cf63179c7913729
config: 1a000b88aa8f9cf78e81853015c446f0
tags: 645f666f9bcd5a90fca523b33c5a78b7
Binary file modified docs/html/.doctrees/environment.pickle
Binary file not shown.
Binary file modified docs/html/.doctrees/trak.doctree
Binary file not shown.
6 changes: 3 additions & 3 deletions docs/html/_modules/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<meta name="color-scheme" content="light dark"><link rel="index" title="Index" href="../genindex.html" /><link rel="search" title="Search" href="../search.html" />

<!-- Generated with Sphinx 4.4.0 and Furo 2022.12.07 -->
<title>Overview: module code - TRAK 0.1.0 documentation</title>
<title>Overview: module code - TRAK 0.1.1 documentation</title>
<link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
<link rel="stylesheet" type="text/css" href="../_static/styles/furo.css?digest=91d0f0d1c444bdcb17a68e833c7a53903343c195" />
<link rel="stylesheet" type="text/css" href="../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
Expand Down Expand Up @@ -122,7 +122,7 @@
</label>
</div>
<div class="header-center">
<a href="../index.html"><div class="brand">TRAK 0.1.0 documentation</div></a>
<a href="../index.html"><div class="brand">TRAK 0.1.1 documentation</div></a>
</div>
<div class="header-right">
<div class="theme-toggle-container theme-toggle-header">
Expand All @@ -145,7 +145,7 @@
<div class="sidebar-sticky"><a class="sidebar-brand" href="../index.html">


<span class="sidebar-brand-text">TRAK 0.1.0 documentation</span>
<span class="sidebar-brand-text">TRAK 0.1.1 documentation</span>

</a><form class="sidebar-search-container" method="get" action="../search.html" role="search">
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
Expand Down
51 changes: 51 additions & 0 deletions docs/html/_modules/trak/modelout_functions.html
Original file line number Diff line number Diff line change
Expand Up @@ -591,9 +591,60 @@ <h1>Source code for trak.modelout_functions</h1><div class="highlight"><pre>
<span class="k">return</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">ps</span><span class="p">)</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></div></div>


<div class="viewcode-block" id="TextClassificationModelOutput"><a class="viewcode-back" href="../../trak.html#trak.modelout_functions.TextClassificationModelOutput">[docs]</a><span class="k">class</span> <span class="nc">TextClassificationModelOutput</span><span class="p">(</span><span class="n">AbstractModelOutput</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Margin for text classification models. This assumes that the model takes in</span>
<span class="sd"> input_ids, token_type_ids, and attention_mask.</span>
<span class="sd"> .. math::</span>
<span class="sd"> \text{logit}[\text{correct}] - \log\left(\sum_{i \neq \text{correct}}</span>
<span class="sd"> \exp(\text{logit}[i])\right)</span>
<span class="sd"> Version of margin proposed in &#39;Understanding Influence Functions</span>
<span class="sd"> and Datamodels via Harmonic Analysis&#39;</span>
<span class="sd"> &quot;&quot;&quot;</span>

<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">temperature</span><span class="o">=</span><span class="mf">1.</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">ch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">loss_temperature</span> <span class="o">=</span> <span class="n">temperature</span>

<div class="viewcode-block" id="TextClassificationModelOutput.get_output"><a class="viewcode-back" href="../../trak.html#trak.modelout_functions.TextClassificationModelOutput.get_output">[docs]</a> <span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">get_output</span><span class="p">(</span><span class="n">func_model</span><span class="p">,</span>
<span class="n">weights</span><span class="p">:</span> <span class="n">Iterable</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">buffers</span><span class="p">:</span> <span class="n">Iterable</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">input_id</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">token_type_id</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">label</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">func_model</span><span class="p">(</span><span class="n">weights</span><span class="p">,</span> <span class="n">buffers</span><span class="p">,</span> <span class="n">input_id</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">token_type_id</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">attention_mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
<span class="n">bindex</span> <span class="o">=</span> <span class="n">ch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">logits</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">logits</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">logits_correct</span> <span class="o">=</span> <span class="n">logits</span><span class="p">[</span><span class="n">bindex</span><span class="p">,</span> <span class="n">label</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)]</span>

<span class="n">cloned_logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="n">cloned_logits</span><span class="p">[</span><span class="n">bindex</span><span class="p">,</span> <span class="n">label</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)]</span> <span class="o">=</span> <span class="n">ch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="o">-</span><span class="n">ch</span><span class="o">.</span><span class="n">inf</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">logits</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>

<span class="n">margins</span> <span class="o">=</span> <span class="n">logits_correct</span> <span class="o">-</span> <span class="n">cloned_logits</span><span class="o">.</span><span class="n">logsumexp</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">margins</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span></div>

<div class="viewcode-block" id="TextClassificationModelOutput.forward"><a class="viewcode-back" href="../../trak.html#trak.modelout_functions.TextClassificationModelOutput.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">Module</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Iterable</span><span class="p">[</span><span class="n">Tensor</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="n">input_ids</span><span class="p">,</span> <span class="n">token_type_ids</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">batch</span>
<span class="k">return</span> <span class="n">model</span><span class="p">(</span><span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">,</span>
<span class="n">token_type_ids</span><span class="o">=</span><span class="n">token_type_ids</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">)</span></div>

<div class="viewcode-block" id="TextClassificationModelOutput.get_out_to_loss_grad"><a class="viewcode-back" href="../../trak.html#trak.modelout_functions.TextClassificationModelOutput.get_out_to_loss_grad">[docs]</a> <span class="k">def</span> <span class="nf">get_out_to_loss_grad</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">func_model</span><span class="p">,</span> <span class="n">weights</span><span class="p">,</span> <span class="n">buffers</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Iterable</span><span class="p">[</span><span class="n">Tensor</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="n">input_ids</span><span class="p">,</span> <span class="n">token_type_ids</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">batch</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">func_model</span><span class="p">(</span><span class="n">weights</span><span class="p">,</span> <span class="n">buffers</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">token_type_ids</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">)</span>
<span class="n">ps</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">logits</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_temperature</span><span class="p">)[</span><span class="n">ch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">logits</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)),</span> <span class="n">labels</span><span class="p">]</span>
<span class="k">return</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">ps</span><span class="p">)</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></div></div>


<span class="n">TASK_TO_MODELOUT</span> <span class="o">=</span> <span class="p">{</span>
<span class="p">(</span><span class="s1">&#39;image_classification&#39;</span><span class="p">,</span> <span class="kc">True</span><span class="p">):</span> <span class="n">ImageClassificationModelOutput</span><span class="p">,</span>
<span class="p">(</span><span class="s1">&#39;image_classification&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">):</span> <span class="n">IterImageClassificationModelOutput</span><span class="p">,</span>
<span class="p">(</span><span class="s1">&#39;text_classification&#39;</span><span class="p">,</span> <span class="kc">True</span><span class="p">):</span> <span class="n">TextClassificationModelOutput</span><span class="p">,</span>
<span class="p">(</span><span class="s1">&#39;clip&#39;</span><span class="p">,</span> <span class="kc">True</span><span class="p">):</span> <span class="n">CLIPModelOutput</span><span class="p">,</span>
<span class="p">}</span>
</pre></div>
Expand Down
Loading

0 comments on commit 32862db

Please sign in to comment.