forked from chanberg/atmt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
translate_beam.py
253 lines (206 loc) · 12.3 KB
/
translate_beam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import os
import logging
import argparse
import numpy as np
from tqdm import tqdm
import torch
from torch.serialization import default_restore_location
from seq2seq import models, utils
from seq2seq.data.dictionary import Dictionary
from seq2seq.data.dataset import Seq2SeqDataset, BatchSampler
from seq2seq.beam import BeamSearch, BeamSearchNode
from collections import defaultdict
def get_args():
""" Defines generation-specific hyper-parameters. """
parser = argparse.ArgumentParser('Sequence to Sequence Model')
parser.add_argument('--cuda', default=False, help='Use a GPU')
parser.add_argument('--seed', default=42, type=int, help='pseudo random number generator seed')
# Add data arguments
parser.add_argument('--data', default='data_asg4/prepared_data', help='path to data directory')
parser.add_argument('--checkpoint-path', default='checkpoints_asg4/checkpoint_best.pt', help='path to the model file')
parser.add_argument('--batch-size', default=None, type=int, help='maximum number of sentences in a batch')
parser.add_argument('--output', default='model_translations.txt', type=str,
help='path to the output file destination')
parser.add_argument('--max-len', default=100, type=int, help='maximum length of generated sequence')
# Add beam search arguments
parser.add_argument('--beam-size', default=5, type=int, help='number of hypotheses expanded in beam search')
parser.add_argument('--alpha', default=0., type=float, help='parameter that controls length normalization')
parser.add_argument('--gamma', default=0., type=float, help='parameter that controls diversity reranking')
parser.add_argument('--n-best', default=1, type=int, help='n-best translations to be returned per source sentence')
return parser.parse_args()
def main(args):
""" Main translation function' """
# Load arguments from checkpoint
torch.manual_seed(args.seed)
state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
args_loaded = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])})
args_loaded.data = args.data
args = args_loaded
utils.init_logging(args)
# Load dictionaries
src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict)))
tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict)))
# Load dataset
test_dataset = Seq2SeqDataset(
src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
src_dict=src_dict, tgt_dict=tgt_dict)
test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater,
batch_sampler=BatchSampler(test_dataset, 9999999,
args.batch_size, 1, 0, shuffle=False,
seed=args.seed))
# Build model and criterion
model = models.build_model(args, src_dict, tgt_dict)
if args.cuda:
model = model.cuda()
model.eval()
model.load_state_dict(state_dict['model'])
logging.info('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path))
progress_bar = tqdm(test_loader, desc='| Generation', leave=False)
# Iterate over the test set
# for task 1 and 3:
# all_hyps = {}
# for task 4: change all_hyps to a dictionary containing lists as values
all_hyps = defaultdict(list)
for i, sample in enumerate(progress_bar):
# Create a beam search object or every input sentence in batch
batch_size = sample['src_tokens'].shape[0]
searches = [BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx) for i in range(batch_size)]
with torch.no_grad():
# Compute the encoder output
encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths'])
# __QUESTION 1: What is "go_slice" used for and what do its dimensions represent?
go_slice = \
torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens'])
if args.cuda:
go_slice = utils.move_to_cuda(go_slice)
# Compute the decoder output at the first time step
decoder_out, _ = model.decoder(go_slice, encoder_out)
# __QUESTION 2: Why do we keep one top candidate more than the beam size?
log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)),
args.beam_size+1, dim=-1)
# Create number of beam_size beam search nodes for every input sentence
for i in range(batch_size):
for j in range(args.beam_size):
best_candidate = next_candidates[i, :, j]
backoff_candidate = next_candidates[i, :, j+1]
best_log_p = log_probs[i, :, j]
backoff_log_p = log_probs[i, :, j+1]
next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate)
log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p)
log_p = log_p[-1]
# Store the encoder_out information for the current input sentence and beam
emb = encoder_out['src_embeddings'][:,i,:]
lstm_out = encoder_out['src_out'][0][:,i,:]
final_hidden = encoder_out['src_out'][1][:,i,:]
final_cell = encoder_out['src_out'][2][:,i,:]
try:
mask = encoder_out['src_mask'][i,:]
except TypeError:
mask = None
node = BeamSearchNode(searches[i], emb, lstm_out, final_hidden, final_cell,
mask, torch.cat((go_slice[i], next_word)), log_p, 1)
# __QUESTION 3: Why do we add the node with a negative score?
# for task 3: add args.alpha as argument
# for task 4: add args.alpha, args.gamma and 1 (=rank) as arguments
searches[i].add(-node.eval(args.alpha, args.gamma, 1), node)
# Start generating further tokens until max sentence length reached
for _ in range(args.max_len-1):
# Get the current nodes to expand
nodes = [n[1] for s in searches for n in s.get_current_beams()]
if nodes == []:
break # All beams ended in EOS
# Reconstruct prev_words, encoder_out from current beam search nodes
prev_words = torch.stack([node.sequence for node in nodes])
encoder_out["src_embeddings"] = torch.stack([node.emb for node in nodes], dim=1)
lstm_out = torch.stack([node.lstm_out for node in nodes], dim=1)
final_hidden = torch.stack([node.final_hidden for node in nodes], dim=1)
final_cell = torch.stack([node.final_cell for node in nodes], dim=1)
encoder_out["src_out"] = (lstm_out, final_hidden, final_cell)
try:
encoder_out["src_mask"] = torch.stack([node.mask for node in nodes], dim=0)
except TypeError:
encoder_out["src_mask"] = None
with torch.no_grad():
# Compute the decoder output by feeding it the decoded sentence prefix
decoder_out, _ = model.decoder(prev_words, encoder_out)
# see __QUESTION 2
log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)), args.beam_size+1, dim=-1)
# Create number of beam_size next nodes for every current node
for i in range(log_probs.shape[0]):
for j in range(args.beam_size):
best_candidate = next_candidates[i, :, j]
backoff_candidate = next_candidates[i, :, j+1]
best_log_p = log_probs[i, :, j]
backoff_log_p = log_probs[i, :, j+1]
next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate)
log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p)
log_p = log_p[-1]
next_word = torch.cat((prev_words[i][1:], next_word[-1:]))
# Get parent node and beam search object for corresponding sentence
node = nodes[i]
search = node.search
# __QUESTION 4: How are "add" and "add_final" different? What would happen if we did not make this distinction?
# Store the node as final if EOS is generated
if next_word[-1 ] == tgt_dict.eos_idx:
node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden,
node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]),
next_word)), node.logp, node.length)
# for task 3: add args.alpha as argument
# for task 4: add args.alpha, args.gamma, j (=rank) as argument
search.add_final(-node.eval(args.alpha, args.gamma, j), node)
# Add the node to current nodes for next iteration
else:
node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden,
node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]),
next_word)), node.logp + log_p, node.length + 1)
# for task 3: add args.alpha as argument
# for task 4: add args.alpha, args.gamma, j (=rank) as argument
search.add(-node.eval(args.alpha, args.gamma, j), node)
# __QUESTION 5: What happens internally when we prune our beams?
# How do we know we always maintain the best sequences?
for search in searches:
search.prune()
# Segment into sentences
# for task 3:
#best_sents = torch.stack([search.get_best()[1].sequence[1:].cpu() for search in searches])
# for task 4: adapt best_sents in order to contain the n-best translations for each source sentence
best_sents = torch.stack([node[1].sequence[1:].cpu() for search in searches for node in search.get_best(args.n_best)])
decoded_batch = best_sents.numpy()
output_sentences = [decoded_batch[row, :] for row in range(decoded_batch.shape[0])]
# __QUESTION 6: What is the purpose of this for loop?
temp = list()
for sent in output_sentences:
first_eos = np.where(sent == tgt_dict.eos_idx)[0]
if len(first_eos) > 0:
temp.append(sent[:first_eos[0]])
else:
temp.append(sent)
output_sentences = temp
# Convert arrays of indices into strings of words
output_sentences = [tgt_dict.string(sent) for sent in output_sentences]
# for task 3:
# for ii, sent in enumerate(output_sentences):
# all_hyps[int(sample['id'].data[ii])] = sent
# for task 4:
# all_hyps is now a defaultdict containing lists: append the n-best sentences to the list
# increase the index only every n-th time to prevent IndexError
idx = -1
for ii, sent in enumerate(output_sentences):
if ii % args.n_best == 0:
idx += 1
all_hyps[int(sample['id'].data[idx])].append(sent)
# Write to file
if args.output is not None:
with open(args.output, 'w') as out_file:
for sent_id in range(len(all_hyps.keys())):
# for task 3:
# out_file.write(all_hyps[sent_id] + '\n')
# for task 4: iterate through n-best list
for sent in all_hyps[sent_id]:
out_file.write(sent + '\n')
if __name__ == '__main__':
args = get_args()
main(args)