-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
97 lines (86 loc) · 3.6 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
import torch.utils.data as data_utils
import numpy as np
import sys
DATA_ROOT = './cifar10_data/'
def get_cifar10_data_loaders(batch_size=64, n_train=40000, n_val=10000,n_test=10000):
train_transform = transforms.Compose([
transforms.RandomCrop(size=32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.,0.,0.),(1.,1.,1.))
])
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.,0.,0.),(1.,1.,1.))
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.,0.,0.),(1.,1.,1.))
])
train_set = datasets.CIFAR10(root=DATA_ROOT, download=True, train=True, \
transform=train_transform)
val_set = datasets.CIFAR10(root=DATA_ROOT, download=True, train=True, \
transform=val_transform)
test_set = datasets.CIFAR10(root=DATA_ROOT, download=True, train=False, \
transform=test_transform)
# Generated as follows
# indices = np.arange(0, 50000)
# np.random.shuffle
# np.save(...)
indices = np.load(DATA_ROOT + '/CIFAR10_indices.npy')
train_sampler = SubsetRandomSampler(indices[:n_train])
val_sampler = SubsetRandomSampler(indices[n_train:])
test_sampler = SubsetRandomSampler(np.arange(n_test))
train_loader = data_utils.DataLoader(train_set, batch_size=batch_size,\
sampler=train_sampler)
val_loader = data_utils.DataLoader(val_set, batch_size=batch_size,\
sampler=val_sampler)
test_loader = data_utils.DataLoader(test_set, batch_size=batch_size,\
sampler=test_sampler)
return train_loader, val_loader, test_loader
DATA_ROOT_MNIST = './mnist_data/'
def get_mnist_data_loaders(batch_size=64, n_train=50000, n_val=10000,n_test=10000):
train_transform = transforms.Compose([
transforms.ToTensor(),
])
val_transform = transforms.Compose([
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
])
train_set = datasets.MNIST(root=DATA_ROOT_MNIST, download=True, train=True, \
transform=train_transform)
val_set = datasets.MNIST(root=DATA_ROOT_MNIST, download=True, train=True, \
transform=val_transform)
test_set = datasets.MNIST(root=DATA_ROOT_MNIST, download=True, train=False, \
transform=test_transform)
# Generated as follows
# indices = np.arange(0, 60000)
# np.random.shuffle
# np.save(...)
indices = np.load(DATA_ROOT_MNIST + '/MNIST_indices.npy')
train_sampler = SubsetRandomSampler(indices[:n_train])
val_sampler = SubsetRandomSampler(indices[n_train:])
test_sampler = SubsetRandomSampler(np.arange(n_test))
train_loader = data_utils.DataLoader(train_set, batch_size=batch_size,\
sampler=train_sampler)
val_loader = data_utils.DataLoader(val_set, batch_size=batch_size,\
sampler=val_sampler)
test_loader = data_utils.DataLoader(test_set, batch_size=batch_size,\
sampler=test_sampler)
return train_loader, val_loader, test_loader
def progress(curr, total, suffix=''):
bar_len = 48
filled = int(round(bar_len * curr / float(total)))
if filled == 0:
filled = 1
bar = '=' * (filled - 1) + '>' + '-' * (bar_len - filled)
sys.stdout.write('\r[%s] .. %s' % (bar, suffix))
sys.stdout.flush()
if curr == total:
bar = bar_len * '='
sys.stdout.write('\r[%s] .. %s .. Completed\n' % (bar, suffix))