forked from cgpotts/cs224u
-
Notifications
You must be signed in to change notification settings - Fork 0
/
retrofitting.py
160 lines (133 loc) · 4.91 KB
/
retrofitting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.spatial.distance import euclidean
import utils
__author__ = "Christopher Potts"
__version__ = "CS224u, Stanford, Fall 2020"
class Retrofitter(object):
"""
Implements the baseline retrofitting method of Faruqui et al.
Parameters
----------
max_iter : int indicating the maximum number of iterations to run.
alpha : func from `edges.keys()` to floats or None
beta : func from `edges.keys()` to floats or None
tol : float
If the average distance change between two rounds is at or
below this value, we stop. Default to 10^-2 as suggested
in the paper.
verbose : bool
Whether to print information about the optimization process.
introspecting : bool
Whether to accumulate a list of the retrofitting matrices
at each step. This should be set to `True` only for small
illustrative tasks. For large ones, it will impose huge
memory demands.
"""
def __init__(self, max_iter=100, alpha=None, beta=None, tol=1e-2,
verbose=False, introspecting=False):
self.max_iter = max_iter
self.alpha = alpha
self.beta = beta
self.tol = tol
self.verbose = verbose
self.introspecting = introspecting
def fit(self, X, edges):
"""
The core internal retrofitting method.
Parameters
----------
X : np.array (distributional embeddings)
edges : dict
Mapping indices into `X` into sets of indices into `X`.
Attributes
----------
self.Y : np.array, same dimensions and arrangement as `X`.
The retrofitting matrix.
self.all_Y : list
Set only if `self.introspecting=True`.
Returns
-------
self
"""
index = None
columns = None
if isinstance(X, pd.DataFrame):
index = X.index
columns = X.columns
X = X.values
if self.alpha is None:
self.alpha = lambda x: 1.0
if self.beta is None:
self.beta = lambda x: 1.0 / len(edges[x])
if self.introspecting:
self.all_Y = []
Y = X.copy()
Y_prev = Y.copy()
for iteration in range(1, self.max_iter+1):
for i, vec in enumerate(X):
neighbors = edges[i]
n_neighbors = len(neighbors)
if n_neighbors:
a = self.alpha(i)
b = self.beta(i)
retro = np.array([b * Y[j] for j in neighbors])
retro = retro.sum(axis=0) + (a * X[i])
norm = np.array([b for j in neighbors])
norm = norm.sum(axis=0) + a
Y[i] = retro / norm
changes = self._measure_changes(Y, Y_prev)
if changes <= self.tol:
self._progress_bar(
"Converged at iteration {}; change was {:.4f} ".format(
iteration, changes))
break
else:
if self.introspecting:
self.all_Y.append(Y.copy())
Y_prev = Y.copy()
self._progress_bar(
"Iteration {:d}; change was {:.4f}".format(
iteration, changes))
if index is not None:
Y = pd.DataFrame(Y, index=index, columns=columns)
self.Y = Y
return self.Y
@staticmethod
def _measure_changes(Y, Y_prev):
return np.abs(
np.mean(
np.linalg.norm(
np.squeeze(Y_prev) - np.squeeze(Y),
ord=2)))
def _progress_bar(self, msg):
if self.verbose:
utils.progress_bar(msg)
def plot_retro_vsm(Q, edges, ax=None, lims=None):
ax = Q.plot.scatter(x=0, y=1, ax=ax)
if lims is not None:
ax.set_xlim(lims)
ax.set_ylim(lims)
_ = Q.apply(lambda x: ax.text(x[0], x[1], x.name, fontsize=18), axis=1)
for i, vals in edges.items():
for j in vals:
x0, y0 = Q.iloc[i].values
x1, y1 = (Q.iloc[j] - Q.iloc[i]) * 0.9
ax.arrow(x0, y0, x1, y1, head_width=0.05, head_length=0.05)
return ax
def plot_retro_path(Q_hat, edges, retrofitter=None):
if retrofitter is None:
retrofitter = Retrofitter(introspecting=True)
retrofitter.introspecting = True
retrofitter.fit(Q_hat, edges)
all_Y = retrofitter.all_Y
lims = [Q_hat.values.min()-0.1, Q_hat.values.max()+0.1]
n_steps = len(all_Y)
fig, axes = plt.subplots(nrows=1, ncols=n_steps+1, figsize=(12, 4), squeeze=False)
plot_retro_vsm(Q_hat, edges, axes[0][0], lims=lims)
for Q, ax in zip(all_Y, axes[0][1: ]):
Q = pd.DataFrame(Q, index=Q_hat.index, columns=Q_hat.columns)
ax = plot_retro_vsm(Q, edges, ax=ax, lims=lims)
plt.tight_layout()
return retrofitter