-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataprovider.py
273 lines (212 loc) · 8.74 KB
/
dataprovider.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import sys, json
from six.moves import urllib
import tensorflow as tf
import matplotlib
FLAGS = tf.app.flags.FLAGS
# Size of the crop:
IMAGE_SIZE = 224
TRAIN_FILE = 'train.tfrecords'
VALIDATION_FILE = 'validation.tfrecords'
TEST_FILE = 'test.tfrecords'
def getsize(filename):
jsfile = filename + ".json"
if tf.gfile.Exists(jsfile):
with open(jsfile, 'r') as f:
N = json.load(f)['count']
else:
N = 0
for record in tf.python_io.tf_record_iterator(filename):
N += 1
with open(jsfile, 'w') as f:
f.write(json.dumps({'count': N}))
return N
def read_record(filename_queue):
"""Reads TF record.
Args:
filename_queue: A queue of strings with the filenames to read from.
Returns:
An object representing a single example, with the following fields:
height: number of rows in the result
width: number of columns in the result
depth: number of color channels in the result (3)
label: an int64 Tensor with the image label.
id: an int64 Tensor with the image id.
image: a [height, width, depth] float64 Tensor with the image data
"""
class DataRecord(object):
pass
result = DataRecord()
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([], tf.int64),
'id': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string)
})
image = tf.decode_raw(features['image_raw'], tf.float64)
result.height = tf.cast(features['height'], tf.int32)
result.width = tf.cast(features['width'], tf.int32)
result.depth = tf.cast(features['depth'], tf.int32)
im_shape = tf.stack([result.height, result.width, result.depth])
result.image = tf.reshape(image, im_shape)
result.label = features['label']
result.id = features['id']
return result
def _generate_batch(L, AB, label, min_queue_examples,
batch_size, shuffle, size=0):
"""Construct a queued batch.
Args:
L: 3-D Tensor of [height, width, 1] of type.float32.
AB: 3-D Tensor of [height, width, 2] of type.float32.
label: 1-D Tensor of type.int32
min_queue_examples: int32, minimum number of samples to retain
in the queue that provides of batches of examples.
batch_size: Number of images per batch.
shuffle: boolean indicating whether to use a shuffling queue.
Returns:
images: Images. 4D tensor of [batch_size, height, width, 3] size.
labels: labels. 2D tensor of [batch_size, 1] size.
"""
# Create a queue that shuffles the examples, and then
# read 'batch_size' images + labels from the example queue.
num_preprocess_threads = 16
if shuffle:
Ls, ABs, labels = tf.train.shuffle_batch(
[L, AB, label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples)
else:
Ls, ABs, labels = tf.train.batch(
[L, AB, label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size)
if size > 0:
return Ls, ABs, labels, size
else:
return Ls, ABs, labels
def distorted_inputs(get_size=False):
"""Construct distorted input training using the Reader ops.
Args:
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: labels. 2D tensor of [batch_size, 1] size.
"""
data_dir = FLAGS.data_dir
batch_size = FLAGS.batch_size
filename = os.path.join(data_dir, TRAIN_FILE)
if get_size:
N = getsize(filename)
filename_queue = tf.train.string_input_producer([filename])
read_input = read_record(filename_queue)
height = IMAGE_SIZE
width = IMAGE_SIZE
# Randomly crop a [height, width] section of the image.
image = tf.random_crop(read_input.image, [height, width, 3])
# Randomly flip the image horizontally.
image = tf.image.random_flip_left_right(image)
image = tf.cast(image, tf.float32)
L,A,B = tf.split(axis=2, num_or_size_splits=3, value=image)
# FIXME: clipping isn't very nice but anything clipped should still fall in the same bin as non-clipped
# without clipping the quantization in the loss should probably be more proper
AB = tf.clip_by_value(tf.concat(values=[A,B], axis=2), -1.0, 1.0)
# delta of 0.016 should give a CIEDE2000 of less than 1
L = tf.image.random_brightness(L, 0.016)
# Ensure that the random shuffling has good mixing properties.
min_fraction_of_examples_in_queue = 0.3
if get_size:
min_queue_examples = max(100, int(min(10000, N) *
min_fraction_of_examples_in_queue))
else:
min_queue_examples = int(10000 *
min_fraction_of_examples_in_queue)
print ('Filling queue with %d images before starting to train. '
'This might take a few minutes.' % min_queue_examples)
# Generate a batch by building up a queue of examples.
return _generate_batch(L, AB, read_input.label,
min_queue_examples, batch_size,
shuffle=True, size=N)
def inputs(get_size=False):
"""Construct input for evaluation using the Reader ops.
Args:
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: labels. 2D tensor of [batch_size, 1] size.
"""
data_dir = FLAGS.data_dir
batch_size = FLAGS.batch_size
filename = os.path.join(data_dir, VALIDATION_FILE)
N = 0
if get_size:
N = getsize(filename)
filename_queue = tf.train.string_input_producer([filename])
read_input = read_record(filename_queue)
height = IMAGE_SIZE
width = IMAGE_SIZE
# Image processing for evaluation.
# Crop the central [height, width] of the image.
image = tf.image.resize_image_with_crop_or_pad(read_input.image,
width, height)
image = tf.reshape(image, (width,height,3))
#image = tf.random_crop(read_input.image, [height, width, 3])
image = tf.cast(image, tf.float32)
L,A,B = tf.split(axis=2, num_or_size_splits=3, value=image)
AB = tf.clip_by_value(tf.concat(values=[A,B], axis=2), -1.0, 1.0)
# Ensure that the random shuffling has good mixing properties.
min_fraction_of_examples_in_queue = 0.4
if get_size:
min_queue_examples = max(100, int(min(10000, N) *
min_fraction_of_examples_in_queue))
else:
min_queue_examples = int(10000 *
min_fraction_of_examples_in_queue)
# Generate a batch by building up a queue of examples.
return _generate_batch(L, AB, read_input.label,
min_queue_examples, batch_size,
shuffle=False, size=N)
def input(source = 'val', get_size=False):
"""Construct input for evaluation using the Reader ops.
Args:
Returns:
image: Image. 4D tensor of [1, IMAGE_SIZE, IMAGE_SIZE, 3] size.
label: label. 1D tensor of [1] size.
"""
data_dir = FLAGS.data_dir
batch_size = FLAGS.batch_size
sourcefile = VALIDATION_FILE
if source.lower() == 'train':
sourcefile = TRAIN_FILE
elif source.lower() == 'test':
sourcefile = TEST_FILE
filename = os.path.join(data_dir, sourcefile)
N = 0
if get_size:
N = getsize(filename)
filename_queue = tf.train.string_input_producer([filename])
read_input = read_record(filename_queue)
height = IMAGE_SIZE
width = IMAGE_SIZE
#image = tf.random_crop(read_input.image, [height, width, 3])
image = tf.image.resize_image_with_crop_or_pad(read_input.image,
width, height)
image = tf.reshape(image, (1, height, width, 3))
image = tf.cast(image, tf.float32)
L,A,B = tf.split(axis=3, num_or_size_splits=3, value=image)
AB = tf.clip_by_value(tf.concat(values=[A,B],axis=3), -1.0, 1.0)
label = tf.reshape(read_input.label, (1,))
if get_size:
return L, AB, label, N
else:
return L, AB, label