-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest_reconstruction_gray.py
53 lines (42 loc) · 1.93 KB
/
test_reconstruction_gray.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# 测试灰度图像重构结果
from net_gray import *
import utils
import cv2
import torch
def main():
with torch.no_grad():
for m in range(1,21):
input_image_dir = './images/test_reconstruction_images/COCO_' + str(m) + '.jpg'
channel = 1
[input_image, Height, Width] = utils.read_image(input_image_dir)
input_image = input_image.reshape(1, 1, input_image.shape[0], input_image.shape[1])
i = 0
# model_dir = './gray/batch_4_lr_1e-4_80000_gray/model_gray/scale_'+ str(i) +'_final_model.model'
model_dir = './gray/model_gray/scale_'+ str(i) +'_final_model.model'
generator = GenerativeNet(1,1)
generator.load_state_dict(torch.load(model_dir))
# params = list(generator.named_parameters())
# for parameters in generator.parameters():
# print(parameters)
generator.eval()
generator.cuda()
temp_img = input_image
input = torch.from_numpy(temp_img)
input = input.float()
input = input.cuda()
input_ = input
out_image = generator(input_)
result = (out_image - torch.min(out_image)) / (torch.max(out_image) - torch.min(out_image) + args.EPSILON)
result = result * 255
temp_generative_image = result.cpu()
temp_generative_image = temp_generative_image.numpy() # ndarray float32
temp_generative_image = temp_generative_image.astype(np.uint8)
temp_generative_image = torch.from_numpy(temp_generative_image) # tensor
temp = temp_generative_image.view(temp_generative_image.shape[2], temp_generative_image.shape[3], -1)
temp = temp.squeeze()
temp = temp.numpy()
# save_path = './gray/Res2NetFuse_reconstruction_80000/' + str(m) + '_res2netfuse_80000.jpg'
save_path = './gray/Res2NetFuse_reconstruction/' + str(m) + '_res2netfuse_a_image.jpg'
cv2.imwrite(save_path, temp)
if __name__ == '__main__':
main()