-
Notifications
You must be signed in to change notification settings - Fork 6
/
generator.py
102 lines (70 loc) · 2.64 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import util
class Generator(nn.Module):
"""Generator"""
def __init__(self, vocab_size, embedding_size, hidden_dim, num_layers):
super(Generator, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_size)
self.lstm = nn.LSTM(embedding_size, hidden_dim, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_dim, vocab_size)
self.hidden_dim = hidden_dim
self.num_layers = num_layers
def forward(self, x):
"""
x: (None, sequence_len) LongTensor
"""
embedding = self.embedding(x) # (None, sequence_len, embedding_size)
batch_size = x.size(0)
h0, c0 = self.init_hidden(self.num_layers, batch_size, self.hidden_dim)
output, (_, _) = self.lstm(embedding, (h0, c0)) # (None, sequence_len, hidden_dim)
logits = self.linear(output) # (None, sequence_len, vocab_size)
logits = logits.transpose(1, 2) # (None, vocab_size, sequence_len)
return logits # (None, vocab_size, sequence_len)
def step(self, x, h, c):
"""
Args:
x: (batch_size, 1), sequence of tokens generated by generator
h: (1, batch_size, hidden_dim), lstm hidden state
c: (1, batch_size, hidden_dim), lstm cell state
"""
embedding = self.embedding(x) # (batch_size, 1, embedding_size)
self.lstm.flatten_parameters()
output, (_, _) = self.lstm(embedding, (h, c)) # (batch_size, 1, hidden_dim)
logits = self.linear(output).squeeze_(1) # (batch_size, vocab_size)
return logits, h, c
def sample(self, batch_size, sequence_len, x=None):
flag = False
if x is None:
x = util.to_var(torch.zeros(batch_size, 1).long())
flag = True
h, c = self.init_hidden(self.num_layers, batch_size, self.hidden_dim)
samples = []
if flag:
for _ in range(sequence_len):
logits, h, c = self.step(x, h, c)
probs = F.softmax(logits, dim=1)
sample = probs.multinomial(1) # (batch_size, 1)
samples.append(sample)
else:
given_len = x.size(1)
lis = x.chunk(x.size(1), dim=1)
for i in range(given_len):
logits, h, c = self.step(lis[i], h, c)
samples.append(lis[i])
x = F.softmax(logits, dim=1).multinomial(1)
for i in range(given_len, sequence_len):
samples.append(x)
logits, h, c = self.step(x, h, c)
x = F.softmax(logits, dim=1).multinomial(1)
output = torch.cat(samples, 1)
return output # (batch_size, sequence_len)
def init_hidden(self, num_layers, batch_size, hidden_dim):
"""
initialize h0, c0
"""
h = util.to_var(torch.zeros(num_layers, batch_size, hidden_dim))
c = util.to_var(torch.zeros(num_layers, batch_size, hidden_dim))
return h, c