-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_imagenet.py
214 lines (185 loc) · 8.65 KB
/
train_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
#!/usr/bin/env python
"""Example code of learning a large scale convnet from ILSVRC2012 dataset.
Prerequisite: To run this example, crop the center of ILSVRC2012 training and
validation images, scale them to 256x256 and convert them to RGB, and make
two lists of space-separated CSV whose first column is full path to image and
second column is zero-origin label (this format is same as that used by Caffe's
ImageDataLayer).
"""
import argparse
import random
import numpy as np
import chainer
from chainer import dataset
from chainer import training
from chainer.training import extensions
import chainerx
import resnet50
import dali_util
class PreprocessedDataset(chainer.dataset.DatasetMixin):
def __init__(self, path, root, mean, crop_size, random=True):
self.base = chainer.datasets.LabeledImageDataset(path, root)
self.mean = mean.astype(chainer.get_dtype())
self.crop_size = crop_size
self.random = random
def __len__(self):
return len(self.base)
def get_example(self, i):
# It reads the i-th image/label pair and return a preprocessed image.
# It applies following preprocesses:
# - Cropping (random or center rectangular)
# - Random flip
# - Scaling to [0, 1] value
crop_size = self.crop_size
image, label = self.base[i]
_, h, w = image.shape
if self.random:
# Randomly crop a region and flip the image
top = random.randint(0, h - crop_size - 1)
left = random.randint(0, w - crop_size - 1)
if random.randint(0, 1):
image = image[:, :, ::-1]
else:
# Crop the center
top = (h - crop_size) // 2
left = (w - crop_size) // 2
bottom = top + crop_size
right = left + crop_size
image = image[:, top:bottom, left:right]
image -= self.mean[:, top:bottom, left:right]
image *= (1.0 / 255.0) # Scale to [0, 1]
return image, label
def main():
archs = {
'resnet50': resnet50.ResNet50
}
dtypes = {
'float16': np.float16,
'float32': np.float32,
'float64': np.float64,
}
parser = argparse.ArgumentParser(
description='Learning convnet from ILSVRC2012 dataset')
parser.add_argument('train', help='Path to training image-label list file')
parser.add_argument('val', help='Path to validation image-label list file')
parser.add_argument('--arch', '-a', choices=archs.keys(), default='resnet50',
help='Convnet architecture')
parser.add_argument('--batchsize', '-B', type=int, default=32,
help='Learning minibatch size')
parser.add_argument('--dtype', choices=dtypes, help='Specify the dtype '
'used. If not supplied, the default dtype is used')
parser.add_argument('--epoch', '-E', type=int, default=10,
help='Number of epochs to train')
parser.add_argument('--device', '-d', type=str, default='-1',
help='Device specifier. Either ChainerX device '
'specifier or an integer. If non-negative integer, '
'CuPy arrays with specified device id are used. If '
'negative integer, NumPy arrays are used')
parser.add_argument('--initmodel',
help='Initialize the model from given file')
parser.add_argument('--loaderjob', '-j', type=int,
help='Number of parallel data loading processes')
parser.add_argument('--mean', '-m', default='mean.npy',
help='Mean file (computed by compute_mean.py)')
parser.add_argument('--resume', '-r', default='',
help='Initialize the trainer from given file')
parser.add_argument('--out', '-o', default='result',
help='Output directory')
parser.add_argument('--root', '-R', default='.',
help='Root directory path of image files')
parser.add_argument('--val_batchsize', '-b', type=int, default=32,
help='Validation minibatch size')
parser.add_argument('--test', action='store_true')
parser.set_defaults(test=False)
parser.add_argument('--dali', action='store_true')
parser.set_defaults(dali=False)
group = parser.add_argument_group('deprecated arguments')
group.add_argument('--gpu', '-g', dest='device',
type=int, nargs='?', const=0,
help='GPU ID (negative value indicates CPU)')
args = parser.parse_args()
worksize = chainer.backends.cuda.get_max_workspace_size()
chainer.backends.cuda.set_max_workspace_size(worksize*1024)
device = chainer.get_device(args.device)
# Set the dtype if supplied.
if args.dtype is not None:
chainer.config.dtype = args.dtype
print('Device: {}'.format(device))
print('Dtype: {}'.format(chainer.config.dtype))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')
# Initialize the model to train
model = archs[args.arch]()
if args.initmodel:
print('Load model from {}'.format(args.initmodel))
chainer.serializers.load_npz(args.initmodel, model)
model.to_device(device)
device.use()
# Load the mean file
#mean = np.load(args.mean)
if args.dali:
if not dali_util._dali_available:
raise RuntimeError('DALI seems not available on your system.')
if not isinstance(device, chainer.backend.cuda.GpuDevice):
raise RuntimeError('Using DALI requires GPU device. Please '
'specify it with --device option.')
num_threads = args.loaderjob
if num_threads is None or num_threads <= 0:
num_threads = 1
#ch_mean = list(np.average(mean, axis=(1, 2)))
ch_std = [255.0, 255.0, 255.0]
# Setup DALI pipelines
train_pipe = dali_util.DaliPipelineTrain(
args.train, args.root, model.insize, args.batchsize,
num_threads, device.device.id, True, std=ch_std)
val_pipe = dali_util.DaliPipelineVal(
args.val, args.root, model.insize, args.val_batchsize,
num_threads, device.device.id, False, std=ch_std)
train_iter = chainer.iterators.DaliIterator(train_pipe)
val_iter = chainer.iterators.DaliIterator(val_pipe, repeat=False)
# converter = dali_converter
converter = dali_util.DaliConverter(crop_size=model.insize)
else:
# Load the dataset files
train = PreprocessedDataset(args.train, args.root, mean, model.insize)
val = PreprocessedDataset(args.val, args.root, mean, model.insize,
False)
# These iterators load the images with subprocesses running in parallel
# to the training/validation.
train_iter = chainer.iterators.SerialIterator(
train, args.batchsize)
val_iter = chainer.iterators.SerialIterator(
val, args.val_batchsize, repeat=False)
converter = dataset.concat_examples
# Set up an optimizer
optimizer = chainer.optimizers.MomentumSGD(lr=0.01, momentum=0.9)
optimizer.setup(model)
# Set up a trainer
updater = training.updaters.StandardUpdater(
train_iter, optimizer, converter=converter, device=device)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), args.out)
val_interval = (1 if args.test else 100000), 'iteration'
log_interval = (1 if args.test else 1000), 'iteration'
trainer.extend(extensions.Evaluator(val_iter, model, converter=converter,
device=device), trigger=val_interval)
# TODO(sonots): Temporarily disabled for chainerx. Fix it.
if device.xp is not chainerx:
trainer.extend(extensions.DumpGraph('main/loss'))
trainer.extend(extensions.snapshot(), trigger=val_interval)
trainer.extend(extensions.snapshot_object(
model, 'model_iter_{.updater.iteration}'), trigger=val_interval)
# Be careful to pass the interval directly to LogReport
# (it determines when to emit log rather than when to read observations)
trainer.extend(extensions.LogReport(trigger=log_interval))
trainer.extend(extensions.observe_lr(), trigger=log_interval)
trainer.extend(extensions.PrintReport([
'epoch', 'iteration', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'lr'
]), trigger=log_interval)
trainer.extend(extensions.ProgressBar(update_interval=10))
if args.resume:
chainer.serializers.load_npz(args.resume, trainer)
trainer.run()
if __name__ == '__main__':
main()