forked from CuriousAI/mean-teacher
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
150 lines (109 loc) · 5.02 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright (c) 2018, Curious AI Ltd. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Functions to load data from folders and augment it"""
import itertools
import logging
import os.path
from PIL import Image
import numpy as np
from torch.utils.data.sampler import Sampler
LOG = logging.getLogger('main')
NO_LABEL = -1
class RandomTranslateWithReflect:
"""Translate image randomly
Translate vertically and horizontally by n pixels where
n is integer drawn uniformly independently for each axis
from [-max_translation, max_translation].
Fill the uncovered blank area with reflect padding.
"""
def __init__(self, max_translation):
self.max_translation = max_translation
def __call__(self, old_image):
xtranslation, ytranslation = np.random.randint(-self.max_translation,
self.max_translation + 1,
size=2)
xpad, ypad = abs(xtranslation), abs(ytranslation)
xsize, ysize = old_image.size
flipped_lr = old_image.transpose(Image.FLIP_LEFT_RIGHT)
flipped_tb = old_image.transpose(Image.FLIP_TOP_BOTTOM)
flipped_both = old_image.transpose(Image.ROTATE_180)
new_image = Image.new("RGB", (xsize + 2 * xpad, ysize + 2 * ypad))
new_image.paste(old_image, (xpad, ypad))
new_image.paste(flipped_lr, (xpad + xsize - 1, ypad))
new_image.paste(flipped_lr, (xpad - xsize + 1, ypad))
new_image.paste(flipped_tb, (xpad, ypad + ysize - 1))
new_image.paste(flipped_tb, (xpad, ypad - ysize + 1))
new_image.paste(flipped_both, (xpad - xsize + 1, ypad - ysize + 1))
new_image.paste(flipped_both, (xpad + xsize - 1, ypad - ysize + 1))
new_image.paste(flipped_both, (xpad - xsize + 1, ypad + ysize - 1))
new_image.paste(flipped_both, (xpad + xsize - 1, ypad + ysize - 1))
new_image = new_image.crop((xpad - xtranslation,
ypad - ytranslation,
xpad + xsize - xtranslation,
ypad + ysize - ytranslation))
return new_image
class TransformTwice:
def __init__(self, transform):
self.transform = transform
def __call__(self, inp):
out1 = self.transform(inp)
out2 = self.transform(inp)
return out1, out2
def relabel_dataset(dataset, labels):
unlabeled_idxs = []
for idx in range(len(dataset.imgs)):
path, _ = dataset.imgs[idx]
filename = os.path.basename(path)
if filename in labels:
label_idx = dataset.class_to_idx[labels[filename]]
dataset.imgs[idx] = path, label_idx
del labels[filename]
else:
dataset.imgs[idx] = path, NO_LABEL
unlabeled_idxs.append(idx)
if len(labels) != 0:
message = "List of unlabeled contains {} unknown files: {}, ..."
some_missing = ', '.join(list(labels.keys())[:5])
raise LookupError(message.format(len(labels), some_missing))
labeled_idxs = sorted(set(range(len(dataset.imgs))) - set(unlabeled_idxs))
return labeled_idxs, unlabeled_idxs
class TwoStreamBatchSampler(Sampler):
"""Iterate two sets of indices
An 'epoch' is one iteration through the primary indices.
During the epoch, the secondary indices are iterated through
as many times as needed.
"""
def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
self.primary_indices = primary_indices
self.secondary_indices = secondary_indices
self.secondary_batch_size = secondary_batch_size
self.primary_batch_size = batch_size - secondary_batch_size
assert len(self.primary_indices) >= self.primary_batch_size > 0
assert len(self.secondary_indices) >= self.secondary_batch_size > 0
def __iter__(self):
primary_iter = iterate_once(self.primary_indices)
secondary_iter = iterate_eternally(self.secondary_indices)
return (
primary_batch + secondary_batch
for (primary_batch, secondary_batch)
in zip(grouper(primary_iter, self.primary_batch_size),
grouper(secondary_iter, self.secondary_batch_size))
)
def __len__(self):
return len(self.primary_indices) // self.primary_batch_size
def iterate_once(iterable):
return np.random.permutation(iterable)
def iterate_eternally(indices):
def infinite_shuffles():
while True:
yield np.random.permutation(indices)
return itertools.chain.from_iterable(infinite_shuffles())
def grouper(iterable, n):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3) --> ABC DEF"
args = [iter(iterable)] * n
return zip(*args)