-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathtrain_code.py
100 lines (79 loc) · 2.9 KB
/
train_code.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import torch
import argparse
from model.code import CodeModel
from dataset import CodeDataset
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
def train(args):
# gpu device
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
device = torch.device("cuda:0")
# Initialize dataset loader
dataset = CodeDataset(datapath=args.input, maxlen=args.seqlen)
dataloader = torch.utils.data.DataLoader(dataset,
shuffle=True,
batch_size=args.batchsize,
num_workers=5)
# Initialize vertex model
model = CodeModel(
config={
'hidden_dim': 512,
'embed_dim': 256,
'num_layers': 8,
'num_heads': 8,
'dropout_rate': 0.1
},
max_len=args.seqlen,
classes=args.code,
)
model = model.to(device).train()
# Initialize optimizer
network_parameters = list(model.parameters())
optimizer = torch.optim.Adam(network_parameters, lr=1e-3)
# logging
writer = SummaryWriter(log_dir=args.output)
# Main training loop
iters = 0
print('Start training...')
for epoch in range(800):
print(epoch)
for batch in dataloader:
code = batch
code = code.to(device)
# Pass through vertex prediction module
logits = model(code[:, :-1])
c_pred = logits.reshape(-1, logits.shape[-1])
c_target = code.reshape(-1)
code_loss = F.cross_entropy(c_pred, c_target)
total_loss = code_loss
# logging
if iters % 20 == 0:
writer.add_scalar("Loss/Total", total_loss, iters)
# Backprop
optimizer.zero_grad()
total_loss.backward()
nn.utils.clip_grad_norm_(network_parameters, max_norm=1.0) # clip gradient
optimizer.step()
iters += 1
writer.flush()
# save model after n epoch
if (epoch+1) % 500 == 0:
torch.save(model.state_dict(), os.path.join(args.output,'code_epoch_'+str(epoch+1)+'.pt'))
writer.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, required=True)
parser.add_argument("--output", type=str, required=True)
parser.add_argument("--batchsize", type=int, required=True)
parser.add_argument("--device", type=str, required=True)
parser.add_argument("--seqlen", type=int, required=True)
parser.add_argument("--code", type=int, required=True)
args = parser.parse_args()
# Create training folder
result_folder = args.output
if not os.path.exists(result_folder):
os.makedirs(result_folder)
# Start training
train(args)