-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
185 lines (167 loc) · 8.25 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import argparse
import functools
import os
import shutil
import time
from datetime import datetime, timedelta
import paddle
from paddle.distributed import fleet
from paddle.io import DataLoader
from paddle.metric import accuracy
from visualdl import LogWriter
from utils.arcmargin import ArcNet
from utils.reader import CustomDataset
from utils.se_resnet_vd import SE_ResNet_vd
from utils.utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('batch_size', int, 32, '训练的批量大小')
add_arg('num_workers', int, 4, '读取数据的线程数量')
add_arg('num_epoch', int, 50, '训练的轮数')
add_arg('num_classes', int, 3242, '分类的类别数量')
add_arg('learning_rate', float, 1e-3, '初始学习率的大小')
add_arg('input_shape', str, '(None, 1, 257, 257)', '数据输入的形状')
add_arg('train_list_path', str, 'dataset/train_list.txt', '训练数据的数据列表路径')
add_arg('test_list_path', str, 'dataset/test_list.txt', '测试数据的数据列表路径')
add_arg('save_model', str, 'models/', '模型保存的路径')
add_arg('resume', str, None, '恢复训练,当为None则不使用恢复模型')
add_arg('pretrained_model', str, None, '预训练模型的路径,当为None则不使用预训练模型')
args = parser.parse_args()
# 评估模型
@paddle.no_grad()
def test(model, metric_fc, test_loader):
model.eval()
accuracies = []
for batch_id, (spec_mag, label) in enumerate(test_loader()):
feature = model(spec_mag)
output = metric_fc(feature, label)
label = paddle.reshape(label, shape=(-1, 1))
acc = accuracy(input=output, label=label)
accuracies.append(acc.numpy()[0])
model.train()
return float(sum(accuracies) / len(accuracies))
# 保存模型
def save_model(args, epoch, model, metric_fc, optimizer):
model_params_path = os.path.join(args.save_model, 'epoch_%d' % epoch)
if not os.path.exists(model_params_path):
os.makedirs(model_params_path)
# 保存模型参数
paddle.save(model.state_dict(), os.path.join(model_params_path, 'model.pdparams'))
paddle.save(metric_fc.state_dict(), os.path.join(model_params_path, 'metric_fc.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(model_params_path, 'optimizer.pdopt'))
# 删除旧的模型
old_model_path = os.path.join(args.save_model, 'epoch_%d' % (epoch - 3))
if os.path.exists(old_model_path):
shutil.rmtree(old_model_path)
def train(args):
# 获取有多少张显卡训练
nranks = paddle.distributed.get_world_size()
local_rank = paddle.distributed.get_rank()
if nranks > 1:
# 初始化Fleet环境
fleet.init(is_collective=True)
if local_rank == 0:
# 日志记录器
writer = LogWriter(logdir='log')
# 数据输入的形状
input_shape = eval(args.input_shape)
# 获取数据
train_dataset = CustomDataset(args.train_list_path, model='train', spec_len=input_shape[3])
# 设置支持多卡训练
if nranks > 1:
train_batch_sampler = paddle.io.DistributedBatchSampler(train_dataset, batch_size=args.batch_size, shuffle=True)
else:
train_batch_sampler = paddle.io.BatchSampler(train_dataset, batch_size=args.batch_size, shuffle=True)
train_loader = DataLoader(dataset=train_dataset, batch_sampler=train_batch_sampler, num_workers=args.num_workers)
test_dataset = CustomDataset(args.test_list_path, model='test', spec_len=input_shape[3])
test_batch_sampler = paddle.io.BatchSampler(test_dataset, batch_size=args.batch_size)
test_loader = DataLoader(dataset=test_dataset, batch_sampler=test_batch_sampler, num_workers=args.num_workers)
# 获取模型
model = SE_ResNet_vd()
metric_fc = ArcNet(feature_dim=model.pool2d_avg_channels, class_dim=args.num_classes)
if local_rank == 0:
paddle.summary(model, input_size=input_shape)
# 设置支持多卡训练
if nranks > 1:
model = paddle.DataParallel(model)
metric_fc = paddle.DataParallel(metric_fc)
# 初始化epoch数
last_epoch = 0
# 学习率衰减
scheduler = paddle.optimizer.lr.StepDecay(learning_rate=args.learning_rate, step_size=1, gamma=0.8)
# 设置优化方法
optimizer = paddle.optimizer.Momentum(parameters=model.parameters() + metric_fc.parameters(),
learning_rate=scheduler,
momentum=0.9,
weight_decay=paddle.regularizer.L2Decay(1e-6))
# 加载预训练模型
if args.pretrained_model is not None:
model_dict = model.state_dict()
param_state_dict = paddle.load(os.path.join(args.pretrained_model, 'model.pdparams'))
for name, weight in model_dict.items():
if name in param_state_dict.keys():
if weight.shape != list(param_state_dict[name].shape):
print('{} not used, shape {} unmatched with {} in model.'.
format(name, list(param_state_dict[name].shape), weight.shape))
param_state_dict.pop(name, None)
else:
print('Lack weight: {}'.format(name))
model.set_dict(param_state_dict)
print('成功加载预训练模型参数')
# 恢复训练
if args.resume is not None:
model.set_state_dict(paddle.load(os.path.join(args.resume, 'model.pdparams')))
metric_fc.set_state_dict(paddle.load(os.path.join(args.resume, 'metric_fc.pdparams')))
optimizer_state = paddle.load(os.path.join(args.resume, 'optimizer.pdopt'))
optimizer.set_state_dict(optimizer_state)
# 获取预训练的epoch数
last_epoch = optimizer_state['LR_Scheduler']['last_epoch']
print('成功加载模型参数和优化方法参数')
# 获取损失函数
loss = paddle.nn.CrossEntropyLoss()
train_step = 0
test_step = 0
sum_batch = len(train_loader) * (args.num_epoch - last_epoch)
# 开始训练
for epoch in range(last_epoch, args.num_epoch):
loss_sum = []
accuracies = []
start = time.time()
for batch_id, (spec_mag, label) in enumerate(train_loader()):
feature = model(spec_mag)
output = metric_fc(feature, label)
# 计算损失值
los = loss(output, label)
los.backward()
optimizer.step()
optimizer.clear_grad()
# 计算准确率
label = paddle.reshape(label, shape=(-1, 1))
acc = accuracy(input=paddle.nn.functional.softmax(output), label=label)
accuracies.append(acc.numpy()[0])
loss_sum.append(los)
# 多卡训练只使用一个进程打印
if batch_id % 100 == 0 and local_rank == 0:
eta_sec = ((time.time() - start) * 1000) * (sum_batch - (epoch - last_epoch) * len(train_loader) - batch_id)
eta_str = str(timedelta(seconds=int(eta_sec / 1000)))
print('[%s] Train epoch %d, batch: %d/%d, loss: %f, accuracy: %f, lr: %.8f, eta: %s' % (
datetime.now(), epoch, batch_id, len(train_loader), sum(loss_sum) / len(loss_sum), sum(accuracies) / len(accuracies), scheduler.get_lr(), eta_str))
writer.add_scalar('Train loss', los, train_step)
train_step += 1
loss_sum = []
start = time.time()
# 多卡训练只使用一个进程执行评估和保存模型
if local_rank == 0:
acc = test(model, metric_fc, test_loader)
print('='*70)
print('[%s] Test %d, accuracy: %f' % (datetime.now(), epoch, acc))
print('='*70)
writer.add_scalar('Test acc', acc, test_step)
# 记录学习率
writer.add_scalar('Learning rate', scheduler.last_lr, epoch)
test_step += 1
save_model(args, epoch, model, metric_fc, optimizer)
scheduler.step()
if __name__ == '__main__':
print_arguments(args)
train(args)