Skip to content

Commit

Permalink
Merge pull request #653 from FunAudioLLM/dev/lyuxiang.lx
Browse files Browse the repository at this point in the history
resume training
  • Loading branch information
aluminumbox authored Nov 15, 2024
2 parents 7701325 + c3dfd23 commit d6dbdfb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
18 changes: 15 additions & 3 deletions cosyvoice/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,15 @@ def main():

# load checkpoint
model = configs[args.model]
start_step, start_epoch = 0, -1
if args.checkpoint is not None:
if os.path.exists(args.checkpoint):
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False)
state_dict = torch.load(args.checkpoint, map_location='cpu')
model.load_state_dict(state_dict, strict=False)
if 'step' in state_dict:
start_step = state_dict['step']
if 'epoch' in state_dict:
start_epoch = state_dict['epoch']
else:
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))

Expand All @@ -129,19 +135,25 @@ def main():

# Get optimizer & scheduler
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
scheduler.set_step(start_step)
if scheduler_d is not None:
scheduler_d.set_step(start_step)

# Save init checkpoints
info_dict = deepcopy(configs['train_conf'])
info_dict['step'] = start_step
info_dict['epoch'] = start_epoch
save_model(model, 'init', info_dict)

# Get executor
executor = Executor(gan=gan)
executor.step = start_step

# Init scaler, used for pytorch amp mixed precision training
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None

print('start step {} start epoch {}'.format(start_step, start_epoch))
# Start training loop
for epoch in range(info_dict['max_epoch']):
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
executor.epoch = epoch
train_dataset.set_epoch(epoch)
dist.barrier()
Expand Down
5 changes: 3 additions & 2 deletions cosyvoice/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def save_model(model, model_name, info_dict):

if info_dict["train_engine"] == "torch_ddp":
if rank == 0:
torch.save(model.module.state_dict(), save_model_path)
torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path)
else:
with torch.no_grad():
model.save_checkpoint(save_dir=model_dir,
Expand Down Expand Up @@ -284,7 +284,8 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
# We don't check grad here since that if the gradient
# has inf/nan values, scaler.step will skip
# optimizer.step().
scaler.step(optimizer)
if torch.isfinite(grad_norm):
scaler.step(optimizer)
scaler.update()
else:
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
Expand Down

0 comments on commit d6dbdfb

Please sign in to comment.