-
Notifications
You must be signed in to change notification settings - Fork 123
/
data_loader.py
85 lines (66 loc) · 2.98 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import tensorflow as tf
from . import cyclegan_datasets
from . import model
def _load_samples(csv_name, image_type):
filename_queue = tf.train.string_input_producer(
[csv_name])
reader = tf.TextLineReader()
_, csv_filename = reader.read(filename_queue)
record_defaults = [tf.constant([], dtype=tf.string),
tf.constant([], dtype=tf.string)]
filename_i, filename_j = tf.decode_csv(
csv_filename, record_defaults=record_defaults)
file_contents_i = tf.read_file(filename_i)
file_contents_j = tf.read_file(filename_j)
if image_type == '.jpg':
image_decoded_A = tf.image.decode_jpeg(
file_contents_i, channels=model.IMG_CHANNELS)
image_decoded_B = tf.image.decode_jpeg(
file_contents_j, channels=model.IMG_CHANNELS)
elif image_type == '.png':
image_decoded_A = tf.image.decode_png(
file_contents_i, channels=model.IMG_CHANNELS, dtype=tf.uint8)
image_decoded_B = tf.image.decode_png(
file_contents_j, channels=model.IMG_CHANNELS, dtype=tf.uint8)
return image_decoded_A, image_decoded_B
def load_data(dataset_name, image_size_before_crop,
do_shuffle=True, do_flipping=False):
"""
:param dataset_name: The name of the dataset.
:param image_size_before_crop: Resize to this size before random cropping.
:param do_shuffle: Shuffle switch.
:param do_flipping: Flip switch.
:return:
"""
if dataset_name not in cyclegan_datasets.DATASET_TO_SIZES:
raise ValueError('split name %s was not recognized.'
% dataset_name)
csv_name = cyclegan_datasets.PATH_TO_CSV[dataset_name]
image_i, image_j = _load_samples(
csv_name, cyclegan_datasets.DATASET_TO_IMAGETYPE[dataset_name])
inputs = {
'image_i': image_i,
'image_j': image_j
}
# Preprocessing:
inputs['image_i'] = tf.image.resize_images(
inputs['image_i'], [image_size_before_crop, image_size_before_crop])
inputs['image_j'] = tf.image.resize_images(
inputs['image_j'], [image_size_before_crop, image_size_before_crop])
if do_flipping is True:
inputs['image_i'] = tf.image.random_flip_left_right(inputs['image_i'])
inputs['image_j'] = tf.image.random_flip_left_right(inputs['image_j'])
inputs['image_i'] = tf.random_crop(
inputs['image_i'], [model.IMG_HEIGHT, model.IMG_WIDTH, 3])
inputs['image_j'] = tf.random_crop(
inputs['image_j'], [model.IMG_HEIGHT, model.IMG_WIDTH, 3])
inputs['image_i'] = tf.subtract(tf.div(inputs['image_i'], 127.5), 1)
inputs['image_j'] = tf.subtract(tf.div(inputs['image_j'], 127.5), 1)
# Batch
if do_shuffle is True:
inputs['images_i'], inputs['images_j'] = tf.train.shuffle_batch(
[inputs['image_i'], inputs['image_j']], 1, 5000, 100)
else:
inputs['images_i'], inputs['images_j'] = tf.train.batch(
[inputs['image_i'], inputs['image_j']], 1)
return inputs