-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathgmm2d.py
305 lines (247 loc) · 11.1 KB
/
gmm2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
"""
Acknowledgement: function from Trajectron++
https://github.com/StanfordASL/Trajectron-plus-plus
"""
import torch
import torch.distributions as td
import torch.nn.functional as F
import numpy as np
# from model.model_utils import to_one_hot
def to_one_hot(labels, n_labels):
return torch.eye(n_labels, device=labels.device)[labels]
class GMM2D(td.Distribution):
r"""
Gaussian Mixture Model using 2D Multivariate Gaussians each of as N components:
Cholesky decompesition and affine transformation for sampling:
.. math:: Z \sim N(0, I)
.. math:: S = \mu + LZ
.. math:: S \sim N(\mu, \Sigma) \rightarrow N(\mu, LL^T)
where :math:`L = chol(\Sigma)` and
.. math:: \Sigma = \left[ {\begin{array}{cc} \sigma^2_x & \rho \sigma_x \sigma_y \\ \rho \sigma_x \sigma_y & \sigma^2_y \\ \end{array} } \right]
such that
.. math:: L = chol(\Sigma) = \left[ {\begin{array}{cc} \sigma_x & 0 \\ \rho \sigma_y & \sigma_y \sqrt{1-\rho^2} \\ \end{array} } \right]
:param log_pis: Log Mixing Proportions :math:`log(\pi)`. [..., N]
:param mus: Mixture Components mean :math:`\mu`. [..., N * 2]
:param log_sigmas: Log Standard Deviations :math:`log(\sigma_d)`. [..., N * 2]
:param corrs: Cholesky factor of correlation :math:`\rho`. [..., N]
:param clip_lo: Clips the lower end of the standard deviation.
:param clip_hi: Clips the upper end of the standard deviation.
"""
def __init__(self, log_pis, mus, log_sigmas, corrs):
super(GMM2D, self).__init__(
batch_shape=log_pis.shape[0], event_shape=log_pis.shape[1:]
)
self.components = log_pis.shape[-1]
self.dimensions = 2
self.device = log_pis.device
log_pis = torch.clamp(log_pis, min=-1e5)
self.log_pis = log_pis - torch.logsumexp(
log_pis, dim=-1, keepdim=True
) # [..., N]
self.mus = self.reshape_to_components(mus) # [..., N, 2]
self.log_sigmas = self.reshape_to_components(log_sigmas) # [..., N, 2]
# [..., N, 2]
self.sigmas = torch.exp(self.log_sigmas)
self.one_minus_rho2 = 1 - corrs ** 2 # [..., N]
self.one_minus_rho2 = torch.clamp(
self.one_minus_rho2, min=1e-5, max=1
) # otherwise log can be nan
# self.corrs = F.tanh(corrs) # [..., N]
self.corrs = corrs # [..., N]
self.L = torch.stack(
[
torch.stack(
[self.sigmas[..., 0], torch.zeros_like(self.log_pis).cuda()], dim=-1
),
torch.stack(
[
self.sigmas[..., 1] * self.corrs,
self.sigmas[..., 1] * torch.sqrt(self.one_minus_rho2),
],
dim=-1,
),
],
dim=-2,
)
self.pis_cat_dist = td.Categorical(logits=log_pis)
@classmethod
def from_log_pis_mus_cov_mats(cls, log_pis, mus, cov_mats):
corrs_sigma12 = cov_mats[..., 0, 1]
sigma_1 = torch.clamp(cov_mats[..., 0, 0], min=1e-8)
sigma_2 = torch.clamp(cov_mats[..., 1, 1], min=1e-8)
sigmas = torch.stack([torch.sqrt(sigma_1), torch.sqrt(sigma_2)], dim=-1)
log_sigmas = torch.log(sigmas)
corrs = corrs_sigma12 / (torch.prod(sigmas, dim=-1))
return cls(log_pis, mus, log_sigmas, corrs)
# I also think of this step !!!
def rsample(self, sample_shape=torch.Size()):
"""
Generates a sample_shape shaped reparameterized sample or sample_shape
shaped batch of reparameterized samples if the distribution parameters
are batched.
:param sample_shape: Shape of the samples
:return: Samples from the GMM.
"""
mvn_samples = self.mus + torch.squeeze(
torch.matmul(
self.L,
torch.unsqueeze(
torch.randn(
size=sample_shape + self.mus.shape, device=self.device
).cuda(),
dim=-1,
),
),
dim=-1,
)
# component_cat_samples = self.pis_cat_dist.sample(sample_shape)
# selector = torch.unsqueeze(to_one_hot(
# component_cat_samples, self.components), dim=-1)
# return torch.sum(mvn_samples*selector, dim=-2)
return mvn_samples
# def log_prob(self, value, mask):
# r"""
# Calculates the log probability of a value using the PDF for bivariate normal distributions:
# .. math::
# f(x | \mu, \sigma, \rho)={\frac {1}{2\pi \sigma _{x}\sigma _{y}{\sqrt {1-\rho ^{2}}}}}\exp
# \left(-{\frac {1}{2(1-\rho ^{2})}}\left[{\frac {(x-\mu _{x})^{2}}{\sigma _{x}^{2}}}+
# {\frac {(y-\mu _{y})^{2}}{\sigma _{y}^{2}}}-{\frac {2\rho (x-\mu _{x})(y-\mu _{y})}
# {\sigma _{x}\sigma _{y}}}\right]\right)
# :param value: The log probability density function is evaluated at those values.
# :return: Log probability
# """
# # x: [..., 2]
# value = torch.unsqueeze(value, dim=-2) # [..., 1, 2]
# dx = value - self.mus # [..., N, 2]
# exp_nominator = ((torch.sum((dx/self.sigmas)**2, dim=-1) # first and second term of exp nominator
# - 2*self.corrs*torch.prod(dx, dim=-1)/torch.prod(self.sigmas, dim=-1))) # [..., N]
# component_log_p = -(2*np.log(2*np.pi)
# + torch.log(self.one_minus_rho2)
# + 2*torch.sum(self.log_sigmas, dim=-1)
# + exp_nominator/self.one_minus_rho2) / 2
# # How to deal with this
# # apply mask to loss?
# component_log_p = torch.einsum(
# 'ntv, ntv->ntv', component_log_p.squeeze(), mask)
# # To make the log value of padded nodes as zero;
# component_log_p[mask == 0] = -9e15 * torch.ones(1).cuda()
# """
# For single bivariate distribution, log_pis is all zero, so cancel it;
# component_log_p has shape [batch, seq_len, num_ped]
# """
# # return torch.logsumexp(self.log_pis.cuda().squeeze() + component_log_p.cuda(), dim=-1)
# rst = -torch.logsumexp(component_log_p.cuda(), dim=-1)
# return rst
# # return component_log_p
def log_prob(self, value, mask):
r"""
Calculates the log probability of a value using the PDF for bivariate normal distributions:
.. math::
f(x | \mu, \sigma, \rho)={\frac {1}{2\pi \sigma _{x}\sigma _{y}{\sqrt {1-\rho ^{2}}}}}\exp
\left(-{\frac {1}{2(1-\rho ^{2})}}\left[{\frac {(x-\mu _{x})^{2}}{\sigma _{x}^{2}}}+
{\frac {(y-\mu _{y})^{2}}{\sigma _{y}^{2}}}-{\frac {2\rho (x-\mu _{x})(y-\mu _{y})}
{\sigma _{x}\sigma _{y}}}\right]\right)
:param value: The log probability density function is evaluated at those values.
:return: Log probability
"""
epsilon = 1e-20
# x: [..., 2]
value = torch.unsqueeze(value, dim=-2) # [..., 1, 2]
dx = value - self.mus # [..., N, 2]
exp_nominator = torch.sum(
(dx / self.sigmas) ** 2, dim=-1
) - 2 * self.corrs * torch.prod( # first and second term of exp nominator
dx, dim=-1
) / torch.prod(
self.sigmas, dim=-1
) # [..., N]
component_log_p = torch.exp(-exp_nominator / (2 * self.one_minus_rho2))
component_log_p_denom = (
2
* np.pi
* (torch.prod(self.sigmas, dim=-1) * torch.sqrt(self.one_minus_rho2))
)
component_log_p = component_log_p / component_log_p_denom
# apply mask to loss
if mask is not None:
component_log_p = torch.einsum(
"ntv, ntv->ntv", component_log_p.squeeze(), mask
)
# To make the log value of padded nodes as zero
component_log_p[mask == 0] = 1.0
# component_log_p = torch.log(component_log_p)
component_log_p = -torch.log(torch.clamp(component_log_p, min=epsilon))
return component_log_p
def get_for_node_at_time(self, n, t):
return self.__class__(
self.log_pis[:, n : n + 1, t : t + 1],
self.mus[:, n : n + 1, t : t + 1],
self.log_sigmas[:, n : n + 1, t : t + 1],
self.corrs[:, n : n + 1, t : t + 1],
)
def mode(self):
"""
Calculates the mode of the GMM by calculating probabilities of a 2D mesh grid
:param required_accuracy: Accuracy of the meshgrid
:return: Mode of the GMM
"""
if self.mus.shape[-2] > 1:
samp, bs, time, comp, _ = self.mus.shape
assert samp == 1, "For taking the mode only one sample makes sense."
mode_node_list = []
for n in range(bs):
mode_t_list = []
for t in range(time):
nt_gmm = self.get_for_node_at_time(n, t)
x_min = self.mus[:, n, t, :, 0].min()
x_max = self.mus[:, n, t, :, 0].max()
y_min = self.mus[:, n, t, :, 1].min()
y_max = self.mus[:, n, t, :, 1].max()
search_grid = (
torch.stack(
torch.meshgrid(
[
torch.arange(x_min, x_max, 0.01),
torch.arange(y_min, y_max, 0.01),
]
),
dim=2,
)
.view(-1, 2)
.float()
.to(self.device)
)
ll_score = nt_gmm.log_prob(search_grid)
argmax = torch.argmax(ll_score.squeeze(), dim=0)
mode_t_list.append(search_grid[argmax])
mode_node_list.append(torch.stack(mode_t_list, dim=0))
return torch.stack(mode_node_list, dim=0).unsqueeze(dim=0)
return torch.squeeze(self.mus, dim=-2)
def reshape_to_components(self, tensor):
if len(tensor.shape) == 5:
return tensor
# new_shape = list(tensor.shape[:-1]) + [self.components, self.dimensions]
return torch.reshape(
tensor, list(tensor.shape[:-1]) + [self.components, self.dimensions]
)
def get_covariance_matrix(self):
cov = self.corrs * torch.prod(self.sigmas, dim=-1)
E = torch.stack(
[
torch.stack([self.sigmas[..., 0] ** 2, cov], dim=-1),
torch.stack([cov, self.sigmas[..., 1] ** 2], dim=-1),
],
dim=-2,
)
return E
if __name__ == "__main__":
log_pis = torch.ones(32, 1).cuda()
mus = torch.ones(32, 1 * 2).cuda()
log_sigma = torch.ones(32, 1 * 2).cuda() * 0.5
corr = torch.ones(32, 1).cuda() * 0.7
gmm2d = GMM2D(log_pis, mus, log_sigma, corr)
print(gmm2d.rsample())
print(gmm2d.rsample().shape)
import pdb
pdb.set_trace()
print(torch.mean(gmm2d.log_prob(torch.ones(mus.shape).cuda() * 0.66)))