-
Notifications
You must be signed in to change notification settings - Fork 7
/
data.py
56 lines (43 loc) · 1.54 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from scipy import stats
from statsmodels.distributions.empirical_distribution import ECDF
import mxnet as mx
import numpy as np
from mxnet import nd, autograd, gluon
def load_data(domain):
if domain == "mnist":
return load_mnist()
elif domain == "cifar10":
return load_cifar10()
else:
print("ERROR: INVALID DATASET SPECIFIED")
def load_mnist():
mnist = mx.test_utils.get_mnist()
num_inputs = 784
num_outputs = 10
dfeat = 784
nclass = 10
dataset = mnist
X = dataset["train_data"]
y = dataset["train_label"]
# lastly get the test set and its corresponding iterator
Xtest = dataset["test_data"]
ytest = dataset["test_label"]
return (X, y, Xtest, ytest)
def load_cifar10():
num_inputs = 3072
num_outputs = 10
dfeat = 3072
nclass = 10
def transform(data, label):
return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32)
train_DS = gluon.data.vision.CIFAR10(train=True, transform=transform)
test_DS = gluon.data.vision.CIFAR10(train=False, transform=transform)
def transform_data(data):
return nd.transpose(data.astype(np.float32), axes=(0,3,1,2))/255
def transform_label(label):
return nd.transpose(label.astype(np.float32))
X = transform_data(train_DS._data).asnumpy()
y = transform_label(nd.array(train_DS._label)).asnumpy()
Xtest = transform_data(test_DS._data).asnumpy()
ytest = transform_label(nd.array(test_DS._label)).asnumpy()
return (X, y, Xtest, ytest)