-
Notifications
You must be signed in to change notification settings - Fork 5
/
datasets.py
64 lines (49 loc) · 1.9 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from torch.utils import data
class CustomDataset(data.Dataset):
def __init__(self, dataset, indices, source_class = None, target_class = None):
self.dataset = dataset
self.indices = indices
self.source_class = source_class
self.target_class = target_class
self.contains_source_class = False
def __getitem__(self, index):
x, y = self.dataset[int(self.indices[index])][0], self.dataset[int(self.indices[index])][1]
if y == self.source_class:
y = self.target_class
return x, y
def __len__(self):
return len(self.indices)
class PoisonedDataset(data.Dataset):
def __init__(self, dataset, source_class = None, target_class = None):
self.dataset = dataset
self.source_class = source_class
self.target_class = target_class
def __getitem__(self, index):
x, y = self.dataset[index][0], self.dataset[index][1]
if y == self.source_class:
y = self.target_class
return x, y
def __len__(self):
return len(self.dataset)
class IMDBDataset:
def __init__(self, reviews, targets):
"""
Argument:
reviews: a numpy array
targets: a vector array
Return xtrain and ylabel in torch tensor datatype
"""
self.reviews = reviews
self.target = targets
def __len__(self):
# return length of dataset
return len(self.reviews)
def __getitem__(self, index):
# given an index (item), return review and target of that index in torch tensor
x = torch.tensor(self.reviews[index,:], dtype = torch.long)
y = torch.tensor(self.target[index], dtype = torch.float)
return x, y
# A method for combining datasets
def combine_datasets(list_of_datasets):
return data.ConcatDataset(list_of_datasets)