-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdatasets.py
112 lines (103 loc) · 4.5 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import math
import torch
import torch.nn as nn
import random
import numpy as np
import torch.nn.functional as F
import argparse
import os
import shutil
import time
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data.dataset import random_split
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # THIS IS BAD! but okay for now, should pass device to dataset constructor
class CifarExpertDataset(Dataset):
def __init__(self, images, targets, expert_fn, labeled, indices = None, expert_preds = None):
"""
Original cifar dataset
images: images
targets: labels
expert_fn: expert function
labeled: indicator array if images is labeled
indices: indices in original CIFAR dataset (if this subset is subsampled)
expert_preds: used if expert_fn or have different expert model
"""
self.images = images
self.targets = np.array(targets)
self.expert_fn = expert_fn
self.labeled = np.array(labeled)
normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
self.transform_test = transforms.Compose([transforms.ToTensor(), normalize])
if expert_preds is not None:
self.expert_preds = expert_preds
else:
self.expert_preds = np.array(expert_fn(None, torch.FloatTensor(targets)))
for i in range(len(self.expert_preds)):
if self.labeled[i] == 0:
self.expert_preds[i] = -1 # not labeled by expert
if indices is not None:
self.indices = indices
else:
self.indices = np.array(list(range(len(self.targets))))
def __getitem__(self, index):
"""Take the index of item and returns the image, label, expert prediction and index in original dataset"""
label = self.targets[index]
image = self.transform_test(self.images[index])
expert_pred = self.expert_preds[index]
indice = self.indices[index]
labeled = self.labeled[index]
return torch.FloatTensor(image), label, expert_pred, indice, labeled
def __len__(self):
return len(self.targets)
class CifarExpertDatasetLinear(Dataset):
def __init__(self, images, targets, expert_fn, labeled, indices = None, model = None):
"""
Original cifar dataset
images: images
targets: labels
expert_fn: expert function
labeled: indicator array if images is labeled
indices: indices in original CIFAR dataset (if this subset is subsampled)
model: model that maps images to a vector representation
"""
self.images = images
self.targets = np.array(targets)
self.expert_fn = expert_fn
self.model = model
self.labeled = np.array(labeled)
normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
self.transform_test = transforms.Compose([transforms.ToTensor(), normalize])
self.expert_preds = np.array(expert_fn(None, torch.FloatTensor(targets)))
for i in range(len(self.expert_preds)):
if self.labeled[i] == 0:
self.expert_preds[i] = -1 # not labeled by expert
if indices != None:
self.indices = indices
else:
self.indices = np.array(list(range(len(self.targets))))
def __getitem__(self, index):
"""Take the index of item and returns the image, label, expert prediction and index in original dataset"""
label = self.targets[index]
image = self.transform_test(self.images[index])
image_repr = image.to(device)
image_repr = torch.reshape(image_repr, (1,3,32,32)).to(device)
image_repr = self.model.repr(image_repr)
image_repr = image_repr[0]
image_repr = image_repr.to(torch.device('cpu'))
expert_pred = self.expert_preds[index]
indice = self.indices[index]
labeled = self.labeled[index]
return image_repr, label, expert_pred, indice, torch.FloatTensor(image)
def __len__(self):
return len(self.targets)