-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain.py
121 lines (101 loc) · 4.51 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import argparse
import os
import json
import logging
from pathlib import Path
import torch
from torch.utils import data
from models.model_factory import get_model
from dataloader import Dataset
from utils import train_model, evaluate_model
def get_arg():
parser = argparse.ArgumentParser()
parser.add_argument('vocab_path')
parser.add_argument('train_path')
parser.add_argument('val_path')
parser.add_argument('--experiment_name')
parser.add_argument('--src_postfix', default='.notone')
parser.add_argument('--trg_postfix', default='.tone')
parser.add_argument('--config_file', default='model_config.json')
parser.add_argument('--model_name', default='big_evolved')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--cuda', action='store_true', default=False)
parser.add_argument('--learning_rate', type=float, default=0.0001)
parser.add_argument('--num_epochs', type=int, default=1)
parser.add_argument('--restore_file', default=None)
parser.add_argument('--initial_epoch', default=1, type=int)
args = parser.parse_args()
return args
if __name__=='__main__':
args = get_arg()
# Init experiment folder
experiment_folder = os.path.join("experiments", args.experiment_name)
Path(experiment_folder).mkdir(parents=True, exist_ok=True)
# Init Log
log_file = os.path.join(experiment_folder, "logs.txt")
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logging.basicConfig(filename=log_file,
filemode='a',
level=logging.INFO,
format="%(levelname)s - %(asctime)s: %(message)s")
logger=logging.getLogger(__name__)
# Load tokenizer
print("Load tokenizer")
tokenizer = torch.load(args.vocab_path)
src_tokenizer = tokenizer['notone']
trg_tokenizer = tokenizer['tone']
src_pad_token = 0
trg_pad_token = 0
# Load data
print("Load data")
train_src_file = args.train_path + args.src_postfix
train_trg_file = args.train_path + args.trg_postfix
train_dataset = Dataset(src_tokenizer, trg_tokenizer, train_src_file, train_trg_file)
train_iter = data.dataloader.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_src_file = args.val_path + args.src_postfix
val_trg_file = args.val_path + args.trg_postfix
val_dataset = Dataset(src_tokenizer, trg_tokenizer, val_src_file, val_trg_file)
val_iter = data.dataloader.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
# Model config
with open(args.config_file) as f:
config = json.load(f)
if args.model_name in config:
model_param = config[args.model_name]
else:
raise Exception("Invalid model name")
model_param['src_vocab_size'] = len(src_tokenizer.word_index) + 1
model_param['trg_vocab_size'] = len(trg_tokenizer.word_index) + 1
# Device
print("Init model")
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
# Init model
model = get_model(model_param)
if device.type=='cuda':
model = model.cuda()
optim = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9)
print("Using", device.type)
# Load weight
if args.restore_file is not None:
if os.path.isfile(args.restore_file):
print("Load model")
state = torch.load(args.restore_file)
model.load_state_dict(state['model'])
optim.load_state_dict(state['optim'])
else:
raise Exception("Invalid weight path")
# Init weight dir
weight_folder = os.path.join(experiment_folder, "weights")
Path(weight_folder).mkdir(parents=True, exist_ok=True)
# Train model
print("Start training %d epochs" % args.num_epochs)
for e in range(args.initial_epoch, args.num_epochs+1):
logger.info("Epoch %02d/%02d" % (e, args.num_epochs))
logger.info("Start training")
print("\nEpoch %02d/%02d" % (e, args.num_epochs), flush=True)
save_file = os.path.join(weight_folder, 'epoch_%02d.h5' % e)
train_loss = train_model(model, optim, train_iter, src_pad_token, use_mask=model_param["use_mask"], device=device, save_path=save_file)
logger.info("End training")
logger.info("train_loss = %.8f" % train_loss)
val_loss = evaluate_model(model, val_iter, src_pad_token, use_mask=model_param["use_mask"], device=device)
logger.info("val_loss = %.8f\n" % val_loss)