forked from Atcold/NYU-DLSP20
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplot_lib.py
76 lines (60 loc) · 2.33 KB
/
plot_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from matplotlib import pyplot as plt
import numpy as np
import torch
def set_default(figsize=(10, 10)):
plt.style.use(['dark_background', 'bmh'])
plt.rc('axes', facecolor='k')
plt.rc('figure', facecolor='k')
plt.rc('figure', figsize=figsize)
def plot_data(X, y, d=0, auto=False, zoom=1):
X = X.cpu()
y = y.cpu()
plt.scatter(X.numpy()[:, 0], X.numpy()[:, 1], c=y, s=20, cmap=plt.cm.Spectral)
plt.axis('square')
plt.axis(np.array((-1.1, 1.1, -1.1, 1.1)) * zoom)
if auto is True: plt.axis('equal')
plt.axis('off')
_m, _c = 0, '.15'
plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0)
plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0)
def plot_model(X, y, model):
model.cpu()
mesh = np.arange(-1.1, 1.1, 0.01)
xx, yy = np.meshgrid(mesh, mesh)
with torch.no_grad():
data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float()
Z = model(data).detach()
Z = np.argmax(Z, axis=1).reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3)
plot_data(X, y)
def show_scatterplot(X, colors, title=''):
colors = colors.cpu().numpy()
X = X.cpu().numpy()
plt.figure()
plt.axis('equal')
plt.scatter(X[:, 0], X[:, 1], c=colors, s=30)
# plt.grid(True)
plt.title(title)
plt.axis('off')
def plot_bases(bases, width=0.04):
bases = bases.cpu()
bases[2:] -= bases[:2]
plt.arrow(*bases[0], *bases[2], width=width, color=(1,0,0), zorder=10, alpha=1., length_includes_head=True)
plt.arrow(*bases[1], *bases[3], width=width, color=(0,1,0), zorder=10, alpha=1., length_includes_head=True)
def show_mat(mat, vect, prod, threshold=-1):
# Subplot grid definition
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=False, sharey=True,
gridspec_kw={'width_ratios':[5,1,1]})
# Plot matrices
cax1 = ax1.matshow(mat.numpy(), clim=(-1, 1))
ax2.matshow(vect.numpy(), clim=(-1, 1))
cax3 = ax3.matshow(prod.numpy(), clim=(threshold, 1))
# Set titles
ax1.set_title(f'A: {mat.size(0)} \u00D7 {mat.size(1)}')
ax2.set_title(f'a^(i): {vect.numel()}')
ax3.set_title(f'p: {prod.numel()}')
# Plot colourbars
fig.colorbar(cax1, ax=ax2)
fig.colorbar(cax3, ax=ax3)
# Fix y-axis limits
ax1.set_ylim(bottom=max(len(prod), len(vect)) - 0.5)