-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
142 lines (124 loc) · 4.64 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vision_utils
import seaborn as sns
import pandas as pd
from data import reverse_transform
def cosine_beta_schedule(timesteps, s=0.008):
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
def plot_batch(ax, batch, title=None, **kwargs):
r_transform = reverse_transform()
imgs = vision_utils.make_grid(batch, padding=2, normalize=True)
imgs = r_transform(imgs)
ax.set_axis_off()
if title is not None: ax.set_title(title)
return ax.imshow(imgs, **kwargs)
def save_images(batch, title):
batch_size = batch.shape[0]
row = int(np.sqrt(batch_size))
col = batch_size // row
fig = plt.figure(figsize=(row, col))
ax = fig.add_subplot(111)
plot_batch(ax, batch.cpu(), title)
file_name = title + '_generated_images.png'
plt.savefig(fname=file_name)
def save_image_seqs(batch, title):
row = batch[0].shape[0]
col = len(batch)
fig = plt.figure(figsize=(row, col))
for i, img_seq in enumerate(batch):
ax = fig.add_subplot(col, 1, i+1)
plot_batch(ax, img_seq.cpu())
file_name = title + '_generated_sequential_images.png'
plt.savefig(fname=file_name)
def plot_seqs(imgs, with_orig=False, row_title=None, **imshow_kwargs):
if not isinstance(imgs[0], list):
imgs = [imgs]
num_rows = len(imgs)
num_cols = len(imgs[0]) + with_orig
fig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)
for row_idx, row in enumerate(imgs):
row = [image] + row if with_orig else row
for col_idx, img in enumerate(row):
ax = axs[row_idx, col_idx]
ax.imshow(np.asarray(img), **imshow_kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if with_orig:
axs[0, 0].set(title='Original image')
axs[0, 0].title.set_size(8)
if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])
plt.tight_layout()
def create_img_seq(imgs_list):
"""convert imgs_list(seq, tensor) to (b, seq)"""
batch_len = imgs_list[0].shape[0]
img_seq = [[seq[batch_index].cpu() for seq in imgs_list] for batch_index in range(batch_len)]
return img_seq
def plot_img_seq(img_seq, **kwargs):
"""plot (b, seq) imgs"""
img_list = [[tensor_to_sample(img.cpu()) for img in seq] for seq in create_img_seq(img_seq)]
row_num = len(img_list)
col_num = len(img_list[0])
fig, axs = plt.subplots(figsize=(10, 10), nrows=row_num, ncols=col_num, squeeze=True)
for idx_row, batch in enumerate(img_list):
for idx_col, img in enumerate(batch):
idx_num = col_num * idx_row + idx_col + 1
ax = axs[idx_row, idx_col]
ax.imshow(img, **kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
plt.tight_layout()
plt.show()
# The Fourier transform function
def spectrum(sample):
f = torch.fft.fft2(sample, norm='ortho')
fshift = torch.fft.fftshift(f)
magnitude_spectrum = torch.log(torch.abs(fshift))
return magnitude_spectrum
def average_psd(tensor):
a = torch.fft.rfft(tensor, axis=0)
a = a.real ** 2 + a.imag ** 2
a = torch.sum(a, axis=1) / a.shape[1]
f = torch.fft.rfftfreq(tensor[0].shape[0])
return f.numpy(), a.numpy()
def plot_psds(imgs, time):
sns.set_style('ticks')
for (i, psd_image) in enumerate(imgs):
data = {'frequency': psd_image[0],
'amplitude': psd_image[1]
}
df = pd.DataFrame(data)
sns.lineplot(x='frequency', y='amplitude', data=df, label=f'noise level at: {time[i]}')
plt.yscale('log')
plt.legend()
plt.grid(True)
sns.despine()
plt.show()
def plot_psds_mnist(imgs, time):
sns.set_style('ticks')
for (i, psd_image) in enumerate(imgs):
data = {'frequency': psd_image[0],
'amplitude': psd_image[1][0][:len(psd_image[0])]
}
df = pd.DataFrame(data)
sns.lineplot(x='frequency', y='amplitude', data=df, label=f'noise level at: {time[i]}')
plt.yscale('log')
plt.legend()
plt.grid(True)
sns.despine()
plt.show()
def average_psd_mnist(tensor):
a = torch.fft.rfft(tensor, axis=1)
a = a.real ** 2 + a.imag ** 2
a = torch.sum(a, axis=2) / a.shape[0]
f = torch.fft.rfftfreq(tensor[0].shape[0])
return f.numpy(), a.numpy()
if __name__ == "__main__":
x = torch.randn(size=(5, 3, 64, 64))
y = torch.randn(size=(5, 3, 64, 64))
img_list = [x, y]
save_image_seqs(img_list, "test")