Skip to content

Commit fa2c5d6

Browse files
Circle CICircle CI
Circle CI
authored and
Circle CI
committed
CircleCI update of dev docs (2955).
1 parent e164e0e commit fa2c5d6

File tree

273 files changed

+733626
-731354
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

273 files changed

+733626
-731354
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"\n# Different gradient computations for regularized optimal transport\n\nThis example illustrates the differences in terms of computation time between the gradient options for the Sinkhorn solver.\n"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {
14+
"collapsed": false
15+
},
16+
"outputs": [],
17+
"source": [
18+
"# Author: Sonia Mazelet <[email protected]>\n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 1\n\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.backend import torch"
19+
]
20+
},
21+
{
22+
"cell_type": "markdown",
23+
"metadata": {},
24+
"source": [
25+
"## Time comparison of the Sinkhorn solver for different gradient options\n\n"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"metadata": {
32+
"collapsed": false
33+
},
34+
"outputs": [],
35+
"source": [
36+
"n_trials = 10\ntimes_autodiff = torch.zeros(n_trials)\ntimes_envelope = torch.zeros(n_trials)\ntimes_last_step = torch.zeros(n_trials)\n\nn_samples_s = 300\nn_samples_t = 300\nn_features = 5\nreg = 0.03\n\n# Time required for the Sinkhorn solver and gradient computations, for different gradient options over multiple Gaussian distributions\nfor i in range(n_trials):\n x = torch.rand((n_samples_s, n_features))\n y = torch.rand((n_samples_t, n_features))\n a = ot.utils.unif(n_samples_s)\n b = ot.utils.unif(n_samples_t)\n M = ot.dist(x, y)\n\n a = torch.tensor(a, requires_grad=True)\n b = torch.tensor(b, requires_grad=True)\n M = M.clone().detach().requires_grad_(True)\n\n # autodiff provides the gradient for all the outputs (plan, value, value_linear)\n ot.tic()\n res_autodiff = ot.solve(M, a, b, reg=reg, grad=\"autodiff\")\n res_autodiff.value.backward()\n times_autodiff[i] = ot.toq()\n\n a = a.clone().detach().requires_grad_(True)\n b = b.clone().detach().requires_grad_(True)\n M = M.clone().detach().requires_grad_(True)\n\n # envelope provides the gradient for value\n ot.tic()\n res_envelope = ot.solve(M, a, b, reg=reg, grad=\"envelope\")\n res_envelope.value.backward()\n times_envelope[i] = ot.toq()\n\n a = a.clone().detach().requires_grad_(True)\n b = b.clone().detach().requires_grad_(True)\n M = M.clone().detach().requires_grad_(True)\n\n # last_step provides the gradient for all the outputs, but only for the last iteration of the Sinkhorn algorithm\n ot.tic()\n res_last_step = ot.solve(M, a, b, reg=reg, grad=\"last_step\")\n res_last_step.value.backward()\n times_last_step[i] = ot.toq()\n\npl.figure(1, figsize=(5, 3))\npl.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0, 0))\npl.boxplot(\n ([times_autodiff, times_envelope, times_last_step]),\n tick_labels=[\"autodiff\", \"envelope\", \"last_step\"],\n showfliers=False,\n)\npl.ylabel(\"Time (s)\")\npl.show()"
37+
]
38+
}
39+
],
40+
"metadata": {
41+
"kernelspec": {
42+
"display_name": "Python 3",
43+
"language": "python",
44+
"name": "python3"
45+
},
46+
"language_info": {
47+
"codemirror_mode": {
48+
"name": "ipython",
49+
"version": 3
50+
},
51+
"file_extension": ".py",
52+
"mimetype": "text/x-python",
53+
"name": "python",
54+
"nbconvert_exporter": "python",
55+
"pygments_lexer": "ipython3",
56+
"version": "3.10.15"
57+
}
58+
},
59+
"nbformat": 4,
60+
"nbformat_minor": 0
61+
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
================================================
4+
Different gradient computations for regularized optimal transport
5+
================================================
6+
7+
This example illustrates the differences in terms of computation time between the gradient options for the Sinkhorn solver.
8+
9+
"""
10+
11+
# Author: Sonia Mazelet <[email protected]>
12+
#
13+
# License: MIT License
14+
15+
# sphinx_gallery_thumbnail_number = 1
16+
17+
import matplotlib.pylab as pl
18+
import ot
19+
from ot.backend import torch
20+
21+
22+
##############################################################################
23+
# Time comparison of the Sinkhorn solver for different gradient options
24+
# -------------
25+
26+
27+
# %% parameters
28+
29+
n_trials = 10
30+
times_autodiff = torch.zeros(n_trials)
31+
times_envelope = torch.zeros(n_trials)
32+
times_last_step = torch.zeros(n_trials)
33+
34+
n_samples_s = 300
35+
n_samples_t = 300
36+
n_features = 5
37+
reg = 0.03
38+
39+
# Time required for the Sinkhorn solver and gradient computations, for different gradient options over multiple Gaussian distributions
40+
for i in range(n_trials):
41+
x = torch.rand((n_samples_s, n_features))
42+
y = torch.rand((n_samples_t, n_features))
43+
a = ot.utils.unif(n_samples_s)
44+
b = ot.utils.unif(n_samples_t)
45+
M = ot.dist(x, y)
46+
47+
a = torch.tensor(a, requires_grad=True)
48+
b = torch.tensor(b, requires_grad=True)
49+
M = M.clone().detach().requires_grad_(True)
50+
51+
# autodiff provides the gradient for all the outputs (plan, value, value_linear)
52+
ot.tic()
53+
res_autodiff = ot.solve(M, a, b, reg=reg, grad="autodiff")
54+
res_autodiff.value.backward()
55+
times_autodiff[i] = ot.toq()
56+
57+
a = a.clone().detach().requires_grad_(True)
58+
b = b.clone().detach().requires_grad_(True)
59+
M = M.clone().detach().requires_grad_(True)
60+
61+
# envelope provides the gradient for value
62+
ot.tic()
63+
res_envelope = ot.solve(M, a, b, reg=reg, grad="envelope")
64+
res_envelope.value.backward()
65+
times_envelope[i] = ot.toq()
66+
67+
a = a.clone().detach().requires_grad_(True)
68+
b = b.clone().detach().requires_grad_(True)
69+
M = M.clone().detach().requires_grad_(True)
70+
71+
# last_step provides the gradient for all the outputs, but only for the last iteration of the Sinkhorn algorithm
72+
ot.tic()
73+
res_last_step = ot.solve(M, a, b, reg=reg, grad="last_step")
74+
res_last_step.value.backward()
75+
times_last_step[i] = ot.toq()
76+
77+
pl.figure(1, figsize=(5, 3))
78+
pl.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
79+
pl.boxplot(
80+
([times_autodiff, times_envelope, times_last_step]),
81+
tick_labels=["autodiff", "envelope", "last_step"],
82+
showfliers=False,
83+
)
84+
pl.ylabel("Time (s)")
85+
pl.show()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
753 Bytes
266 Bytes
289 Bytes
290 Bytes
198 Bytes
-49 Bytes
143 Bytes
565 Bytes

master/_modules/ot/solvers.html

+31-6
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,13 @@ <h1>Source code for ot.solvers</h1><div class="highlight"><pre>
209209
<span class="sd"> verbose : bool, optional</span>
210210
<span class="sd"> Print information in the solver, by default False</span>
211211
<span class="sd"> grad : str, optional</span>
212-
<span class="sd"> Type of gradient computation, either or &#39;autodiff&#39; or &#39;envelope&#39; used only for</span>
212+
<span class="sd"> Type of gradient computation, either or &#39;autodiff&#39;, &#39;envelope&#39; or &#39;last_step&#39; used only for</span>
213213
<span class="sd"> Sinkhorn solver. By default &#39;autodiff&#39; provides gradients wrt all</span>
214214
<span class="sd"> outputs (`plan, value, value_linear`) but with important memory cost.</span>
215215
<span class="sd"> &#39;envelope&#39; provides gradients only for `value` and and other outputs are</span>
216-
<span class="sd"> detached. This is useful for memory saving when only the value is needed.</span>
216+
<span class="sd"> detached. This is useful for memory saving when only the value is needed. &#39;last_step&#39; provides</span>
217+
<span class="sd"> gradients only for the last iteration of the Sinkhorn solver, but provides gradient for both the OT plan and the objective values.</span>
218+
<span class="sd"> &#39;detach&#39; does not compute the gradients for the Sinkhorn solver.</span>
217219

218220
<span class="sd"> Returns</span>
219221
<span class="sd"> -------</span>
@@ -365,7 +367,6 @@ <h1>Source code for ot.solvers</h1><div class="highlight"><pre>
365367
<span class="sd"> linear regression. NeurIPS.</span>
366368

367369
<span class="sd"> &quot;&quot;&quot;</span>
368-
369370
<span class="c1"># detect backend</span>
370371
<span class="n">nx</span> <span class="o">=</span> <span class="n">get_backend</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
371372

@@ -496,7 +497,11 @@ <h1>Source code for ot.solvers</h1><div class="highlight"><pre>
496497
<span class="n">potentials</span> <span class="o">=</span> <span class="p">(</span><span class="n">log</span><span class="p">[</span><span class="s2">&quot;u&quot;</span><span class="p">],</span> <span class="n">log</span><span class="p">[</span><span class="s2">&quot;v&quot;</span><span class="p">])</span>
497498

498499
<span class="k">elif</span> <span class="n">reg_type</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;entropy&quot;</span><span class="p">,</span> <span class="s2">&quot;kl&quot;</span><span class="p">]:</span>
499-
<span class="k">if</span> <span class="n">grad</span> <span class="o">==</span> <span class="s2">&quot;envelope&quot;</span><span class="p">:</span> <span class="c1"># if envelope then detach the input</span>
500+
<span class="k">if</span> <span class="n">grad</span> <span class="ow">in</span> <span class="p">[</span>
501+
<span class="s2">&quot;envelope&quot;</span><span class="p">,</span>
502+
<span class="s2">&quot;last_step&quot;</span><span class="p">,</span>
503+
<span class="s2">&quot;detach&quot;</span><span class="p">,</span>
504+
<span class="p">]:</span> <span class="c1"># if envelope, last_step or detach then detach the input</span>
500505
<span class="n">M0</span><span class="p">,</span> <span class="n">a0</span><span class="p">,</span> <span class="n">b0</span> <span class="o">=</span> <span class="n">M</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span>
501506
<span class="n">M</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="n">nx</span><span class="o">.</span><span class="n">detach</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
502507

@@ -505,6 +510,12 @@ <h1>Source code for ot.solvers</h1><div class="highlight"><pre>
505510
<span class="n">max_iter</span> <span class="o">=</span> <span class="mi">1000</span>
506511
<span class="k">if</span> <span class="n">tol</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
507512
<span class="n">tol</span> <span class="o">=</span> <span class="mf">1e-9</span>
513+
<span class="k">if</span> <span class="n">grad</span> <span class="o">==</span> <span class="s2">&quot;last_step&quot;</span><span class="p">:</span>
514+
<span class="k">if</span> <span class="n">max_iter</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
515+
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
516+
<span class="s2">&quot;The maximum number of iterations must be greater than 0 when using grad=last_step.&quot;</span>
517+
<span class="p">)</span>
518+
<span class="n">max_iter</span> <span class="o">=</span> <span class="n">max_iter</span> <span class="o">-</span> <span class="mi">1</span>
508519

509520
<span class="n">plan</span><span class="p">,</span> <span class="n">log</span> <span class="o">=</span> <span class="n">sinkhorn_log</span><span class="p">(</span>
510521
<span class="n">a</span><span class="p">,</span>
@@ -517,6 +528,22 @@ <h1>Source code for ot.solvers</h1><div class="highlight"><pre>
517528
<span class="n">verbose</span><span class="o">=</span><span class="n">verbose</span><span class="p">,</span>
518529
<span class="p">)</span>
519530

531+
<span class="n">potentials</span> <span class="o">=</span> <span class="p">(</span><span class="n">log</span><span class="p">[</span><span class="s2">&quot;log_u&quot;</span><span class="p">],</span> <span class="n">log</span><span class="p">[</span><span class="s2">&quot;log_v&quot;</span><span class="p">])</span>
532+
533+
<span class="c1"># if last_step, compute the last step of the Sinkhorn algorithm with the non-detached inputs</span>
534+
<span class="k">if</span> <span class="n">grad</span> <span class="o">==</span> <span class="s2">&quot;last_step&quot;</span><span class="p">:</span>
535+
<span class="n">loga</span> <span class="o">=</span> <span class="n">nx</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">a0</span><span class="p">)</span>
536+
<span class="n">logb</span> <span class="o">=</span> <span class="n">nx</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">b0</span><span class="p">)</span>
537+
<span class="n">v</span> <span class="o">=</span> <span class="n">logb</span> <span class="o">-</span> <span class="n">nx</span><span class="o">.</span><span class="n">logsumexp</span><span class="p">(</span><span class="o">-</span><span class="n">M0</span> <span class="o">/</span> <span class="n">reg</span> <span class="o">+</span> <span class="n">potentials</span><span class="p">[</span><span class="mi">0</span><span class="p">][:,</span> <span class="kc">None</span><span class="p">],</span> <span class="mi">0</span><span class="p">)</span>
538+
<span class="n">u</span> <span class="o">=</span> <span class="n">loga</span> <span class="o">-</span> <span class="n">nx</span><span class="o">.</span><span class="n">logsumexp</span><span class="p">(</span><span class="o">-</span><span class="n">M0</span> <span class="o">/</span> <span class="n">reg</span> <span class="o">+</span> <span class="n">potentials</span><span class="p">[</span><span class="mi">1</span><span class="p">][</span><span class="kc">None</span><span class="p">,</span> <span class="p">:],</span> <span class="mi">1</span><span class="p">)</span>
539+
<span class="n">plan</span> <span class="o">=</span> <span class="n">nx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">M0</span> <span class="o">/</span> <span class="n">reg</span> <span class="o">+</span> <span class="n">u</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">+</span> <span class="n">v</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:])</span>
540+
<span class="n">potentials</span> <span class="o">=</span> <span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
541+
<span class="n">log</span><span class="p">[</span><span class="s2">&quot;niter&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">max_iter</span> <span class="o">+</span> <span class="mi">1</span>
542+
<span class="n">log</span><span class="p">[</span><span class="s2">&quot;log_u&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">u</span>
543+
<span class="n">log</span><span class="p">[</span><span class="s2">&quot;log_v&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span>
544+
<span class="n">log</span><span class="p">[</span><span class="s2">&quot;u&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">nx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">u</span><span class="p">)</span>
545+
<span class="n">log</span><span class="p">[</span><span class="s2">&quot;v&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">nx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">v</span><span class="p">)</span>
546+
520547
<span class="n">value_linear</span> <span class="o">=</span> <span class="n">nx</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">M</span> <span class="o">*</span> <span class="n">plan</span><span class="p">)</span>
521548

522549
<span class="k">if</span> <span class="n">reg_type</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span> <span class="o">==</span> <span class="s2">&quot;entropy&quot;</span><span class="p">:</span>
@@ -526,8 +553,6 @@ <h1>Source code for ot.solvers</h1><div class="highlight"><pre>
526553
<span class="n">plan</span><span class="p">,</span> <span class="n">a</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">b</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
527554
<span class="p">)</span>
528555

529-
<span class="n">potentials</span> <span class="o">=</span> <span class="p">(</span><span class="n">log</span><span class="p">[</span><span class="s2">&quot;log_u&quot;</span><span class="p">],</span> <span class="n">log</span><span class="p">[</span><span class="s2">&quot;log_v&quot;</span><span class="p">])</span>
530-
531556
<span class="k">if</span> <span class="n">grad</span> <span class="o">==</span> <span class="s2">&quot;envelope&quot;</span><span class="p">:</span> <span class="c1"># set the gradient at convergence</span>
532557
<span class="n">value</span> <span class="o">=</span> <span class="n">nx</span><span class="o">.</span><span class="n">set_gradients</span><span class="p">(</span>
533558
<span class="n">value</span><span class="p">,</span>

0 commit comments

Comments
 (0)