-
Notifications
You must be signed in to change notification settings - Fork 18
/
dataset.py
75 lines (61 loc) · 2.26 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
class Dataset(object):
def __init__(self, filepath, mode=None, imsize=None):
if mode == 'train':
data = np.load(filepath).items()[0][1][0]
elif mode == 'test':
data = np.load(filepath).items()[0][1][1]
else:
raise ValueError('mode can be either train or test.')
self._num_examples = data.shape[0]
self._labels = data[:, 0]
self._s1 = data[:, 1]
self._s2 = data[:, 2]
self._images = data[:, 3:]
if imsize is not None: # For Convolutions
self._images = self._images.reshape([self._num_examples, imsize, imsize, -1])
self._epochs_completed = 0
self._index_in_epoch = 0
@property
def images(self):
return self._images
@property
def s1(self):
return self._s1
@property
def s2(self):
return self._s2
@property
def labels(self):
return self._labels
@property
def num_examples(self):
return self._num_examples
@property
def epochs_completed(self):
return self._epochs_completed
def next_batch(self, batch_size):
"""Return next 'batch_size' examples from this data set.
"""
# Check: batch size should not exceed the size of dataset
assert batch_size <= self._num_examples
# Initial index for slicing
start = self._index_in_epoch
self._index_in_epoch += batch_size
# Not enough data for a batch: Reset + Shuffling
if self._index_in_epoch > self._num_examples:
# Increment finished epoch
self._epochs_completed += 1
# Shuffule the data
perm = np.arange(self._num_examples)
np.random.shuffle(perm)
self._images = self._images[perm]
self._s1 = self._s1[perm]
self._s2 = self._s2[perm]
self._labels = self._labels[perm]
# Start next epoch
start = 0
self._index_in_epoch = batch_size
# End index for slicing
end = self._index_in_epoch
return self._images[start:end], self._s1[start:end], self._s2[start:end], self._labels[start:end]