-
Notifications
You must be signed in to change notification settings - Fork 77
/
Copy pathnoisy_circle.py
69 lines (58 loc) · 2.07 KB
/
noisy_circle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from __future__ import print_function
import torch, torch.nn as nn, numpy as np, matplotlib.pyplot as plt
from topologylayer.nn import LevelSetLayer2D, SumBarcodeLengths, PartialSumBarcodeLengths
# generate circle on grid
# generate circle on grid
n = 50
def circlefn(i, j, n):
r = np.sqrt((i - n/2.)**2 + (j - n/2.)**2)
return np.exp(-(r - n/3.)**2/(n*2))
def gen_circle(n):
beta = np.empty((n,n))
for i in range(n):
for j in range(n):
beta[i,j] = circlefn(i,j,n)
return beta
beta = gen_circle(n)
m = 1500
X = np.random.randn(m, n**2)
y = X.dot(beta.flatten()) + 0.05*np.random.randn(m)
beta_ols = (np.linalg.lstsq(X, y, rcond=None)[0]).reshape(n,n)
class TopLoss(nn.Module):
def __init__(self, size):
super(TopLoss, self).__init__()
self.pdfn = LevelSetLayer2D(size=size, sublevel=False)
self.topfn = PartialSumBarcodeLengths(dim=1, skip=1)
self.topfn2 = SumBarcodeLengths(dim=0)
def forward(self, beta):
dgminfo = self.pdfn(beta)
return self.topfn(dgminfo) + self.topfn2(dgminfo)
tloss = TopLoss((50,50)) # topology penalty
dloss = nn.MSELoss() # data loss
beta_t = torch.autograd.Variable(torch.tensor(beta_ols).type(torch.float), requires_grad=True)
X_t = torch.tensor(X, dtype=torch.float, requires_grad=False)
y_t = torch.tensor(y, dtype=torch.float, requires_grad=False)
optimizer = torch.optim.Adam([beta_t], lr=1e-2)
for i in range(500):
optimizer.zero_grad()
tlossi = tloss(beta_t)
dlossi = dloss(y_t, torch.matmul(X_t, beta_t.view(-1)))
loss = 0.1*tlossi + dlossi
loss.backward()
optimizer.step()
if (i % 10 == 0):
print(i, tlossi.item(), dlossi.item())
# save figure
beta_est = beta_t.detach().numpy()
fig, ax = plt.subplots(ncols=3, figsize=(15,5))
ax[0].imshow(beta)
ax[0].set_title("Truth")
ax[1].imshow(beta_ols)
ax[1].set_title("OLS")
ax[2].imshow(beta_est)
ax[2].set_title("Topology Regularization")
for i in range(3):
ax[i].set_yticklabels([])
ax[i].set_xticklabels([])
ax[i].tick_params(bottom=False, left=False)
plt.savefig('noisy_circle.png')