-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_auto_encoder.py
257 lines (221 loc) · 9.57 KB
/
train_auto_encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
# Training DenseFuse network
# auto-encoder
import os
import sys
import time
import numpy as np
from tqdm import tqdm, trange
import scipy.io as scio
import random
import torch
from torch.optim import Adam
from torch.autograd import Variable
import utils
from model_SFPFusion import MODEL,WaveDecoder
from args_fusion import args
import pytorch_msssim
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main():
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
original_imgs_path = utils.list_images(args.dataset)#路径
train_num = 40000
original_imgs_path = original_imgs_path[:train_num]
random.shuffle(original_imgs_path)#打乱顺序
# for i in range(5):
i = 1
train(i, original_imgs_path)#训练
def train(i, original_imgs_path):
batch_size = args.batch_size#batch_size = 4
# load network model, RGB
in_c = 1 # channel:1 - gray; 3 - RGB
if in_c == 1:
img_model = 'L'
else:
img_model = 'RGB'
input_nc = in_c
output_nc = in_c
AE_Encoder = MODEL(embed_dim=[64, 128, 320, 512], # 25M, 4.4G, 677FPS
depths=[3, 5, 9, 3],
num_heads=[1, 2, 5, 8],
n_iter=[1, 1, 1, 1],
stoken_size=[8, 4, 1, 1],
projection=1024,
mlp_ratio=4,
stoken_refine=True,
stoken_refine_attention=True,
hard_label=False,
rpe=False,
qkv_bias=True,
qk_scale=None,
use_checkpoint=False,
checkpoint_num=[0, 0, 0, 0],
layerscale=[False] * 4,
init_values=1e-6,option_unpool='sum')
AE_Decoder = WaveDecoder('sum')
print(AE_Encoder)
print(AE_Decoder)
optimizer1 = Adam(AE_Encoder.parameters(), args.lr)#自适应矩估计优化器
optimizer2 = Adam(AE_Decoder.parameters(), args.lr) # 自适应矩估计优化器
mse_loss = torch.nn.MSELoss()#损失函数
ssim_loss = pytorch_msssim.msssim
if args.cuda:
AE_Encoder = AE_Encoder.cuda()
AE_Decoder = AE_Decoder.cuda() # densefuse_model.to(device)#模型训练使用cuda还是本机
tbar = trange(args.epochs)#实现进度条
print('Start training.....')
# creating save path
temp_path_model = os.path.join(args.save_model_dir, args.ssim_path[i])
if not os.path.exists(temp_path_model):
os.makedirs(temp_path_model)
temp_path_loss = os.path.join(args.save_loss_dir, args.ssim_path[i])
if not os.path.exists(temp_path_loss):
os.makedirs(temp_path_loss)
Loss_pixel = []
Loss_ssim = []
Loss_all = []
all_ssim_loss = 0.
all_pixel_loss = 0.
for e in tbar:
print('Epoch %d.....' % e)
# load training database
image_set_ir, batches = utils.load_dataset(original_imgs_path, batch_size)
# print("image_set_ir:{}".format(type(image_set_ir)))-->list
#开始训练
AE_Encoder.train()
AE_Decoder.train()
count = 0
for batch in range(batches):
image_paths = image_set_ir[batch * batch_size:(batch * batch_size + batch_size)]
img = utils.get_train_images_auto(image_paths, height=args.HEIGHT, width=args.WIDTH, mode=img_model)#自动得到训练图像tensor类型
count += 1
optimizer1.zero_grad()#梯度
optimizer2.zero_grad() # 梯度
img = Variable(img, requires_grad=False)
# Variable,也就是变量,是神经网络计算图里特有的一个概念,就是Variable提供了自动求导的功能,将tensor变成variable
# 之前如果了解过Tensorflow的读者应该清楚神经网络在做运算的时候需要先构造一个计算图谱,然后在里面进行前向传播和反向传播。
# Variable和Tensor本质上没有区别,不过Variable会被放入一个计算图中,然后进行前向传播,反向传播,自动求导。
# requires_grad默认fasle不对这个变量求梯度
if args.cuda:
img = img.cuda()
# img=img.to(device)
# get fusion image
# encoder 16x16 3核
# densefuse_model = DenseFuse_net(input_nc, output_nc) # 加载DenseFuse模型
en, skips = AE_Encoder(img)
# print(len(en))
# decoder 64x64 3核 ==> c2->c3->c4->c5
outputs = AE_Decoder(en, skips)
# print(outputs[0].shape)
# resolution loss分辨率减损
x = Variable(img.data.clone(), requires_grad=False)
ssim_loss_value = 0.
pixel_loss_value = 0.
for output in outputs:
#损失函数
pixel_loss_temp = mse_loss(output, x)#像素损失函数Lp
ssim_loss_temp = ssim_loss(output, x, normalize=True)#SSIM
ssim_loss_value += (1-ssim_loss_temp)#Lssim
pixel_loss_value += pixel_loss_temp
ssim_loss_value /= len(outputs)
pixel_loss_value /= len(outputs)
# total loss L=λLssim+Lp
total_loss = pixel_loss_value + args.ssim_weight[i] * ssim_loss_value
total_loss.backward()#反向传播,得到每个参数检验的梯度
optimizer1.step()#对其中参数进行优化
optimizer2.step()#对其中参数进行优化
all_ssim_loss += ssim_loss_value.item()#直接获得所对应的python数据类型
all_pixel_loss += pixel_loss_value.item()
if (batch + 1) % args.log_interval == 0:
mesg = "{}\tEpoch {}:\t[{}/{}]\t pixel loss: {:.6f}\t ssim loss: {:.6f}\t total: {:.6f}".format(
time.ctime(), e + 1, count, batches,
all_pixel_loss / args.log_interval,
all_ssim_loss / args.log_interval,
(args.ssim_weight[i] * all_ssim_loss + all_pixel_loss) / args.log_interval
)
tbar.set_description(mesg)#进度条
Loss_pixel.append(all_pixel_loss / args.log_interval)
Loss_ssim.append(all_ssim_loss / args.log_interval)
Loss_all.append((args.ssim_weight[i] * all_ssim_loss + all_pixel_loss) / args.log_interval)
all_ssim_loss = 0.
all_pixel_loss = 0.
if (batch+1 ) % (200 * args.log_interval) == 0:
# save model
AE_Encoder.eval()
AE_Decoder.eval()
AE_Encoder.cpu()
AE_Decoder.cpu()
save_encoder_filename = args.ssim_path[i] + '/' + "Encoder_Epoch_" + str(e) + "_iters_" + str(count) + "_" + \
str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + args.ssim_path[
i] + ".model"
save_encoder_path = os.path.join(args.save_model_dir, save_encoder_filename)
torch.save(AE_Encoder.state_dict(), save_encoder_path)
save_decoder_filename = args.ssim_path[i] + '/' + "Decoder_Epoch_" + str(e) + "_iters_" + str(
count) + "_" + \
str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + args.ssim_path[
i] + ".model"
save_decoder_path = os.path.join(args.save_model_dir, save_decoder_filename)
torch.save(AE_Decoder.state_dict(), save_decoder_path)
# save loss data
# pixel loss
loss_data_pixel = np.array(Loss_pixel)
loss_filename_path = args.ssim_path[i] + '/' + "loss_pixel_epoch_" + str(
args.epochs) + "_iters_" + str(count) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \
args.ssim_path[i] + ".mat"
save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path)
scio.savemat(save_loss_path, {'loss_pixel': loss_data_pixel})
# SSIM loss
loss_data_ssim = np.array(Loss_ssim)
loss_filename_path = args.ssim_path[i] + '/' + "loss_ssim_epoch_" + str(
args.epochs) + "_iters_" + str(count) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \
args.ssim_path[i] + ".mat"
save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path)
scio.savemat(save_loss_path, {'loss_ssim': loss_data_ssim})
# all loss
loss_data_total = np.array(Loss_all)
loss_filename_path = args.ssim_path[i] + '/' + "loss_total_epoch_" + str(
args.epochs) + "_iters_" + str(count) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \
args.ssim_path[i] + ".mat"
save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path)
scio.savemat(save_loss_path, {'loss_total': loss_data_total})
AE_Encoder.train()
AE_Encoder.cuda()
AE_Decoder.train()
AE_Decoder.cuda()
# densefuse_model.to(device)
tbar.set_description("\nCheckpoint, trained model saved at", save_decoder_path)
# pixel loss
loss_data_pixel = np.array(Loss_pixel)
loss_filename_path = args.ssim_path[i] + '/' + "Final_loss_pixel_epoch_" + str(
args.epochs) + "_" + str(time.ctime()).replace(' ', '_').replace(':','_') + "_" + \
args.ssim_path[i] + ".mat"
save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path)
scio.savemat(save_loss_path, {'loss_pixel': loss_data_pixel})
# SSIM loss
loss_data_ssim = np.array(Loss_ssim)
loss_filename_path = args.ssim_path[i] + '/' + "Final_loss_ssim_epoch_" + str(
args.epochs) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \
args.ssim_path[i] + ".mat"
save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path)
scio.savemat(save_loss_path, {'loss_ssim': loss_data_ssim})
# all loss
loss_data_total = np.array(Loss_all)
loss_filename_path = args.ssim_path[i] + '/' + "Final_loss_total_epoch_" + str(
args.epochs) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \
args.ssim_path[i] + ".mat"
save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path)
scio.savemat(save_loss_path, {'loss_total': loss_data_total})
# save model
AE_Encoder.eval()
AE_Decoder.eval()
AE_Encoder.cpu()
AE_Decoder.cpu()
save_encoder_filename = args.ssim_path[i] + '/' "Final_Encoder_epoch" + ".model"
save_encoder_path = os.path.join(args.save_model_dir, save_encoder_filename)
torch.save(AE_Encoder.state_dict(), save_encoder_path)
save_decoder_filename = args.ssim_path[i] + '/' "Final_Decoder_epoch" + ".model"
save_decoder_path = os.path.join(args.save_model_dir, save_decoder_filename)
torch.save(AE_Decoder.state_dict(), save_decoder_path)
print("\nDone, trained model saved at", save_decoder_path)
if __name__ == "__main__":
main()