-
Notifications
You must be signed in to change notification settings - Fork 23
/
chatbots.py
329 lines (271 loc) · 11.4 KB
/
chatbots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
# class defintions for chatbots - questioner and answerer
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.autograd as autograd
import sys
from utilities import initializeWeights
#---------------------------------------------------------------------------
# Parent class for both q and a bots
class ChatBot(nn.Module):
def __init__(self, params):
super(ChatBot, self).__init__()
# absorb all parameters to self
for attr in params: setattr(self, attr, params[attr])
# standard initializations
self.hState = torch.Tensor()
self.cState = torch.Tensor()
self.actions = []
self.evalFlag = False
# modules (common)
self.inNet = nn.Embedding(self.inVocabSize, self.embedSize)
self.outNet = nn.Linear(self.hiddenSize, self.outVocabSize)
# initialize weights
initializeWeights([self.inNet, self.outNet], 'xavier')
# initialize hidden states
def resetStates(self, batchSize, retainActions=False):
# create tensors
self.hState = torch.Tensor(batchSize, self.hiddenSize)
self.hState.fill_(0.0)
self.hState = Variable(self.hState)
self.cState = torch.Tensor(batchSize, self.hiddenSize)
self.cState.fill_(0.0)
self.cState = Variable(self.cState)
if self.useGPU:
self.hState = self.hState.cuda()
self.cState = self.cState.cuda()
# new episode
if not retainActions:
self.actions = []
# freeze agent
def freeze(self):
for p in self.parameters(): p.requires_grad = False
# unfreeze agent
def unfreeze(self):
for p in self.parameters(): p.requires_grad = True
# given an input token, interact for the next round
def listen(self, inputToken, imgEmbed = None):
# embed and pass through LSTM
tokenEmbeds = self.inNet(inputToken)
# concat with image representation
if imgEmbed is not None:
tokenEmbeds = torch.cat((tokenEmbeds, imgEmbed), 1)
# now pass it through rnn
self.hState, self.cState = self.rnn(tokenEmbeds,
(self.hState, self.cState))
# speak a token
def speak(self):
# compute softmax and choose a token
outDistr = nn.functional.softmax(self.outNet(self.hState), dim=-1)
# if evaluating
if self.evalFlag:
_, actions = outDistr.max(1)
else:
action_sampler = torch.distributions.Categorical(outDistr)
actions = action_sampler.sample()
# record actions
self.actions.append(-action_sampler.log_prob(actions))
return actions
# reinforce each state with reward
def reinforce(self, rewards):
for index, action in enumerate(self.actions):
self.actions[index] = action * rewards
# backward computation
def performBackward(self):
sum([ii.sum() for ii in self.actions]).backward()
# switch mode to evaluate
def evaluate(self): self.evalFlag = True
# switch mode to train
def train(self): self.evalFlag = False
#---------------------------------------------------------------------------
class Answerer(ChatBot):
def __init__(self, params):
self.parent = super(Answerer, self)
# input-output for current bot
params['inVocabSize'] = params['aInVocab']
params['outVocabSize'] = params['aOutVocab']
self.parent.__init__(params)
# number of attribute values
numAttrs = sum([len(ii) for ii in self.props.values()])
# number of unique attributes
numUniqAttr = len(self.props)
# rnn inputSize
rnnInputSize = numUniqAttr * self.imgFeatSize + self.embedSize
self.imgNet = nn.Embedding(numAttrs, self.imgFeatSize)
self.rnn = nn.LSTMCell(rnnInputSize, self.hiddenSize)
initializeWeights([self.rnn, self.imgNet], 'xavier')
# set offset
self.listenOffset = params['qOutVocab']
# Embedding the image
def embedImage(self, batch):
embeds = self.imgNet(batch)
# concat instead of add
features = embeds.view(embeds.shape[0], -1)
# features = torch.cat(embeds.transpose(0, 1), 1)
# add features
#features = torch.sum(embeds, 1).squeeze(1)
return features
#---------------------------------------------------------------------------
class Questioner(ChatBot):
def __init__(self, params):
self.parent = super(Questioner, self)
# input-output for current bot
params['inVocabSize'] = params['qInVocab']
params['outVocabSize'] = params['qOutVocab']
self.parent.__init__(params)
# always condition on task
#self.rnn = nn.LSTMCell(2*self.embedSize, self.hiddenSize)
self.rnn = nn.LSTMCell(self.embedSize, self.hiddenSize)
# additional prediction network
# start token included
numPreds = sum([len(ii) for ii in self.props.values()])
# network for predicting
self.predictRNN = nn.LSTMCell(self.embedSize, self.hiddenSize)
self.predictNet = nn.Linear(self.hiddenSize, numPreds)
initializeWeights([self.predictNet, self.predictRNN, self.rnn], 'xavier')
# setting offset
self.taskOffset = params['aOutVocab'] + params['qOutVocab']
self.listenOffset = params['aOutVocab']
# make a guess the given image
def guessAttribute(self, inputEmbeds):
# compute softmax and choose a token
self.hState, self.cState = \
self.predictRNN(inputEmbeds, (self.hState, self.cState))
outDistr = nn.functional.softmax(self.predictNet(self.hState), dim=-1)
# if evaluating
if self.evalFlag: _, actions = outDistr.max(1)
else:
action_sampler = torch.distributions.Categorical(outDistr)
actions = action_sampler.sample()
# record actions
self.actions.append(-action_sampler.log_prob(actions))
return actions, outDistr
# returning the answer, from the task
def predict(self, tasks, numTokens):
guessTokens = []
guessDistr = []
for _ in range(numTokens):
# explicit task dependence
taskEmbeds = self.embedTask(tasks)
guess, distr = self.guessAttribute(taskEmbeds)
# record the guess and distribution
guessTokens.append(guess)
guessDistr.append(distr)
# return prediction
return guessTokens, guessDistr
# Embedding the task
def embedTask(self, tasks): return self.inNet(tasks + self.taskOffset)
#---------------------------------------------------------------------------
class Team:
# initialize
def __init__(self, params):
# memorize params
for field, value in params.items(): setattr(self, field, value)
self.aBot = Answerer(params)
self.qBot = Questioner(params)
self.criterion = nn.NLLLoss()
self.reward = torch.Tensor(self.batchSize)
self.totalReward = None
self.rlNegReward = -10 * self.rlScale
# ship to gpu if needed
if self.useGPU:
self.aBot = self.aBot.cuda()
self.qBot = self.qBot.cuda()
self.reward = self.reward.cuda()
print(self.aBot)
print(self.qBot)
# switch to train
def train(self):
self.aBot.train()
self.qBot.train()
# switch to evaluate
def evaluate(self):
self.aBot.evaluate()
self.qBot.evaluate()
# forward pass
def forward(self, batch, tasks, record=False):
# reset the states of the bots
batchSize = batch.size(0)
self.qBot.resetStates(batchSize)
self.aBot.resetStates(batchSize)
# get image representation
imgEmbed = self.aBot.embedImage(batch)
# ask multiple rounds of questions
aBotReply = tasks + self.qBot.taskOffset
# if the conversation is to be recorded
talk = []
for roundId in range(self.numRounds):
# listen to answer, ask q_r, and listen to q_r as well
self.qBot.listen(aBotReply)
qBotQues = self.qBot.speak()
# clone
qBotQues = qBotQues.detach()
# make this random
self.qBot.listen(self.qBot.listenOffset + qBotQues)
# Aer is memoryless, forget
if not self.remember:
self.aBot.resetStates(batchSize, True)
# listen to question and answer, also listen to answer
self.aBot.listen(qBotQues, imgEmbed)
aBotReply = self.aBot.speak()
aBotReply = aBotReply.detach()
self.aBot.listen(aBotReply + self.aBot.listenOffset, imgEmbed)
if record:
talk.extend([qBotQues, aBotReply])
# listen to the last answer
self.qBot.listen(aBotReply)
# predict the image attributes, compute reward
self.guessToken, self.guessDistr = self.qBot.predict(tasks, 2)
return self.guessToken, self.guessDistr, talk
# backward pass
def backward(self, optimizer, gtLabels, epoch, baseline=None):
# compute reward
self.reward.fill_(self.rlNegReward)
# both attributes need to match
firstMatch = self.guessToken[0].data == gtLabels[:, 0]
secondMatch = self.guessToken[1].data == gtLabels[:, 1]
self.reward[firstMatch & secondMatch] = self.rlScale
# reinforce all actions for qBot, aBot
self.qBot.reinforce(self.reward)
self.aBot.reinforce(self.reward)
# optimize
optimizer.zero_grad()
self.qBot.performBackward()
self.aBot.performBackward()
# clamp the gradients
for p in self.qBot.parameters():
p.grad.data.clamp_(min=-5., max=5.)
for p in self.aBot.parameters():
p.grad.data.clamp_(min=-5., max=5.)
# cummulative reward
batchReward = torch.mean(self.reward)/self.rlScale
if self.totalReward == None: self.totalReward = batchReward
self.totalReward = 0.95 * self.totalReward + 0.05 * batchReward
return batchReward
# loading modules from saved model
def loadModel(self, savedModel):
modules = ['rnn', 'inNet', 'outNet', 'imgNet', \
'predictRNN', 'predictNet']
# savedModel is an instance of dict
dictSaved = isinstance(savedModel['qBot'], dict)
for agentName in ['aBot', 'qBot']:
agent = getattr(self, agentName)
for module in modules:
if hasattr(agent, module):
if dictSaved: savedModule = savedModel[agentName][module]
else: savedModule = getattr(savedModel[agentName], module)
# assign to current model
setattr(agent, module, savedModule)
# saving module, at given path with params and optimizer
def saveModel(self, savePath, optimizer, params):
modules = ['rnn', 'inNet', 'outNet', 'imgNet', \
'predictRNN', 'predictNet']
toSave = {'aBot':{}, 'qBot':{}, 'params': params, 'optims':optimizer}
for agentName in ['aBot', 'qBot']:
agent = getattr(self, agentName)
for module in modules:
if hasattr(agent, module):
toSaveModule = getattr(agent, module)
toSave[agentName][module] = toSaveModule
# save checkpoint.
torch.save(toSave, savePath)