-
Notifications
You must be signed in to change notification settings - Fork 9
/
utils.py
158 lines (127 loc) · 5.14 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import numpy
from numpy.testing import assert_array_equal
import csv
import os
try:
from time import perf_counter
except:
from time import time
perf_counter = time
DATA_PATH = "datasets"
def csv_2_numpy(filename, path=DATA_PATH, sep=',', type='int8'):
"""
Utility to read a dataset in csv format into a numpy array
"""
file_path = os.path.join(path, filename)
reader = csv.reader(open(file_path, "r"), delimiter=sep)
x = list(reader)
array = numpy.array(x).astype(type)
return array
def load_train_valid_test_csvs(dataset_name,
path=DATA_PATH,
sep=',',
type='int32',
suffix='data',
splits=['train',
'valid',
'test'],
verbose=True):
"""
Loading training, validation and test splits by suffix from csv files
"""
csv_files = ['{0}.{1}.{2}'.format(dataset_name, ext, suffix) for ext in splits]
# dataset_subpath = os.path.join(path, dataset_name)
load_start_t = perf_counter()
dataset_splits = [csv_2_numpy(file, path, sep, type) for file in csv_files]
load_end_t = perf_counter()
if verbose:
print('Dataset splits for {0} loaded in {1} secs'.format(dataset_name,
load_end_t - load_start_t))
for data, split in zip(dataset_splits, splits):
print('\t{0}:\t{1}'.format(split, data.shape))
return dataset_splits
def load_train_valid_test_npz(dataset_name,
path=DATA_PATH,
ext='.npz',
type='int32',
suffix='_data',
splits=['train',
'valid',
'test'],
verbose=True):
"""
Loading training, validation and test splits from the npz archive
"""
file_path = os.path.join(path, dataset_name + ext)
uncompressed_data = numpy.load(file_path)
load_start_t = perf_counter()
dataset_splits = [uncompressed_data[key + suffix].astype(type) for key in splits]
load_end_t = perf_counter()
if verbose:
print('Dataset splits for {0} loaded in {1} secs'.format(dataset_name,
load_end_t - load_start_t))
for data, split in zip(dataset_splits, splits):
print('\t{0}:\t{1}'.format(split, data.shape))
return dataset_splits
def compress_numpy_splits_made(dataset_name,
output_path,
dataset_splits,
splits=['train',
'valid',
'test'],
data_suffix='_data',
length_suffix='_length'):
output_filename = os.path.join(output_path, dataset_name)
input_sizes = numpy.array([data.shape[1] for data in dataset_splits],
dtype='int32')
#
# check them to have the same sizes
assert_array_equal(input_sizes,
numpy.array([input_sizes[0] for i in range(input_sizes.shape[0])],
dtype='int32'))
dataset_dict = {}
dataset_dict['inputsize'] = numpy.array(input_sizes[0], dtype='int32')
for data, prefix in zip(dataset_splits, splits):
dataset_dict[prefix + data_suffix] = data.astype('float32')
dataset_dict[prefix + length_suffix] = numpy.array(data.shape[0], dtype='int32')
#
# serializing
numpy.savez(output_filename, **dataset_dict)
if __name__ == '__main__':
DATASET_NAMES = ['accidents',
'ad',
'baudio',
'bbc',
'bnetflix',
'book',
'c20ng',
'cr52',
'cwebkb',
'dna',
'jester',
'kdd',
'msnbc',
'msweb',
'nltcs',
'plants',
'pumsb_star',
'tmovie',
'tretail']
NPZ_OUTPUT = '../MADE/datasets/'
for dataset in DATASET_NAMES:
print('Processing dataset', dataset)
#
# loading from csv
train, valid, test = load_train_valid_test_csvs(dataset,
path=os.path.join(DATA_PATH, dataset))
#
# compressing into npz
compress_numpy_splits_made(dataset, NPZ_OUTPUT, [train, valid, test])
#
# loading them back
train_npz, valid_npz, test_npz = load_train_valid_test_npz(dataset, NPZ_OUTPUT)
#
# check for exactnetss
assert_array_equal(train, train_npz)
assert_array_equal(valid, valid_npz)
assert_array_equal(test, test_npz)