-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdatasets.py
154 lines (121 loc) · 7.63 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
#create fast data loader
class InMemDataLoader(object):
__initialized = False
def __init__(self, tensors, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, drop_last=False):
"""A torch dataloader that fetches data from memory."""
self.dataset = tensors#[torch.tensor(tensor) for tensor in tensors]
#dataset = torch.utils.data.TensorDataset(*tensors)
#self.dataset = tensors
self.batch_size = batch_size
self.drop_last = drop_last
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler option is mutually exclusive '
'with batch_size, shuffle, sampler, and '
'drop_last')
self.batch_size = None
self.drop_last = None
if sampler is not None and shuffle:
raise ValueError('sampler option is mutually exclusive with '
'shuffle')
if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = torch.utils.data.RandomSampler(self.dataset)
else:
sampler = torch.utils.data.SequentialSampler(self.dataset)
batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
self.__initialized = True
def __setattr__(self, attr, val):
if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
raise ValueError('{} attribute should not be set after {} is '
'initialized'.format(attr, self.__class__.__name__))
super(InMemDataLoader, self).__setattr__(attr, val)
def __iter__(self):
for batch_indices in self.batch_sampler:
yield self.dataset[batch_indices]
def __len__(self):
return len(self.batch_sampler)
def to(self, device):
self.dataset.tensors = tuple(t.to(device) for t in self.dataset.tensors)
return self
def getMNIST(batch_size = 32, drop_last=False):
mnist_data = torchvision.datasets.MNIST(download=False, root = 'data/mnist', transform =
transforms.Compose([transforms.Resize(32),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1))
]))
mnist_test_data = torchvision.datasets.MNIST(download=False, root = 'data/mnist', train=False, transform =
transforms.Compose([transforms.Resize(32),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1))
]))
train_loader = torch.utils.data.DataLoader(mnist_data,
batch_size=batch_size,
shuffle=False,
num_workers=16,
drop_last=drop_last)
test_loader = torch.utils.data.DataLoader(mnist_test_data,
batch_size=batch_size,
shuffle=False,
num_workers=16,
drop_last=drop_last)
return train_loader, test_loader
def getFashionMNIST(batch_size = 32, drop_last=False):
fashionmnist_data = torchvision.datasets.FashionMNIST(download=True, root = 'data/fashionmnist', transform =
transforms.Compose([transforms.Resize(32),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(1, 1, 1))
]))
fashionmnist_data_test = torchvision.datasets.FashionMNIST(download=True, root = 'data/fashionmnist', train=False, transform =
transforms.Compose([transforms.Resize(32),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(1, 1, 1))
]))
train_loader = torch.utils.data.DataLoader(fashionmnist_data,
batch_size=batch_size,
shuffle=False,
num_workers=16,
drop_last=drop_last)
test_loader = torch.utils.data.DataLoader(fashionmnist_data_test,
batch_size=batch_size,
shuffle=False,
num_workers=16,
drop_last=drop_last)
return train_loader, test_loader
def getCifar10(batch_size = 32, drop_last=False):
cifar10_data = torchvision.datasets.CIFAR10(download=True, root = 'data/cifar10', transform = transforms.Compose([
transforms.ToTensor()]))
cifar10_data_test = torchvision.datasets.CIFAR10(download=True, root = 'data/cifar10', train=False, transform = transforms.Compose([
transforms.ToTensor()]))
train_loader = torch.utils.data.DataLoader(cifar10_data,
batch_size=batch_size,
shuffle=False,
num_workers=16,
drop_last=drop_last)
test_loader = torch.utils.data.DataLoader(cifar10_data_test,
batch_size=batch_size,
shuffle=False,
num_workers=16,
drop_last=drop_last)
return train_loader, test_loader
def getDataset(dataset = "MNIST", batch_size = 32, drop_last=False):
if(dataset == "MNIST"):
train_loader, test_loader = getMNIST(batch_size, drop_last)
noChannels,dx, dy = train_loader.dataset.__getitem__(1)[0].shape
elif(dataset == "FashionMNIST"):
train_loader, test_loader = getFashionMNIST(batch_size, drop_last)
noChannels, dx, dy = train_loader.dataset.__getitem__(1)[0].shape
elif(dataset == "Cifar10"):
train_loader, test_loader = getCifar10(batch_size, drop_last)
noChannels, dx, dy = train_loader.dataset.__getitem__(1)[0].shape
else:
return None, None, None, None, None
return train_loader, test_loader, noChannels, dx, dy