-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathhdf5_to_tfrec_minerva_xtxutuvtv.py
421 lines (378 loc) · 16 KB
/
hdf5_to_tfrec_minerva_xtxutuvtv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
"""
Convert a list of hdf5s file to a list of tfrecords files of the basic types
(train, valid, test) - assumes the two-deep minerva "spacetime" hdf5 file
format.
"""
from __future__ import print_function
from collections import OrderedDict
from six.moves import range
import tensorflow as tf
import numpy as np
import sys
import os
import logging
import glob
import mnvtf.utils as utils
from mnvtf.hdf5_readers import MnvHDF5Reader
from mnvtf.data_constants import make_mnv_data_dict_from_fields
from mnvtf.data_constants import EVENT_DATA
from mnvtf.data_constants import PLANECODES, SEGMENTS
LOGGER = logging.getLogger(__name__)
def slices_maker(n, slice_size=100000):
"""
make "slices" of size `slice_size` from a file of `n` events
(so, [0, slice_size), [slice_size, 2 * slice_size), etc.)
"""
if n < slice_size:
return [(0, n)]
remainder = n % slice_size
n = n - remainder
nblocks = n // slice_size
counter = 0
slices = []
for i in range(nblocks):
end = counter + slice_size
slices.append((counter, end))
counter += slice_size
if remainder != 0:
slices.append((counter, counter + remainder))
return slices
def get_binary_data(reader, name, start_idx, stop_idx):
"""
* reader - hdf5_reader
* name of dset in the hdf5 file
* indices
returns byte data
NOTE: we must treat the 'planecodes' dataset as special - TF has some
issues with 16bit dtypes as of TF 1.2, so we must cast planecodes as 32-bit
_prior_ to byte conversion.
Note: syntax to cap a numpy array: `b[np.where(b > 5)] = 5`
"""
dta = reader.get_data(name, start_idx, stop_idx)
if name == PLANECODES:
# we must cast the 16 bit values into 32 bit values
dta = dta.astype(np.int32)
return dta.tobytes()
def write_tfrecord(
reader, data_dict, start_idx, stop_idx, tfrecord_file, compress_to_gz
):
writer = tf.python_io.TFRecordWriter(tfrecord_file)
features_dict = {}
for idx in range(start_idx, stop_idx):
for k in data_dict:
data_dict[k]['byte_data'] = get_binary_data(
reader, k, idx, idx + 1
)
if len(data_dict[k]['byte_data']) > 0:
features_dict[k] = tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[data_dict[k]['byte_data']]
)
)
example = tf.train.Example(
features=tf.train.Features(feature=features_dict)
)
writer.write(example.SerializeToString())
writer.close()
if compress_to_gz:
utils.gz_compress(tfrecord_file)
def test_read_tfrecord(
tfrecord_file, tfrec_struct, compression,
img_h, imgw_x, imgw_uv, img_depth,
n_planecodes, data_format
):
tf.reset_default_graph()
LOGGER.info('opening {} for reading'.format(tfrecord_file))
dd = utils.make_data_reader_dict(
filenames_list=[tfrecord_file],
batch_size=64,
name='test_read',
compression=compression,
img_shp=(img_h, imgw_x, imgw_uv, img_depth),
data_format=data_format,
n_planecodes=n_planecodes
)
reader_class = utils.get_reader_class(tfrec_struct)
reader = reader_class(dd)
# get an ordered dict
batch_dict = reader.batch_generator()
with tf.Session() as sess:
# have to run local variable init for `string_input_producer`
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
for batch_num in range(10):
tensor_list = sess.run(batch_dict.values())
results = OrderedDict(zip(batch_dict.keys(), tensor_list))
LOGGER.info('batch = {}'.format(batch_num))
for k, v in results.items():
LOGGER.info('{} shape = {}'.format(k, v.shape))
if 'hitimes' not in k:
if k == PLANECODES or k == SEGMENTS:
LOGGER.info(' {} = {}'.format(
k, np.argmax(v, axis=1)
))
else:
LOGGER.info(' {} = {}'.format(k, v))
except tf.errors.OutOfRangeError:
LOGGER.info('Reading stopped - queue is empty.')
except Exception as e:
LOGGER.info(e)
finally:
coord.request_stop()
coord.join(threads)
def write_all(
n_events_per_tfrecord_triplet, max_triplets, file_num_start,
hdf5_file, train_file_pat, valid_file_pat, test_file_pat,
train_fraction, valid_fraction, dry_run, compress_to_gz,
file_num_start_write, tfrec_struct
):
# todo, make this a while loop that keeps making tf record files
# until we run out of events in the hdf5, then pass back the
# file number we stopped on
LOGGER.info('opening hdf5 file {} for file start number {}'.format(
hdf5_file, file_num_start
))
m = MnvHDF5Reader(hdf5_file)
m.open()
n_total = m.get_nevents(group=EVENT_DATA)
slcs = slices_maker(n_total, n_events_per_tfrecord_triplet)
n_processed = 0
new_files = []
for i, slc in enumerate(slcs):
file_num = i + file_num_start
n_slc = slc[-1] - slc[0]
n_train = int(n_slc * train_fraction)
n_valid = int(n_slc * valid_fraction)
n_test = n_slc - n_train - n_valid
train_start, train_stop = n_processed, n_processed + n_train
valid_start, valid_stop = train_stop, train_stop + n_valid
test_start, test_stop = valid_stop, valid_stop + n_test
train_file = train_file_pat % file_num
valid_file = valid_file_pat % file_num
test_file = test_file_pat % file_num
LOGGER.info("slice {}, {} total events".format(i, n_slc))
LOGGER.info(
"slice {}, {} train events, [{}-{}): {}".format(
i, n_train, train_start, train_stop, train_file)
)
LOGGER.info(
"slice {}, {} valid events, [{}-{}): {}".format(
i, n_valid, valid_start, valid_stop, valid_file)
)
LOGGER.info(
"slice {}, {} test events, [{}-{}): {}".format(
i, n_test, test_start, test_stop, test_file)
)
if not dry_run and file_num >= file_num_start_write:
# clean up existing files in the slice
for filename in [train_file, valid_file, test_file]:
if compress_to_gz:
check_file = filename + '.gz'
if os.path.isfile(check_file):
LOGGER.info(
'found existing tfrecord file {}, removing...'.format(
check_file
)
)
os.remove(check_file)
list_of_fields = utils.get_fields_list(tfrec_struct)
data_dict = make_mnv_data_dict_from_fields(
list_of_fields=list_of_fields
)
# events included are [start, stop)
if n_train > 0:
LOGGER.info('creating train file...')
write_tfrecord(
m, data_dict, train_start, train_stop,
train_file, compress_to_gz
)
new_files.append(
train_file + '.gz' if compress_to_gz else train_file
)
if n_valid > 0:
LOGGER.info('creating valid file...')
write_tfrecord(
m, data_dict, valid_start, valid_stop,
valid_file, compress_to_gz
)
new_files.append(
valid_file + '.gz' if compress_to_gz else valid_file
)
if n_test > 0:
LOGGER.info('creating test file...')
write_tfrecord(
m, data_dict, test_start, test_stop,
test_file, compress_to_gz
)
new_files.append(
test_file + '.gz' if compress_to_gz else test_file
)
if (max_triplets > 0) and (len(new_files) / 3 >= max_triplets):
break
n_processed += n_slc
LOGGER.info("Processed {} events, finished with file number {}".format(
n_processed, file_num
))
m.close()
return file_num, new_files
def read_all(
files_written, tfrec_struct, dry_run, compressed, imgw_x, imgw_uv,
n_planecodes, img_h=127, img_depth=2, data_format='NHWC'
):
LOGGER.info('reading files...')
for filename in files_written:
if os.path.isfile(filename):
LOGGER.info(
'found existing tfrecord file {} with size {}...'.format(
filename, os.stat(filename).st_size
)
)
if not dry_run:
compression = 'gz' if compressed else ''
test_read_tfrecord(
filename, tfrec_struct, compression,
img_h, imgw_x, imgw_uv, img_depth,
n_planecodes, data_format
)
if __name__ == '__main__':
def arg_list_split(option, opt, value, parser):
setattr(parser.values, option.dest, value.split(','))
from optparse import OptionParser
parser = OptionParser(usage=__doc__)
parser.add_option('-l', '--file_list', dest='file_list',
help='HDF5 file list (csv, full paths)',
metavar='FILELIST', type='string', action='callback',
callback=arg_list_split)
parser.add_option('-p', '--input_file_pattern', dest='input_file_pattern',
help='Input file pattern', metavar='FILEPATTERN',
default=None, type='string')
parser.add_option('-i', '--in_dir', dest='in_dir',
help='In directory (for file patterns)',
metavar='IN_DIR', default=None, type='string')
parser.add_option('-o', '--out_dir', dest='out_dir',
help='Out directory', metavar='OUT_DIR', default=None,
type='string')
parser.add_option('-n', '--nevents', dest='n_events', default=0,
help='Number of events per file', metavar='N_EVENTS',
type='int')
parser.add_option('-s', '--start_idx', dest='start_idx', default=0,
help='Start writing to disk at index value',
metavar='START_IDX', type='int')
parser.add_option('-m', '--max_triplets', dest='max_triplets', default=0,
help='Max number of each file type',
metavar='MAX_TRIPLETS', type='int')
parser.add_option('-r', '--test_read', dest='do_test', default=False,
help='Test read', metavar='DO_TEST',
action='store_true')
parser.add_option('-d', '--dry_run', dest='dry_run', default=False,
help='Dry run for write', metavar='DRY_RUN',
action='store_true')
parser.add_option('-c', '--compress_to_gz', dest='compress_to_gz',
default=False, help='Gzip compression',
metavar='COMPRESS_TO_GZ', action='store_true')
parser.add_option('-t', '--train_fraction', dest='train_fraction',
default=0.88, help='Train fraction',
metavar='TRAIN_FRAC', type='float')
parser.add_option('-v', '--valid_fraction', dest='valid_fraction',
default=0.09, help='Valid fraction',
metavar='VALID_FRAC', type='float')
parser.add_option('-g', '--logfile', dest='logfilename',
help='Log file name', metavar='LOGFILENAME',
default=None, type='string')
parser.add_option('--imgw_x', dest='imgw_x', default=94,
help='Image width-x', metavar='IMGWX',
type='int')
parser.add_option('--imgw_uv', dest='imgw_uv', default=47,
help='Image width-uv', metavar='IMGWUV',
type='int')
parser.add_option('--n_planecodes', dest='n_planecodes', default=173,
help='Number (count) of planecodes', metavar='NPCODES',
type='int')
parser.add_option('--tfrec_struct', dest='tfrec_struct',
default=None, help='TFRecord structure',
metavar='TFRECSTRUCT', type='string')
parser.add_option('--playlist', dest='playlist',
default='UNK', help='Playlist label',
metavar='PLAYLIST', type='string')
(options, args) = parser.parse_args()
if (not options.file_list) and (not options.input_file_pattern):
print("\nSpecify file list or file pattern:\n\n")
print(__doc__)
sys.exit(1)
if (options.train_fraction + options.valid_fraction) > 1.001:
print("\nTraining and validation fractions sum > 1!")
print(__doc__)
sys.exit(1)
logfilename = options.logfilename or \
'log_hdf5_to_tfrec_minerva_xtxutuvtv.txt'
logging.basicConfig(
filename=logfilename, level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
LOGGER.info("Starting...")
LOGGER.info(__file__)
LOGGER.info(' in-directory contents: {}'.format(
','.join(os.listdir(options.in_dir))
))
files = options.file_list or []
for ext in ['*.hdf5', '*.h5']:
extra_files = glob.glob(
options.in_dir + '/' + options.input_file_pattern + ext
)
files.extend(extra_files)
# kill any repeats
files = list(set(files))
files.sort()
LOGGER.info("Datasets:")
dataset_statsinfo = 0
LOGGER.info(' files: {}'.format(
','.join([str(f) for f in files])
))
for hdf5_file in files:
fsize = os.stat(hdf5_file).st_size
dataset_statsinfo += os.stat(hdf5_file).st_size
LOGGER.info(" {}, size = {}".format(hdf5_file, fsize))
LOGGER.info("Total dataset size: {}".format(dataset_statsinfo))
# loop over list of hdf5 files (glob for patterns?), for each file, create
# tfrecord files of specified size, putting remainders in new files.
file_num = 0
for i, hdf5_file in enumerate(files):
# base name for output files:
# NOTE - this includes the path, so don't put '.' or '..' in the dirs!
base_name = options.tfrec_struct + '_127x' + str(options.imgw_x) + \
'_' + options.playlist
# base_name = hdf5_file.split('/')[-1]
# base_name = options.out_dir + '/' + base_name.split('.')[0]
base_name = options.out_dir + '/' + base_name
# create file patterns to fill tfrecord files by number
train_file_pat = base_name + '_%06d_train.tfrecord'
valid_file_pat = base_name + '_%06d_valid.tfrecord'
test_file_pat = base_name + '_%06d_test.tfrecord'
# TODO: add a 'mask' that properly trims imgw_x/uv, caps the pcodes
out_num, files_written = write_all(
n_events_per_tfrecord_triplet=options.n_events,
max_triplets=options.max_triplets, file_num_start=file_num,
hdf5_file=hdf5_file, train_file_pat=train_file_pat,
valid_file_pat=valid_file_pat, test_file_pat=test_file_pat,
train_fraction=options.train_fraction,
valid_fraction=options.valid_fraction,
dry_run=options.dry_run, compress_to_gz=options.compress_to_gz,
file_num_start_write=options.start_idx,
tfrec_struct=options.tfrec_struct
)
file_num = out_num + 1
if options.do_test:
read_all(
files_written,
options.tfrec_struct,
options.dry_run,
options.compress_to_gz,
options.imgw_x,
options.imgw_uv,
options.n_planecodes
)
if (options.max_triplets > 0) and \
(len(files_written) / 3 > options.max_triplets):
break