forked from fhieber/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_data.py
71 lines (65 loc) · 2.39 KB
/
get_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import mxnet as mx
def get_mnist(data_dir):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
os.chdir(data_dir)
if (not os.path.exists('train-images-idx3-ubyte')) or \
(not os.path.exists('train-labels-idx1-ubyte')) or \
(not os.path.exists('t10k-images-idx3-ubyte')) or \
(not os.path.exists('t10k-labels-idx1-ubyte')):
import urllib, zipfile
zippath = os.path.join(os.getcwd(), "mnist.zip")
urllib.urlretrieve("http://data.mxnet.io/mxnet/data/mnist.zip", zippath)
zf = zipfile.ZipFile(zippath, "r")
zf.extractall()
zf.close()
os.remove(zippath)
os.chdir("..")
def get_cifar10(data_dir):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
cwd = os.path.abspath(os.getcwd())
os.chdir(data_dir)
if (not os.path.exists('train.rec')) or \
(not os.path.exists('test.rec')) :
import urllib, zipfile, glob
dirname = os.getcwd()
zippath = os.path.join(dirname, "cifar10.zip")
urllib.urlretrieve("http://data.mxnet.io/mxnet/data/cifar10.zip", zippath)
zf = zipfile.ZipFile(zippath, "r")
zf.extractall()
zf.close()
os.remove(zippath)
for f in glob.glob(os.path.join(dirname, "cifar", "*")):
name = f.split(os.path.sep)[-1]
os.rename(f, os.path.join(dirname, name))
os.rmdir(os.path.join(dirname, "cifar"))
os.chdir(cwd)
# data
def get_cifar10_iterator(args, kv):
data_shape = (3, 28, 28)
data_dir = args.data_dir
if os.name == "nt":
data_dir = data_dir[:-1] + "\\"
if '://' not in args.data_dir:
get_cifar10(data_dir)
train = mx.io.ImageRecordIter(
path_imgrec = os.path.join(data_dir, "train.rec"),
mean_img = os.path.join(data_dir, "mean.bin"),
data_shape = data_shape,
batch_size = args.batch_size,
rand_crop = True,
rand_mirror = True,
num_parts = kv.num_workers,
part_index = kv.rank)
val = mx.io.ImageRecordIter(
path_imgrec = os.path.join(data_dir, "test.rec"),
mean_img = os.path.join(data_dir, "mean.bin"),
rand_crop = False,
rand_mirror = False,
data_shape = data_shape,
batch_size = args.batch_size,
num_parts = kv.num_workers,
part_index = kv.rank)
return (train, val)