-
Notifications
You must be signed in to change notification settings - Fork 1
/
train-pytorch.py
117 lines (89 loc) · 3.06 KB
/
train-pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import sys
from model import *
hidden_size = 500
embed_size = 50
learning_rate = 0.0001
n_epochs = 100000
grad_clip = 1.0
kld_start_inc = 10000
kld_weight = 0.05
kld_max = 0.1
kld_inc = 0.000002
temperature = 0.9
temperature_min = 0.5
temperature_dec = 0.000002
# Training
# ------------------------------------------------------------------------------
if len(sys.argv) < 2:
print("Usage: python train.py [filename]")
sys.exit(1)
file, file_len = read_file(sys.argv[1])
# file, file_len = read_file('../practical-pytorch/data/first-names.txt')
lines = [line.strip() for line in file.split('\n')]
print('n lines', len(lines))
def gen_chunks(lines, n):
chunks = []
for i in range(0, len(lines), n):
chunks.append('\n'.join(lines[i: i + n]))
return chunks
def good_size(line):
return len(line) >= MIN_LENGTH and len(line) <= MAX_LENGTH
def good_content(line):
return 'http' not in line and '/' not in line
lines = [line for line in lines if good_size(line) and good_content(line)]
print('n lines', len(lines))
chunks = gen_chunks(lines,5)
random.shuffle(lines)
random.shuffle(chunks)
def random_training_set():
chunk = random.choice(chunks)
inp = char_tensor(chunk)
target = char_tensor(chunk)
return inp, target
e = EncoderRNN(n_characters, hidden_size, embed_size)
d = DecoderRNN(embed_size, hidden_size, n_characters, 2)
vae = VAE(e, d)
optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
if USE_CUDA:
vae.cuda()
criterion.cuda()
log_every = 200
save_every = 5000
def save():
save_filename = 'vae-pytorch.pt'
torch.save(vae, save_filename)
print('Saved as %s' % save_filename)
try:
for epoch in range(n_epochs):
input, target = random_training_set()
optimizer.zero_grad()
m, l, z, decoded = vae(input, temperature)
if temperature > temperature_min:
temperature -= temperature_dec
loss = criterion(decoded, target)
#job.record(epoch, loss.data[0])
KLD = (-0.5 * torch.sum(l - torch.pow(m, 2) - torch.exp(l) + 1, 1)).mean().squeeze()
loss += KLD * kld_weight
if epoch > kld_start_inc and kld_weight < kld_max:
kld_weight += kld_inc
loss.backward()
# print('from', next(vae.parameters()).grad.data[0][0])
ec = torch.nn.utils.clip_grad_norm(vae.parameters(), grad_clip)
# print('to ', next(vae.parameters()).grad.data[0][0])
optimizer.step()
if epoch % log_every == 0:
print('[%d] %.4f (k=%.4f, t=%.4f, kl=%.4f, ec=%.4f)' % (
epoch, loss.data, kld_weight, temperature, KLD.data, ec
))
print(' (target) "%s"' % longtensor_to_string(target))
generated = vae.decoder.generate(z, MAX_LENGTH, temperature)
print('(generated) "%s"' % tensor_to_string(generated))
print('')
if epoch > 0 and epoch % save_every == 0:
save()
save()
except KeyboardInterrupt as err:
print("ERROR", err)
print("Saving before quit...")
save()