-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
124 lines (94 loc) · 3.23 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import sys
import pickle
from model import *
hidden_size = 500
embed_size = 50
learning_rate = 0.0001
n_epochs = 2500
grad_clip = 1.0
kld_start_inc = 10000
kld_weight = 0.05
kld_max = 0.1
kld_inc = 0.000002
temperature = 0.9
temperature_min = 0.5
temperature_dec = 0.000002
# Training
# ------------------------------------------------------------------------------
if len(sys.argv) < 3:
print("Usage: python train.py [filename] [checkpoint name]")
sys.exit(1)
file, file_len = read_file(sys.argv[1])
# file, file_len = read_file('../practical-pytorch/data/first-names.txt')
lines = [line.strip() for line in file.split('\n')]
print('n lines', len(lines))
def gen_chunks(lines, n):
chunks = []
for i in range(0, len(lines), n):
chunks.append('\n'.join(lines[i: i + n]))
return chunks
def good_size(line):
return len(line) >= MIN_LENGTH and len(line) <= MAX_LENGTH
def good_content(line):
return 'http' not in line and '/' not in line
lines = [line for line in lines if good_size(line) and good_content(line)]
print('n lines', len(lines))
chunks = gen_chunks(lines,1)
random.shuffle(lines)
random.shuffle(chunks)
def random_training_set():
chunk = random.choice(chunks)
inp = char_tensor(chunk)
target = char_tensor(chunk)
return inp, target
e = EncoderRNN(n_characters, hidden_size, embed_size)
d = DecoderRNN(embed_size, hidden_size, n_characters, 2)
vae = VAE(e, d)
optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
if USE_CUDA:
vae.cuda()
criterion.cuda()
log_every = 200
save_every = 5000
def save():
save_filename = sys.argv[2]
torch.save(vae, "out/%s" % save_filename)
with open("out/losses_%s.pkl" % save_filename, "wb") as f:
pickle.dump(losses, f)
print('Saved as %s' % save_filename)
try:
losses = []
for epoch in range(n_epochs):
input, target = random_training_set()
optimizer.zero_grad()
m, l, z, decoded = vae(input, temperature)
if temperature > temperature_min:
temperature -= temperature_dec
loss = criterion(decoded, target)
#job.record(epoch, loss.data[0])
KLD = (-0.5 * torch.sum(l - torch.pow(m, 2) - torch.exp(l) + 1, 1)).mean().squeeze()
loss += KLD * kld_weight
if epoch > kld_start_inc and kld_weight < kld_max:
kld_weight += kld_inc
loss.backward()
# print('from', next(vae.parameters()).grad.data[0][0])
ec = torch.nn.utils.clip_grad_norm(vae.parameters(), grad_clip)
# print('to ', next(vae.parameters()).grad.data[0][0])
optimizer.step()
losses.append(loss)
if epoch % log_every == 0:
print('[%d] %.4f (k=%.4f, t=%.4f, kl=%.4f, ec=%.4f)' % (
epoch, loss.data, kld_weight, temperature, KLD.data, ec
))
print(' (target) "%s"' % longtensor_to_string(target))
generated = vae.decoder.generate(z, MAX_LENGTH, temperature)
print('(generated) "%s"' % tensor_to_string(generated))
print('')
if epoch > 0 and epoch % save_every == 0:
save()
save()
except KeyboardInterrupt as err:
print("ERROR", err)
print("Saving before quit...")
save()