-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
109 lines (85 loc) · 3.82 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
from transformers import BertTokenizer, AdamW, BertConfig
from torch.utils.data import DataLoader
from model import BertClassifier
from dataset import CNewsDataset
from tqdm import tqdm
from dbData import CLoadDbDataset
def main():
CLoadDbDataset()
return
# 参数设置
batch_size = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("device: ", device)
epochs = 2
learning_rate = 1e-4 # Learning Rate不宜太大
# 获取到dataset
train_dataset = CNewsDataset('/home/songzhouran/py/bert/BertClassifier/trainData/bitcoin86markproc.csv')
valid_dataset = CNewsDataset('/home/songzhouran/py/bert/BertClassifier/trainData/bitcoin86markproc1.csv')
# train_dataset = CNewsDataset('/Users/a/fsdownload/bitcoin86mark.csv')
# valid_dataset = CNewsDataset('/Users/a/fsdownload/bitcoin86mark1.csv')
# test_data = load_data('cnews/cnews.test.txt')
# 生成Batch
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
# test_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=False)
# 读取BERT的配置文件
bert_config = BertConfig.from_pretrained('bert-base-chinese')
num_labels = len(train_dataset.labels)
# 初始化模型
model = BertClassifier(bert_config, num_labels).to(device)
optimizer = AdamW(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
best_acc = 0
for epoch in range(1, epochs + 1):
losses = 0 # 损失
accuracy = 0 # 准确率
model.train()
train_bar = tqdm(train_dataloader)
for input_ids, token_type_ids, attention_mask, label_id in train_bar:
model.zero_grad()
train_bar.set_description('Epoch %i train' % epoch)
output = model(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
token_type_ids=token_type_ids.to(device),
)
loss = criterion(output, label_id.to(device))
losses += loss.item()
pred_labels = torch.argmax(output, dim=1) # 预测出的label
acc = torch.sum(pred_labels == label_id.to(device)).item() / len(pred_labels) # acc
accuracy += acc
loss.backward()
optimizer.step()
train_bar.set_postfix(loss=loss.item(), acc=acc)
average_loss = losses / len(train_dataloader)
average_acc = accuracy / len(train_dataloader)
print('\tTrain ACC:', average_acc, '\tLoss:', average_loss)
# 验证
model.eval()
losses = 0 # 损失
accuracy = 0 # 准确率
valid_bar = tqdm(valid_dataloader)
for input_ids, token_type_ids, attention_mask, label_id in valid_bar:
valid_bar.set_description('Epoch %i valid' % epoch)
output = model(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
token_type_ids=token_type_ids.to(device),
)
loss = criterion(output, label_id.to(device))
losses += loss.item()
pred_labels = torch.argmax(output, dim=1) # 预测出的label
acc = torch.sum(pred_labels == label_id.to(device)).item() / len(pred_labels) # acc
accuracy += acc
valid_bar.set_postfix(loss=loss.item(), acc=acc)
average_loss = losses / len(valid_dataloader)
average_acc = accuracy / len(valid_dataloader)
print('\tValid ACC:', average_acc, '\tLoss:', average_loss)
if average_acc > best_acc:
best_acc = average_acc
torch.save(model.state_dict(), 'models/best_model.pkl')
if __name__ == '__main__':
main()