-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_fusion_model.py
105 lines (87 loc) · 4.5 KB
/
test_fusion_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""测试融合网络"""
import argparse
import os
import random
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from data_loader.msrs_data import MSRS_data
from models.common import YCrCb2RGB, RGB2YCrCb, clamp
from models.fusion_model import PIAFusion
from utils.metirc import Evaluator
def init_seeds(seed=0):
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
# cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
import torch.backends.cudnn as cudnn
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if args.cuda:
torch.cuda.manual_seed(seed)
cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch PIAFusion')
parser.add_argument('--dataset_path', metavar='DIR', default='test_data/MSRS',
help='path to dataset (default: imagenet)') # 测试数据存放位置
parser.add_argument('-a', '--arch', metavar='ARCH', default='fusion_model',
choices=['fusion_model'])
parser.add_argument('--save_path', default='results/fusion') # 融合结果存放位置
parser.add_argument('-j', '--workers', default=1, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--fusion_pretrained', default='pretrained/fusion_model_epoch_31.pth',
help='use cls pre-trained model')
parser.add_argument('--seed', default=0, type=int,
help='seed for initializing training. ')
parser.add_argument('--cuda', default=True, type=bool,
help='use GPU or not.')
args = parser.parse_args()
init_seeds(args.seed)
test_dataset = MSRS_data(args.dataset_path)
test_loader = DataLoader(
test_dataset, batch_size=1, shuffle=False,
num_workers=args.workers, pin_memory=True)
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
# 如果是融合网络
if args.arch == 'fusion_model':
model = PIAFusion()
model = model.cuda()
model.encoder = torch.nn.DataParallel(model.encoder)
model.decoder = torch.nn.DataParallel(model.decoder)
model.load_state_dict(torch.load(args.fusion_pretrained))
model.eval()
metric_result = np.zeros((8))
test_tqdm = tqdm(test_loader, total=len(test_loader))
with torch.no_grad():
for _, vis_y_image, cb, cr, inf_image, name in test_tqdm:
vis_y_image = vis_y_image.cuda()
cb = cb.cuda()
cr = cr.cuda()
inf_image = inf_image.cuda()
# 测试转为Ycbcr的数据再转换回来的输出效果,结果与原图一样,说明这两个函数是没有问题的。
# t = YCbCr2RGB2(vis_y_image[0], cb[0], cr[0])
# transforms.ToPILImage()(t).save(name[0])
fused_image = model(vis_y_image, inf_image)
fused_image = clamp(fused_image)
rgb_fused_image = YCrCb2RGB(fused_image[0], cb[0], cr[0])
rgb_fused_image = transforms.ToPILImage()(rgb_fused_image)
rgb_fused_image.save(f'{args.save_path}/{name[0]}')
vis_imgs = vis_y_image.cpu().numpy()
ir_imgs = inf_image.cpu().numpy()
fused_image=fused_image.cpu().detach().numpy()
fuse_img = 255*(fused_image-np.min(fused_image))/(np.max(fused_image) - np.min(fused_image))
fuse_img=np.uint8(np.round(fuse_img).astype(int))[0][0]
vis_imgs = 255*(vis_imgs[0][0])
ir_imgs = 255*(ir_imgs[0][0])
metric_result += np.round(np.array([
Evaluator.EN(fuse_img),
Evaluator.SD(fuse_img),
Evaluator.SF(fuse_img),
Evaluator.MI(fuse_img, ir_imgs, vis_imgs),
Evaluator.SCD(fuse_img, ir_imgs, vis_imgs),
Evaluator.VIFF(fuse_img, ir_imgs, vis_imgs),
Evaluator.Qabf(fuse_img, ir_imgs, vis_imgs),
Evaluator.SSIM(fuse_img, ir_imgs, vis_imgs)]), 3)
print(metric_result/len(test_loader))