Skip to content

Commit fd1ff29

Browse files
Circle CICircle CI
Circle CI
authored and
Circle CI
committed
CircleCI update of dev docs (2779).
1 parent 1075619 commit fd1ff29

File tree

375 files changed

+738752
-735639
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

375 files changed

+738752
-735639
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"\n# GMM Flow\n\nIllustration of the flow of a Gaussian Mixture with\nrespect to its GMM-OT distance with respect to a\nfixed GMM.\n"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {
14+
"collapsed": false
15+
},
16+
"outputs": [],
17+
"source": [
18+
"# Author: Eloi Tanguy <eloi.tanguy@u-paris>\n# Remi Flamary <[email protected]>\n# Julie Delon <[email protected]>\n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 4\n\nimport numpy as np\nimport matplotlib.pylab as pl\nfrom matplotlib import colormaps as cm\nimport ot\nimport ot.plot\nfrom ot.utils import proj_SDP, proj_simplex\nfrom ot.gmm import gmm_ot_loss\nimport torch\nfrom torch.optim import Adam\nfrom matplotlib.patches import Ellipse"
19+
]
20+
},
21+
{
22+
"cell_type": "markdown",
23+
"metadata": {},
24+
"source": [
25+
"## Generate data and plot it\n\n"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"metadata": {
32+
"collapsed": false
33+
},
34+
"outputs": [],
35+
"source": [
36+
"torch.manual_seed(3)\nks = 3\nkt = 2\nd = 2\neps = 0.1\nm_s = torch.randn(ks, d)\nm_s.requires_grad_()\nm_t = torch.randn(kt, d)\nC_s = torch.randn(ks, d, d)\nC_s = torch.matmul(C_s, torch.transpose(C_s, 2, 1))\nC_s += eps * torch.eye(d)[None, :, :] * torch.ones(ks, 1, 1)\nC_s.requires_grad_()\nC_t = torch.randn(kt, d, d)\nC_t = torch.matmul(C_t, torch.transpose(C_t, 2, 1))\nC_t += eps * torch.eye(d)[None, :, :] * torch.ones(kt, 1, 1)\nw_s = torch.randn(ks)\nw_s = proj_simplex(w_s)\nw_s.requires_grad_()\nw_t = torch.tensor(ot.unif(kt))\n\n\ndef draw_cov(mu, C, color=None, label=None, nstd=1, alpha=.5):\n\n def eigsorted(cov):\n vals, vecs = np.linalg.eigh(cov)\n order = vals.argsort()[::-1]\n return vals[order], vecs[:, order]\n\n vals, vecs = eigsorted(C)\n theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))\n w, h = 2 * nstd * np.sqrt(vals)\n ell = Ellipse(xy=(mu[0], mu[1]),\n width=w, height=h, alpha=alpha,\n angle=theta, facecolor=color, edgecolor=color, label=label, fill=True)\n pl.gca().add_artist(ell)\n\n\ndef draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1):\n for k in range(ms.shape[0]):\n draw_cov(ms[k], Cs[k], color, None, nstd,\n alpha * ws[k])\n\n\naxis = [-3, 3, -3, 3]\npl.figure(1, (20, 10))\npl.clf()\n\npl.subplot(1, 2, 1)\npl.scatter(m_s[:, 0].detach(), m_s[:, 1].detach(), color='C0')\ndraw_gmm(m_s.detach(), C_s.detach(),\n torch.softmax(w_s, 0).detach().numpy(),\n color='C0')\npl.axis(axis)\npl.title('Source GMM')\n\npl.subplot(1, 2, 2)\npl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), color='C1')\ndraw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), color='C1')\npl.axis(axis)\npl.title('Target GMM')"
37+
]
38+
},
39+
{
40+
"cell_type": "markdown",
41+
"metadata": {},
42+
"source": [
43+
"## Gradient descent loop\n\n"
44+
]
45+
},
46+
{
47+
"cell_type": "code",
48+
"execution_count": null,
49+
"metadata": {
50+
"collapsed": false
51+
},
52+
"outputs": [],
53+
"source": [
54+
"n_gd_its = 100\nlr = 3e-2\nopt = Adam([{'params': m_s, 'lr': 2 * lr},\n {'params': C_s, 'lr': lr},\n {'params': w_s, 'lr': lr}])\nm_list = [m_s.data.numpy().copy()]\nC_list = [C_s.data.numpy().copy()]\nw_list = [torch.softmax(w_s, 0).data.numpy().copy()]\nloss_list = []\n\nfor _ in range(n_gd_its):\n opt.zero_grad()\n loss = gmm_ot_loss(m_s, m_t, C_s, C_t,\n torch.softmax(w_s, 0), w_t)\n loss.backward()\n opt.step()\n with torch.no_grad():\n C_s.data = proj_SDP(C_s.data, vmin=1e-6)\n m_list.append(m_s.data.numpy().copy())\n C_list.append(C_s.data.numpy().copy())\n w_list.append(torch.softmax(w_s, 0).data.numpy().copy())\n loss_list.append(loss.item())\n\npl.figure(2)\npl.clf()\npl.plot(loss_list)\npl.title('Loss')\npl.xlabel('its')\npl.ylabel('loss')"
55+
]
56+
},
57+
{
58+
"cell_type": "markdown",
59+
"metadata": {},
60+
"source": [
61+
"## Last step visualisation\n\n"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": null,
67+
"metadata": {
68+
"collapsed": false
69+
},
70+
"outputs": [],
71+
"source": [
72+
"axis = [-3, 3, -3, 3]\npl.figure(3, (10, 10))\npl.clf()\npl.title('GMM flow, last step')\npl.scatter(m_list[0][:, 0], m_list[0][:, 1], color='C0', label='Source')\ndraw_gmm(m_list[0], C_list[0], w_list[0], color='C0')\npl.axis(axis)\n\npl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), color='C1', label='Target')\ndraw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), color='C1')\npl.axis(axis)\n\nk = -1\npl.scatter(m_list[k][:, 0], m_list[k][:, 1], color='C2', alpha=1, label='Last step')\ndraw_gmm(m_list[k], C_list[k], w_list[0], color='C2', alpha=1)\n\npl.axis(axis)\npl.legend(fontsize=15)"
73+
]
74+
},
75+
{
76+
"cell_type": "markdown",
77+
"metadata": {},
78+
"source": [
79+
"## Steps visualisation\n\n"
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": null,
85+
"metadata": {
86+
"collapsed": false
87+
},
88+
"outputs": [],
89+
"source": [
90+
"def index_to_color(i):\n return int(i**0.5)\n\n\nn_steps_visu = 100\npl.figure(3, (10, 10))\npl.clf()\npl.title('GMM flow, all steps')\n\nits_to_show = [int(x) for x in np.linspace(1, n_gd_its - 1, n_steps_visu)]\ncmp = cm['plasma'].resampled(index_to_color(n_steps_visu))\n\npl.scatter(m_list[0][:, 0], m_list[0][:, 1],\n color=cmp(index_to_color(0)), label='Source')\ndraw_gmm(m_list[0], C_list[0], w_list[0],\n color=cmp(index_to_color(0)))\n\npl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(),\n color=cmp(index_to_color(n_steps_visu - 1)), label='Target')\ndraw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(),\n color=cmp(index_to_color(n_steps_visu - 1)))\n\n\nfor k in its_to_show:\n pl.scatter(m_list[k][:, 0], m_list[k][:, 1],\n color=cmp(index_to_color(k)), alpha=0.8)\n draw_gmm(m_list[k], C_list[k], w_list[0],\n color=cmp(index_to_color(k)), alpha=0.04)\n\npl.axis(axis)\npl.legend(fontsize=15)"
91+
]
92+
}
93+
],
94+
"metadata": {
95+
"kernelspec": {
96+
"display_name": "Python 3",
97+
"language": "python",
98+
"name": "python3"
99+
},
100+
"language_info": {
101+
"codemirror_mode": {
102+
"name": "ipython",
103+
"version": 3
104+
},
105+
"file_extension": ".py",
106+
"mimetype": "text/x-python",
107+
"name": "python",
108+
"nbconvert_exporter": "python",
109+
"pygments_lexer": "ipython3",
110+
"version": "3.10.14"
111+
}
112+
},
113+
"nbformat": 4,
114+
"nbformat_minor": 0
115+
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

master/_downloads/82f76233504be88025895b8d5276f8eb/plot_OT_1D.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# %%
12
# -*- coding: utf-8 -*-
23
"""
34
======================================
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# %%
2+
# -*- coding: utf-8 -*-
3+
r"""
4+
====================================================
5+
GMM Plan 1D
6+
====================================================
7+
8+
Illustration of the GMM plan for
9+
the Mixture Wasserstein between two GMM in 1D,
10+
as well as the two maps T_mean and T_rand.
11+
T_mean is the barycentric projection of the GMM coupling,
12+
and T_rand takes a random gaussian image between two components,
13+
according to the coupling and the GMMs.
14+
See [69] for details.
15+
.. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.
16+
17+
"""
18+
19+
# Author: Eloi Tanguy <eloi.tanguy@u-paris>
20+
# Remi Flamary <[email protected]>
21+
# Julie Delon <[email protected]>
22+
#
23+
# License: MIT License
24+
25+
# sphinx_gallery_thumbnail_number = 1
26+
27+
import numpy as np
28+
from ot.plot import plot1D_mat, rescale_for_imshow_plot
29+
from ot.gmm import gmm_ot_plan_density, gmm_pdf, gmm_ot_apply_map
30+
import matplotlib.pyplot as plt
31+
32+
##############################################################################
33+
# Generate GMMOT plan plot it
34+
# ---------------------------
35+
ks = 2
36+
kt = 3
37+
d = 1
38+
eps = 0.1
39+
m_s = np.array([[1], [2]])
40+
m_t = np.array([[3], [4.2], [5]])
41+
C_s = np.array([[[.05]], [[.06]]])
42+
C_t = np.array([[[.03]], [[.07]], [[.04]]])
43+
w_s = np.array([.4, .6])
44+
w_t = np.array([.4, .2, .4])
45+
46+
n = 500
47+
a_x, b_x = 0, 3
48+
x = np.linspace(a_x, b_x, n)
49+
a_y, b_y = 2, 6
50+
y = np.linspace(a_y, b_y, n)
51+
plan_density = gmm_ot_plan_density(x[:, None], y[:, None],
52+
m_s, m_t, C_s, C_t, w_s, w_t,
53+
plan=None, atol=2e-2)
54+
55+
a = gmm_pdf(x[:, None], m_s, C_s, w_s)
56+
b = gmm_pdf(y[:, None], m_t, C_t, w_t)
57+
plt.figure(figsize=(8, 8))
58+
plot1D_mat(a, b, plan_density, title='GMM OT plan', plot_style='xy',
59+
a_label='Source distribution', b_label='Target distribution')
60+
61+
62+
##############################################################################
63+
# Generate GMMOT maps and plot them over plan
64+
# -------------------------------------------
65+
plt.figure(figsize=(8, 8))
66+
ax_s, ax_t, ax_M = plot1D_mat(a, b, plan_density, plot_style='xy',
67+
title='GMM OT plan with T_mean and T_rand maps',
68+
a_label='Source distribution',
69+
b_label='Target distribution')
70+
T_mean = gmm_ot_apply_map(x[:, None], m_s, m_t, C_s, C_t,
71+
w_s, w_t, method='bary')[:, 0]
72+
x_rescaled, T_mean_rescaled = rescale_for_imshow_plot(x, T_mean, n,
73+
a_y=a_y, b_y=b_y)
74+
75+
ax_M.plot(x_rescaled, T_mean_rescaled, label='T_mean', alpha=.5,
76+
linewidth=5, color='aqua')
77+
78+
T_rand = gmm_ot_apply_map(x[:, None], m_s, m_t, C_s, C_t,
79+
w_s, w_t, method='rand', seed=0)[:, 0]
80+
x_rescaled, T_rand_rescaled = rescale_for_imshow_plot(x, T_rand, n,
81+
a_y=a_y, b_y=b_y)
82+
83+
ax_M.scatter(x_rescaled, T_rand_rescaled, label='T_rand', alpha=.5,
84+
s=20, color='orange')
85+
86+
ax_M.legend(loc='upper left', fontsize=13)
87+
88+
# %%
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"\n# GMM Plan 1D\n\nIllustration of the GMM plan for\nthe Mixture Wasserstein between two GMM in 1D,\nas well as the two maps T_mean and T_rand.\nT_mean is the barycentric projection of the GMM coupling,\nand T_rand takes a random gaussian image between two components,\naccording to the coupling and the GMMs.\nSee [69] for details.\n.. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.\n"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {
14+
"collapsed": false
15+
},
16+
"outputs": [],
17+
"source": [
18+
"# Author: Eloi Tanguy <eloi.tanguy@u-paris>\n# Remi Flamary <[email protected]>\n# Julie Delon <[email protected]>\n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 1\n\nimport numpy as np\nfrom ot.plot import plot1D_mat, rescale_for_imshow_plot\nfrom ot.gmm import gmm_ot_plan_density, gmm_pdf, gmm_ot_apply_map\nimport matplotlib.pyplot as plt"
19+
]
20+
},
21+
{
22+
"cell_type": "markdown",
23+
"metadata": {},
24+
"source": [
25+
"## Generate GMMOT plan plot it\n\n"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"metadata": {
32+
"collapsed": false
33+
},
34+
"outputs": [],
35+
"source": [
36+
"ks = 2\nkt = 3\nd = 1\neps = 0.1\nm_s = np.array([[1], [2]])\nm_t = np.array([[3], [4.2], [5]])\nC_s = np.array([[[.05]], [[.06]]])\nC_t = np.array([[[.03]], [[.07]], [[.04]]])\nw_s = np.array([.4, .6])\nw_t = np.array([.4, .2, .4])\n\nn = 500\na_x, b_x = 0, 3\nx = np.linspace(a_x, b_x, n)\na_y, b_y = 2, 6\ny = np.linspace(a_y, b_y, n)\nplan_density = gmm_ot_plan_density(x[:, None], y[:, None],\n m_s, m_t, C_s, C_t, w_s, w_t,\n plan=None, atol=2e-2)\n\na = gmm_pdf(x[:, None], m_s, C_s, w_s)\nb = gmm_pdf(y[:, None], m_t, C_t, w_t)\nplt.figure(figsize=(8, 8))\nplot1D_mat(a, b, plan_density, title='GMM OT plan', plot_style='xy',\n a_label='Source distribution', b_label='Target distribution')"
37+
]
38+
},
39+
{
40+
"cell_type": "markdown",
41+
"metadata": {},
42+
"source": [
43+
"## Generate GMMOT maps and plot them over plan\n\n"
44+
]
45+
},
46+
{
47+
"cell_type": "code",
48+
"execution_count": null,
49+
"metadata": {
50+
"collapsed": false
51+
},
52+
"outputs": [],
53+
"source": [
54+
"plt.figure(figsize=(8, 8))\nax_s, ax_t, ax_M = plot1D_mat(a, b, plan_density, plot_style='xy',\n title='GMM OT plan with T_mean and T_rand maps',\n a_label='Source distribution',\n b_label='Target distribution')\nT_mean = gmm_ot_apply_map(x[:, None], m_s, m_t, C_s, C_t,\n w_s, w_t, method='bary')[:, 0]\nx_rescaled, T_mean_rescaled = rescale_for_imshow_plot(x, T_mean, n,\n a_y=a_y, b_y=b_y)\n\nax_M.plot(x_rescaled, T_mean_rescaled, label='T_mean', alpha=.5,\n linewidth=5, color='aqua')\n\nT_rand = gmm_ot_apply_map(x[:, None], m_s, m_t, C_s, C_t,\n w_s, w_t, method='rand', seed=0)[:, 0]\nx_rescaled, T_rand_rescaled = rescale_for_imshow_plot(x, T_rand, n,\n a_y=a_y, b_y=b_y)\n\nax_M.scatter(x_rescaled, T_rand_rescaled, label='T_rand', alpha=.5,\n s=20, color='orange')\n\nax_M.legend(loc='upper left', fontsize=13)"
55+
]
56+
}
57+
],
58+
"metadata": {
59+
"kernelspec": {
60+
"display_name": "Python 3",
61+
"language": "python",
62+
"name": "python3"
63+
},
64+
"language_info": {
65+
"codemirror_mode": {
66+
"name": "ipython",
67+
"version": 3
68+
},
69+
"file_extension": ".py",
70+
"mimetype": "text/x-python",
71+
"name": "python",
72+
"nbconvert_exporter": "python",
73+
"pygments_lexer": "ipython3",
74+
"version": "3.10.14"
75+
}
76+
},
77+
"nbformat": 4,
78+
"nbformat_minor": 0
79+
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)