-
Notifications
You must be signed in to change notification settings - Fork 750
/
Batch.py
62 lines (52 loc) · 2.12 KB
/
Batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from torchtext import data
import numpy as np
from torch.autograd import Variable
def nopeak_mask(size, device):
np_mask = np.triu(np.ones((1, size, size)),
k=1).astype('uint8')
variable = Variable
np_mask = variable(torch.from_numpy(np_mask) == 0)
np_mask = np_mask.cuda(device)
return np_mask
def create_masks(src, trg, device):
src_mask = (src != -1).unsqueeze(-2)
if trg is not None:
trg_mask = (trg != -1).unsqueeze(-2)
trg_mask.cuda(device)
size = trg.size(1) # get seq_len for matrix
np_mask = nopeak_mask(size, device)
trg_mask = trg_mask & np_mask
else:
trg_mask = None
return src_mask, trg_mask
# patch on Torchtext's batching process that makes it more efficient
# from http://nlp.seas.harvard.edu/2018/04/03/attention.html#position-wise-feed-forward-networks
class MyIterator(data.Iterator):
def create_batches(self):
if self.train:
def pool(d, random_shuffler):
for p in data.batch(d, self.batch_size * 100):
p_batch = data.batch(
sorted(p, key=self.sort_key),
self.batch_size, self.batch_size_fn)
for b in random_shuffler(list(p_batch)):
yield b
self.batches = pool(self.data(), self.random_shuffler)
else:
self.batches = []
for b in data.batch(self.data(), self.batch_size,
self.batch_size_fn):
self.batches.append(sorted(b, key=self.sort_key))
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
"Keep augmenting batch and calculate total number of tokens + padding."
global max_src_in_batch, max_tgt_in_batch
if count == 1:
max_src_in_batch = 0
max_tgt_in_batch = 0
max_src_in_batch = max(max_src_in_batch, len(new.src))
max_tgt_in_batch = max(max_tgt_in_batch, len(new.trg) + 2)
src_elements = count * max_src_in_batch
tgt_elements = count * max_tgt_in_batch
return max(src_elements, tgt_elements)