-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
72 lines (54 loc) · 2.27 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
class Data(object):
"""Docstring for Data."""
def __init__(
self, features: torch.FloatTensor, target_word: torch.LongTensor, distractors: [torch.FloatTensor] = None
):
"""A data object (single or batch)
:features: preprocessed shape angle features
:target_word: preprocessed target word
:distractors: preprocessed distractors
"""
assert features.ndim
self.features = features
self.target_word = target_word
self.distractors = distractors
def __repr__(self):
s = 'Data(\n'
s += f'\tfeatures={self.features}\n'
s += f'\ttarget_word={self.target_word}\n'
s += f'\tdistractors={self.distractors}\n'
if self.distractors is not None:
s += f'\tdistractors.shape={self.distractors.shape}\n'
s += ')'
return s
def cuda(self):
self.features = self.features.cuda()
if self.target_word is not None:
self.target_word = self.target_word.cuda()
if self.distractors is not None:
self.distractors = self.distractors.cuda()
return self
def __len__(self):
return self.features.size(0)
def collate(examples):
features = torch.stack([ex.features for ex in examples])
target_word = torch.stack([ex.target_word for ex in examples])
return Data(features, target_word, distractors=None)
def collate_with_distractors(examples):
features = torch.stack([ex.features for ex in examples])
target_word = torch.stack([ex.target_word for ex in examples])
distractors = torch.stack([ex.distractors for ex in examples])
return Data(features, target_word, distractors=distractors)
def collate_only_features(examples):
features = torch.stack([ex.features for ex in examples])
return Data(features, target_word=None, distractors=None)
class DataLoader(torch.utils.data.DataLoader):
"""DataLoader subclass with our collate function as default"""
def __init__(self, *args, **kwargs):
"""Init dataloader
:*args: passed to dataloader
:**kwargs: passed to dataloader
"""
collate_fn = kwargs.pop('collate_fn', collate)
torch.utils.data.DataLoader.__init__(self, *args, **kwargs, collate_fn=collate_fn)