-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
28 lines (25 loc) · 906 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from matplotlib import pyplot as plt
import numpy as np
import torch.nn as nn
import torchvision.utils as vutils
# 初始化权重
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
def save_image(real_cpu, img_list, sample_path, device):
# 显示生成的图像
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_cpu.to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
# plt.show()
plt.savefig(sample_path + f'/showcase.png')