-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_multi_gpu.py
336 lines (306 loc) · 15.8 KB
/
main_multi_gpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
import os
import time
import paddle
from valid import validation
from data import prep_dataset, prep_loader
from utils import same_seeds, save_model, \
ReduceOnPlateauWithAnnael, NMTMetric, get_logger, get_grad_norm
from models import build_model
import paddle.distributed as dist
from paddlenlp.transformers import CrossEntropyCriterion, LinearDecayWithWarmup
from config import get_config, get_arguments
from paddle.optimizer.lr import CosineAnnealingDecay, NoamDecay
from visualdl import LogWriter
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
def train_one_epoch(dataloader,
model,
criterion,
optimizer,
scaler,
epoch,
step_id,
metric,
logger,
logwriter,
max_epoch,
pad_idx=1,
amp=False,
log_steps=100,
update_freq=1,
scheduler=None): # for warmup
"""Training for one epoch
Args:
dataloader: paddle.io.DataLoader, dataloader instance
model: ConvS2S model
criterion: nn.criterion
epoch: int, current epoch
total_epoch: int, total num of epoch, for logging
log_steps: int, num of iters to log info
update_freq: int, num of iters for accumulating gradients
Returns:
train_loss_meter.avg
train_acc_meter.avg
train_time
"""
world_size = paddle.distributed.get_world_size()
model.train()
# Train loop
sentences = 0
tic_train = time.time()
for batch_id, input_data in enumerate(dataloader):
(samples_id, src_tokens, prev_tokens, tgt_tokens) = input_data
gnorm = 0 # gradient norm
# for multi card training
if world_size > 1:
if amp is True: # mixed precision training
# step 1 : skip gradient synchronization by 'no_sync'
with model.no_sync():
with paddle.amp.auto_cast():
logits = model(src_tokens=src_tokens, prev_output_tokens=prev_tokens)[0]
sum_cost, avg_cost, token_num = criterion(logits, tgt_tokens)
scaled = scaler.scale(avg_cost)
scaled.backward()
gnorm = get_grad_norm(grads=[p.grad for p in optimizer._param_groups])
if ((batch_id + 1) % update_freq == 0) or (batch_id + 1 == len(dataloader)):
fused_allreduce_gradients(list(model.parameters()), None)
scaler.minimize(optimizer, scaled)
optimizer.clear_grad()
else: # full precision training
with model.no_sync():
logits = model(src_tokens=src_tokens, prev_output_tokens=prev_tokens)[0]
sum_cost, avg_cost, token_num = criterion(logits, tgt_tokens)
avg_cost.backward()
gnorm = get_grad_norm(grads=[p.grad for p in optimizer._param_groups])
if ((batch_id + 1) % update_freq == 0) or (batch_id + 1 == len(dataloader)):
fused_allreduce_gradients(list(model.parameters()), None)
optimizer.step()
optimizer.clear_grad()
# for single card training
else:
if amp is True: # mixed precision training
with paddle.amp.auto_cast():
logits = model(src_tokens=src_tokens, prev_output_tokens=prev_tokens)[0]
sum_cost, avg_cost, token_num = criterion(logits, tgt_tokens)
scaled = scaler.scale(avg_cost)
scaled.backward()
gnorm = get_grad_norm(grads=[p.grad for p in optimizer._param_groups])
if ((batch_id + 1) % update_freq == 0) or (batch_id + 1 == len(dataloader)):
scaler.minimize(optimizer, scaled)
optimizer.clear_grad()
else: # full precision training
logits = model(src_tokens=src_tokens, prev_output_tokens=prev_tokens)[0]
sum_cost, avg_cost, token_num = criterion(logits, tgt_tokens)
avg_cost.backward()
gnorm = get_grad_norm(grads=[p.grad for p in optimizer._param_groups])
if ((batch_id + 1) % update_freq == 0) or (batch_id + 1 == len(dataloader)):
optimizer.step()
optimizer.clear_grad()
# aggregate metric
loss, nll_loss, ppl = metric.update(sum_cost, logits, target=tgt_tokens, sample_size=token_num, pad_id=pad_idx,
gnorm=gnorm)
sentences += src_tokens.shape[0]
# log
if (batch_id + 1) % log_steps == 0:
avg_bsz = sentences / (batch_id + 1)
bsz = int(avg_bsz * update_freq)
avg_total_steps = int(len(dataloader.dataset) // avg_bsz // dist.get_world_size()) # Number of iterations of each epoch in a single card
cur_steps=avg_total_steps*(epoch-1)+batch_id+1 # current forward steps (single card)
num_updates = (cur_steps//update_freq) * dist.get_world_size()
loss, nll_loss, ppl, gnorm = metric.accumulate()
logger.info(
f"Train | epoch:[{epoch}/{int(max_epoch)}], step:[{batch_id + 1}/{avg_total_steps}], bsz:{bsz}, "
f"speed:{log_steps / (time.time() - tic_train):.2f}step/s "
f"loss:{float(loss):.3f}, nll_loss:{float(nll_loss):.3f}, ppl:{float(ppl):.3f}, gnorm:{gnorm:.4f},num_updates:{num_updates}, lr:{optimizer.get_lr():.5f}")
tic_train = time.time()
if dist.get_rank() == 0:
logwriter.add_scalar(tag='train/loss', step=step_id, value=loss)
logwriter.add_scalar(tag='train/ppl', step=step_id, value=ppl)
if isinstance(scheduler,
(LinearDecayWithWarmup, CosineAnnealingDecay, NoamDecay)): # these scheds updated each step
scheduler.step(step_id)
step_id += 1
return step_id
def main_worker(*args):
# 0.Preparation
conf = args[0]
dist.init_parallel_env()
last_epoch = conf.train.last_epoch
world_size = paddle.distributed.get_world_size()
local_rank = paddle.distributed.get_rank()
logger = get_logger(loggername=f"ConvS2S_{local_rank}", save_path=conf.SAVE)
logger.info(f'----- world_size = {world_size}, local_rank = {local_rank}')
seed = conf.seed + local_rank
same_seeds(seed)
# 1. Create train and val dataloader
dataset_train, dataset_val = args[1], args[2]
# Create training dataloader
train_loader = None
if not conf.eval:
train_loader = prep_loader(conf, dataset_train, 'train', multi_process=True)
logger.info(
f'----- Total of train set:{len(train_loader.dataset)} ,train batch: {len(train_loader)} [single gpu]')
dev_loader = prep_loader(conf, dataset_val, 'dev', multi_process=True)
logger.info(f'----- Total of valid set:{len(dev_loader.dataset)} ,valid batch: {len(dev_loader)} [single gpu]')
if local_rank == 0:
logger.info(f'configs:\n{conf}')
# 2. Create model
model = build_model(conf, is_test=False)
model = paddle.DataParallel(model)
# 3. Define criterion
criterion = CrossEntropyCriterion(conf.learning_strategy.label_smooth_eps, pad_idx=conf.model.pad_idx)
metric = NMTMetric(name=conf.model.model_name)
logwriter = None
best_bleu = 0
if local_rank == 0:
logwriter = LogWriter(
logdir=os.path.join(conf.SAVE, f'vislogs/convs2s_{conf.data.src_lang}{conf.data.tgt_lang}'))
# 4. Define optimizer and lr_scheduler
global_step_id = conf.train.last_epoch * len(train_loader) + 1 if train_loader is not None else 0
scheduler = None
if conf.learning_strategy.sched == "plateau":
scheduler = ReduceOnPlateauWithAnnael(learning_rate=conf.learning_strategy.learning_rate,
patience=conf.learning_strategy.patience,
force_anneal=conf.learning_strategy.force_anneal,
factor=conf.learning_strategy.lr_shrink,
min_lr=conf.learning_strategy.min_lr) # reduce the learning rate until it falls below 10−4
if conf.learning_strategy.sched == "warmup":
scheduler = LinearDecayWithWarmup(learning_rate=conf.learning_strategy.learning_rate,
warmup=conf.learning_strategy.warmup,
last_epoch=conf.train.last_epoch,
total_steps=conf.train.max_epoch * len(train_loader))
elif conf.learning_strategy.sched == "cosine":
scheduler = CosineAnnealingDecay(learning_rate=conf.learning_strategy.learning_rate,
T_max=conf.train.max_epoch,
last_epoch=conf.train.last_epoch)
elif conf.learning_strategy.sched == "noamdecay":
scheduler = NoamDecay(d_model=conf.model.dmodel,
warmup_steps=conf.learning_strategy.warmup,
learning_rate=conf.learning_strategy.learning_rate,
last_epoch=global_step_id)
assert scheduler is not None, "Sched should be [plateau|warmup|cosine]"
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=conf.learning_strategy.clip_norm)
optimizer = None
if conf.learning_strategy.optimizer == "nag":
optimizer = paddle.optimizer.Momentum(
learning_rate=scheduler,
momentum=conf.learning_strategy.momentum,
weight_decay=float(conf.learning_strategy.weight_decay), # int object not callable error
use_nesterov=conf.learning_strategy.use_nesterov,
grad_clip=clip,
parameters=model.parameters())
elif conf.learning_strategy.optimizer == "adam":
optimizer = paddle.optimizer.Adam(
learning_rate=scheduler,
weight_decay=float(conf.learning_strategy.weight_decay), # int object not callable error
grad_clip=clip,
parameters=model.parameters())
elif conf.learning_strategy.optimizer == "adamw":
optimizer = paddle.optimizer.AdamW(
learning_rate=scheduler,
weight_decay=float(conf.learning_strategy.weight_decay), # int object not callable error
grad_clip=clip,
parameters=model.parameters())
assert optimizer is not None, "Optimizer should be [nag|adam|adamw]"
# 5. Load resume optimizer states
if conf.train.resume:
model_path = os.path.join(conf.train.resume, 'convs2s.pdparams')
optim_path = os.path.join(conf.train.resume, 'convs2s.pdopt')
assert os.path.isfile(model_path) is True, f"File {model_path} does not exist."
assert os.path.isfile(optim_path) is True, f"File {optim_path} does not exist."
model_state = paddle.load(model_path)
opt_state = paddle.load(optim_path)
if conf.learning_strategy.reset_lr: # weather to reset lr
opt_state['LR_Scheduler']['last_lr'] = conf.learning_strategy.learning_rate
# resume best bleu
best_bleu = opt_state['LR_Scheduler'].get('best_bleu', 0)
model.set_dict(model_state)
optimizer.set_state_dict(opt_state)
logger.info(
f"----- Resume Training: Load model and optmizer states from {conf.train.resume},LR={optimizer.get_lr():.5f}----- ")
# 6. Validation
if conf.eval:
logger.info('----- Start Validating')
val_loss, val_nll_loss, val_ppl, dev_bleu = validation(conf, dev_loader, model, criterion, logger)
return
# 6. Start training and validation
# 定义 GradScaler
scale_init = conf.train.fp16_init_scale
growth_interval = conf.train.growth_interval if conf.train.amp_scale_window else 2000
scaler = paddle.amp.GradScaler(init_loss_scaling=scale_init, incr_every_n_steps=growth_interval)
lowest_val_loss = 0
num_runs = 0
for epoch in range(last_epoch + 1, conf.train.max_epoch + 1):
# train
logger.info(f"Now training epoch {epoch}. LR={optimizer.get_lr():.5f}")
global_step_id = train_one_epoch(
dataloader=train_loader,
model=model,
criterion=criterion,
optimizer=optimizer,
scaler=scaler,
epoch=epoch,
step_id=global_step_id,
metric=metric,
logger=logger,
logwriter=logwriter,
max_epoch=conf.train.max_epoch,
pad_idx=conf.model.pad_idx,
amp=conf.train.amp,
log_steps=conf.train.log_steps,
update_freq=conf.train.update_freq,
scheduler=scheduler
)
metric.reset()
# evaluate model on valid data after one epoch
val_loss, val_nll_loss, val_ppl, dev_bleu = validation(conf, dev_loader, model, criterion, logger)
# save best model state
if (best_bleu < dev_bleu) and (local_rank == 0):
best_bleu = dev_bleu
save_dir = os.path.join(conf.SAVE, conf.model.save_model, "model_best")
save_model(model, optimizer, save_dir=save_dir, best_bleu=best_bleu)
logger.info(f"Epoch:[{epoch}] | Best Valid Bleu: {best_bleu:.3f} saved to {save_dir}!")
# visualize valid metrics
if local_rank == 0:
logwriter.add_scalar(tag='valid/loss', step=epoch, value=val_loss)
logwriter.add_scalar(tag='valid/ppl', step=epoch, value=val_ppl)
logwriter.add_scalar(tag='valid/bleu', step=epoch, value=dev_bleu)
# adjust learning rate when val ppl stops improving (each epoch).
if conf.learning_strategy.sched == "plateau":
scheduler.step(val_ppl)
# stop training when lr too small
cur_lr = round(optimizer.get_lr(), 5)
min_lr = round(conf.learning_strategy.min_lr, 5)
if (cur_lr <= min_lr) and (local_rank == 0):
logger.info("early stop since min lr is achieved.")
save_model(model, optimizer, save_dir=os.path.join(conf.SAVE, conf.model.save_model, "min_lr"))
break
# early stop
if conf.train.stop_patience > 1:
if val_loss < lowest_val_loss:
lowest_val_loss = val_loss
num_runs = 0
else:
num_runs += 1
if num_runs >= conf.train.stop_patience:
logger.info(
f"early stop since valid performance hasn't improved for last {conf.train.early_stop_num} runs")
break
# save model after several epochs
if (epoch % conf.train.save_epoch == 0) and (local_rank == 0):
save_model(model, optimizer, save_dir=os.path.join(conf.SAVE, conf.model.save_model, f"epoch_{epoch}"))
# save last model
if (conf.model.save_model) and (local_rank == 0):
save_model(model, optimizer, save_dir=os.path.join(conf.SAVE, conf.model.save_model, "epoch_final"))
if local_rank == 0:
logwriter.close()
def main():
args = get_arguments()
conf = get_config(args)
if not conf.eval:
dataset_train = prep_dataset(conf, mode='train')
else:
dataset_train = None
dataset_dev = prep_dataset(conf, mode='dev')
dist.spawn(main_worker, args=(conf, dataset_train, dataset_dev,), nprocs=conf.ngpus)
if __name__ == "__main__":
main()