You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: master/_downloads/0eb518a5cb2af5a1870fc1f120d1483e/plot_partial_wass_and_gromov.ipynb
+1-1
Original file line number
Diff line number
Diff line change
@@ -4,7 +4,7 @@
4
4
"cell_type": "markdown",
5
5
"metadata": {},
6
6
"source": [
7
-
"\n# Partial Wasserstein and Gromov-Wasserstein example\n\nThis example is designed to show how to use the Partial (Gromov-)Wasserstein\ndistance computation in POT.\n"
7
+
"\n# Partial Wasserstein and Gromov-Wasserstein example\n\nThis example is designed to show how to use the Partial (Gromov-)Wasserstein\ndistance computation in POT [29].\n\n[29] Chapel, L., Alaya, M., Gasso, G. (2020). \"Partial Optimal\nTransport with Applications on Positive-Unlabeled Learning\". NeurIPS.\n"
"\n# Plot partial FGW for subgraph matching\n\nThis example illustrates the computation of partial (Fused) Gromov-Wasserstein\ndivergences for subgraph matching tasks, using the exact formulation $p(F)GW$ and\nthe entropically regularized one $p(F)GW_e$ [18, 29].\n\nWe first create a clean circular graph of 15 nodes with node features correlated with\nnode positions on the unit circle, and a noisy version where 5 nodes out of the\ncircle are added. Then knowing the proportion of clean samples in the target graph\n$m=3/4$, we show how to identify them using :\n - The partial GW matching and its entropic counterpart, omitting node features.\n - The partial Fused GW matching and its entropic counterpart.\n\n[18] Vayer Titouan, Chapel Laetitia, Flamary R\u00e9mi, Tavenard Romain\nand Courty Nicolas\n\"Optimal Transport for structured data with application on graphs\"\nInternational Conference on Machine Learning (ICML). 2019.\n\n[29] Chapel, L., Alaya, M., Gasso, G. (2020). \"Partial Optimal\nTransport with Applications on Positive-Unlabeled Learning\". NeurIPS.\n"
"import numpy as np\nimport pylab as pl\nimport networkx as nx\nimport math\nfrom scipy.sparse.csgraph import shortest_path\nimport matplotlib.colors as mcol\nfrom matplotlib import cm\nfrom ot.gromov import (\n partial_gromov_wasserstein,\n entropic_partial_gromov_wasserstein,\n partial_fused_gromov_wasserstein,\n entropic_partial_fused_gromov_wasserstein,\n)\nfrom ot import unif, dist"
30
+
]
31
+
},
32
+
{
33
+
"cell_type": "markdown",
34
+
"metadata": {},
35
+
"source": [
36
+
"## Utils for generation and visualization\n\n"
37
+
]
38
+
},
39
+
{
40
+
"cell_type": "code",
41
+
"execution_count": null,
42
+
"metadata": {
43
+
"collapsed": false
44
+
},
45
+
"outputs": [],
46
+
"source": [
47
+
"def build_noisy_circular_graph(n_clean=15, n_noise=5, random_seed=0):\n \"\"\"Create a noisy circular graph\"\"\"\n # create clean circle\n np.random.seed(random_seed)\n g = nx.Graph()\n g.add_nodes_from(np.arange(n_clean + n_noise))\n for i in range(n_clean):\n g.add_node(i, weight=math.sin(2 * i * math.pi / n_clean))\n if i == (n_clean - 1):\n g.add_edge(i, 0)\n else:\n g.add_edge(i, i + 1)\n # add nodes out of the circle as structure noise\n if n_noise > 0:\n noisy_nodes = np.random.choice(np.arange(n_clean), n_noise)\n for i, j in enumerate(noisy_nodes):\n g.add_node(i + n_clean, weight=math.sin(2 * j * math.pi / n_clean))\n g.add_edge(i + n_clean, j)\n g.add_edge(i + n_clean, (j + 1) % n_clean)\n return g\n\n\ndef graph_colors(nx_graph, vmin=0, vmax=7):\n cnorm = mcol.Normalize(vmin=vmin, vmax=vmax)\n cpick = cm.ScalarMappable(norm=cnorm, cmap=\"viridis\")\n cpick.set_array([])\n val_map = {}\n for k, v in nx.get_node_attributes(nx_graph, \"weight\").items():\n val_map[k] = cpick.to_rgba(v)\n colors = []\n for node in nx_graph.nodes():\n colors.append(val_map[node])\n return colors\n\n\ndef draw_graph(\n G,\n C,\n nodes_color_part,\n Gweights=None,\n pos=None,\n edge_color=\"black\",\n node_size=None,\n shiftx=0,\n):\n if pos is None:\n pos = nx.kamada_kawai_layout(G)\n\n if shiftx != 0:\n for k, v in pos.items():\n v[0] = v[0] + shiftx\n\n alpha_edge = 0.7\n width_edge = 1.8\n if Gweights is None:\n nx.draw_networkx_edges(\n G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color\n )\n else:\n # We make more visible connections between activated nodes\n n = len(Gweights)\n edgelist_activated = []\n edgelist_deactivated = []\n for i in range(n):\n for j in range(n):\n if Gweights[i] * Gweights[j] * C[i, j] > 0:\n edgelist_activated.append((i, j))\n elif C[i, j] > 0:\n edgelist_deactivated.append((i, j))\n\n nx.draw_networkx_edges(\n G,\n pos,\n edgelist=edgelist_activated,\n width=width_edge,\n alpha=alpha_edge,\n edge_color=edge_color,\n )\n nx.draw_networkx_edges(\n G,\n pos,\n edgelist=edgelist_deactivated,\n width=width_edge,\n alpha=0.1,\n edge_color=edge_color,\n )\n\n if Gweights is None:\n for node, node_color in enumerate(nodes_color_part):\n nx.draw_networkx_nodes(\n G,\n pos,\n nodelist=[node],\n node_size=node_size,\n alpha=1,\n node_color=node_color,\n )\n else:\n scaled_Gweights = Gweights / (0.5 * Gweights.max())\n nodes_size = node_size * scaled_Gweights\n for node, node_color in enumerate(nodes_color_part):\n if nodes_size[node] == 0:\n local_node_size = 0\n else:\n local_node_size = max(0.1 * node_size, nodes_size[node])\n nx.draw_networkx_nodes(\n G,\n pos,\n nodelist=[node],\n node_size=local_node_size,\n alpha=1,\n node_color=node_color,\n )\n return pos\n\n\ndef draw_transp_colored(\n G1,\n C1,\n G2,\n C2,\n p1,\n p2,\n T,\n pos1=None,\n pos2=None,\n shiftx=4,\n switchx=False,\n node_size=70,\n color_features=False,\n):\n if color_features:\n nodes_color_part1 = graph_colors(G1, vmin=-1, vmax=1)\n nodes_color_part2 = graph_colors(G2, vmin=-1, vmax=1)\n else:\n nodes_color_part1 = C1.shape[0] * [\"C0\"]\n nodes_color_part2 = C2.shape[0] * [\"C0\"]\n\n pos1 = draw_graph(\n G1,\n C1,\n nodes_color_part1,\n Gweights=p1,\n pos=pos1,\n node_size=node_size,\n shiftx=0,\n )\n pos2 = draw_graph(\n G2,\n C2,\n nodes_color_part2,\n Gweights=p2,\n pos=pos2,\n node_size=node_size,\n shiftx=shiftx,\n )\n T_max = T.max()\n for k1, v1 in pos1.items():\n for k2, v2 in pos2.items():\n if T[k1, k2] > 0:\n pl.plot(\n [pos1[k1][0], pos2[k2][0]],\n [pos1[k1][1], pos2[k2][1]],\n \"-\",\n lw=0.8,\n alpha=max(0.05, 0.8 * T[k1, k2] / T_max),\n color=nodes_color_part1[k1],\n )\n return pos1, pos2"
48
+
]
49
+
},
50
+
{
51
+
"cell_type": "markdown",
52
+
"metadata": {},
53
+
"source": [
54
+
"## Generate and visualize data\nWe build a clean circular graph that will be matched to a noisy circular graph.\n\n"
55
+
]
56
+
},
57
+
{
58
+
"cell_type": "code",
59
+
"execution_count": null,
60
+
"metadata": {
61
+
"collapsed": false
62
+
},
63
+
"outputs": [],
64
+
"source": [
65
+
"clean_graph = build_noisy_circular_graph(n_clean=15, n_noise=0)\n\nnoisy_graph = build_noisy_circular_graph(n_clean=15, n_noise=5)\n\ngraphs = [clean_graph, noisy_graph]\nlist_pos = []\npl.figure(figsize=(6, 3))\nfor i in range(2):\n pl.subplot(1, 2, i + 1)\n g = graphs[i]\n if i == 0:\n pl.title(\"clean graph\", fontsize=16)\n else:\n pl.title(\"noisy graph\", fontsize=16)\n pos = nx.kamada_kawai_layout(g)\n list_pos.append(pos)\n nx.draw_networkx(\n g,\n pos=pos,\n node_color=graph_colors(g, vmin=-1, vmax=1),\n with_labels=False,\n node_size=100,\n )\npl.show()"
66
+
]
67
+
},
68
+
{
69
+
"cell_type": "markdown",
70
+
"metadata": {},
71
+
"source": [
72
+
"## Partial (Entropic) Gromov-Wasserstein computation and visualization\nAdjacency matrices are compared using both exact and entropic partial GW\ndiscarding for now node features.\nThen for illustration, the node sizes are proportional to their optimized masses\nand the intensity of the link between two nodes across graphs is set proportionally\nto the corresponding transported mass.\n\n"
73
+
]
74
+
},
75
+
{
76
+
"cell_type": "code",
77
+
"execution_count": null,
78
+
"metadata": {
79
+
"collapsed": false
80
+
},
81
+
"outputs": [],
82
+
"source": [
83
+
"Cs = [nx.adjacency_matrix(G).toarray().astype(np.float64) for G in graphs]\nps = [unif(C.shape[0]) for C in Cs]\n\n# provide an informative initialization for better visualization\nm = 3.0 / 4.0\npartial_id = np.zeros((15, 20))\npartial_id[:15, :15] = np.eye(15) / 15.0\nG0 = (np.outer(ps[0], ps[1]) + partial_id) * m / 2\n\n# compute exact partial GW\nT, log = partial_gromov_wasserstein(\n Cs[0], Cs[1], ps[0], ps[1], m=m, G0=G0, symmetric=True, log=True\n)\n\n# compute entropic partial GW leading to dense transport plans\nTent, logent = entropic_partial_gromov_wasserstein(\n Cs[0], Cs[1], ps[0], ps[1], reg=0.01, m=m, G0=G0, symmetric=True, log=True\n)\n\n# Plot matchings\nlist_T = [T, Tent]\nlist_dist = [\n np.round(log[\"partial_gw_dist\"], 3),\n np.round(logent[\"partial_gw_dist\"], 3),\n]\nlist_dist_str = [\"pGW\", \"pGW_e\"]\n\npl.figure(2, figsize=(10, 3))\npl.clf()\nfor i in range(2):\n pl.subplot(1, 2, i + 1)\n pl.axis(\"off\")\n pl.title(\n r\"$%s(\\mathbf{C_1},\\mathbf{p_1}^\\star,\\mathbf{C_2},\\mathbf{p_2}^\\star) =%s$\"\n % (list_dist_str[i], list_dist[i]),\n fontsize=14,\n )\n\n p2 = list_T[i].sum(0)\n\n pos1, pos2 = draw_transp_colored(\n clean_graph,\n Cs[0],\n noisy_graph,\n Cs[1],\n p1=None,\n p2=p2,\n T=list_T[i],\n shiftx=3,\n node_size=50,\n )\n\npl.tight_layout()\npl.show()"
84
+
]
85
+
},
86
+
{
87
+
"cell_type": "markdown",
88
+
"metadata": {},
89
+
"source": [
90
+
"## Partial (Entropic) Fused Gromov-Wasserstein computation and visualization\nWe add now node features compared using pairwise euclidean distance\nto illustrate partial FGW computation with trade-off parameter alpha=0.5\n\n"
91
+
]
92
+
},
93
+
{
94
+
"cell_type": "code",
95
+
"execution_count": null,
96
+
"metadata": {
97
+
"collapsed": false
98
+
},
99
+
"outputs": [],
100
+
"source": [
101
+
"Ys = [\n np.array([v for (k, v) in nx.get_node_attributes(G, \"weight\").items()]).reshape(\n -1, 1\n )\n for G in graphs\n]\nM = dist(Ys[0], Ys[1])\n# provide an informative initialization for better visualization\nm = 3.0 / 4.0\npartial_id = np.zeros((15, 20))\npartial_id[:15, :15] = np.eye(15) / 15.0\nG0 = (np.outer(ps[0], ps[1]) + partial_id) * m / 2\n\n# compute exact partial GW\nT, log = partial_fused_gromov_wasserstein(\n M,\n Cs[0],\n Cs[1],\n ps[0],\n ps[1],\n alpha=0.5,\n m=m,\n G0=G0,\n symmetric=True,\n log=True,\n)\n\n# compute entropic partial GW leading to dense transport plans\nTent, logent = entropic_partial_fused_gromov_wasserstein(\n M,\n Cs[0],\n Cs[1],\n ps[0],\n ps[1],\n reg=0.01,\n alpha=0.5,\n m=m,\n G0=G0,\n symmetric=True,\n log=True,\n)\n\n# Plot matchings\nlist_T = [T, Tent]\nlist_dist = [\n np.round(log[\"partial_fgw_dist\"], 3),\n np.round(logent[\"partial_fgw_dist\"], 3),\n]\nlist_dist_str = [\"pFGW\", \"pFGW_e\"]\n\npl.figure(3, figsize=(10, 3))\npl.clf()\nfor i in range(2):\n pl.subplot(1, 2, i + 1)\n pl.axis(\"off\")\n pl.title(\n r\"$%s(\\mathbf{C_1},\\mathbf{p_1}^\\star,\\mathbf{C_2}, \\mathbf{p_2}^\\star) =%s$\"\n % (list_dist_str[i], list_dist[i]),\n fontsize=14,\n )\n\n p2 = list_T[i].sum(0)\n pos1, pos2 = draw_transp_colored(\n clean_graph,\n Cs[0],\n noisy_graph,\n Cs[1],\n p1=None,\n p2=p2,\n T=list_T[i],\n shiftx=3,\n node_size=50,\n color_features=True,\n )\n\npl.tight_layout()\npl.show()"
0 commit comments