Skip to content

Commit

Permalink
[MRG] Tiny fix in SSNB (#535)
Browse files Browse the repository at this point in the history
* added functions to a new mapping module

* simplify ssnb function structure

* SSNB example

* removed numpy saves from example for prod

* tests apart from the import exception catch

* tests apart from the import exception catch

* da class and tests

* guessed PR number

* removed unused import

* PEP8 tab errors fix

* skip ssnb test if no cvxpy

* test and doc fixes

* doc dependency + minor comment in ot __init__.py

* PEP8 fixes

* test typo fix

* ssnb da backend test fix

* moved joint ot mappings to the mapping module

* better ssnb example + ssnb initilisation + small joint_ot_mapping tests

* better ssnb example + ssnb initilisation + small joint_ot_mapping tests

* removed unused dependency in example

* no longer import mapping in __init__ + example thumbnail fix + made qcqp_constants function private

* merge with POT main

* fix barycentric projection factor omission in SSNB solver init

* added modif in RELEASES.md

* fix PR number in RELEASES.md

* broadcast fix
  • Loading branch information
eloitanguy authored Oct 18, 2023
1 parent 8a4a5a6 commit 57eda61
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 17 deletions.
2 changes: 1 addition & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## 0.9.2dev

#### New features
+ Added support for [Nearest Brenier Potentials (SSNB)](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) (PR #526)
+ Added support for [Nearest Brenier Potentials (SSNB)](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) (PR #526) + minor fix (PR #535)
+ Tweaked `get_backend` to ignore `None` inputs (PR #525)
+ Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507)
+ The `linspace` method of the backends now has the `type_as` argument to convert to the same dtype and device. (PR #533)
Expand Down
16 changes: 1 addition & 15 deletions examples/others/plot_SSNB.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
# Author: Eloi Tanguy <[email protected]>
# License: MIT License

# sphinx_gallery_thumbnail_number = 4
# sphinx_gallery_thumbnail_number = 3

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -63,20 +63,6 @@
plt.legend(loc='upper right')
plt.show()

# %%
# Plotting image of barycentric projection (SSNB initialisation values)
plt.clf()
pi = ot.emd(ot.unif(n_fitting_samples), ot.unif(n_fitting_samples), ot.dist(Xs, Xt))
plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source')
plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target')
bar_img = pi @ Xt
for i in range(n_fitting_samples):
plt.plot([Xs[i, 0], bar_img[i, 0]], [Xs[i, 1], bar_img[i, 1]], color='black', alpha=.5)
plt.title('Images of in-data source samples by the barycentric map')
plt.legend(loc='upper right')
plt.axis('equal')
plt.show()

# %%
# Fitting the Nearest Brenier Potential
L = 3 # need L > 2 to allow the 2*y term, default is 1.4
Expand Down
2 changes: 1 addition & 1 deletion ot/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly
if init_method == 'target':
G_val = V
else: # Init G_val with barycentric projection
G_val = emd(a, b, dist(X, V)) @ V
G_val = emd(a, b, dist(X, V)) @ V / a.reshape(n, 1)
phi_val = None
log_dict = {
'G_list': [],
Expand Down

0 comments on commit 57eda61

Please sign in to comment.