-
Notifications
You must be signed in to change notification settings - Fork 12
/
fine-tune.py
123 lines (108 loc) · 4.67 KB
/
fine-tune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import sys
import argparse
import logging
logging.basicConfig(level=logging.DEBUG)
#from common import find_mxnet
from common import data
from common import fit
import mxnet as mx
import mmd
import os, urllib
def download(url):
filename = url.split("/")[-1]
if not os.path.exists('model/'+filename):
urllib.urlretrieve(url, 'model/'+ filename)
def get_model(prefix, epoch):
download(prefix+'-symbol.json')
download(prefix+'-%04d.params' % (epoch,))
LABEL_WIDTH = 1
def get_fine_tune_model(symbol, arg_params, args):
"""
symbol: the pre-trained network symbol
arg_params: the argument parameters of the pre-trained model
num_classes: the number of classes for the fine-tune datasets
layer_name: the layer name before the last fully-connected layer
"""
layer_name = args.layer_before_fullc
all_layers = symbol.get_internals()
last_before = all_layers[layer_name+'_output']
lr_mult = 1
feature = last_before
fc = mx.symbol.FullyConnected(data=feature, num_hidden=args.num_classes, name='fc', lr_mult=lr_mult) #, lr_mult=10)
net = mmd.mmd(feature, fc, args)
if args.train_stage==0:
new_args = dict({k:arg_params[k] for k in arg_params if 'fc' not in k})
else:
new_args = arg_params
return (net, new_args)
if __name__ == "__main__":
# parse args
parser = argparse.ArgumentParser(description="fine-tune a dataset",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
train = fit.add_fit_args(parser)
data.add_data_args(parser)
aug = data.add_data_aug_args(parser)
parser.add_argument('--pretrained-model', type=str, default='model/resnet-152',
help='the pre-trained model')
parser.add_argument('--pretrained-epoch', type=int, default=0,
help='the pre-trained model epoch to load')
parser.add_argument('--layer-before-fullc', type=str, default='flatten0',
help='the name of the layer before the last fullc layer')
parser.add_argument('--no-checkpoint', action="store_true", default=False,
help='do not save checkpoints')
parser.add_argument('--freeze', action="store_true", default=False,
help='freeze lower layers')
parser.add_argument('--train-stage', type=int, default=0,
help='training stage, train softmax only in training stage0 and use mmd loss in training stage1')
parser.add_argument('--null-label', type=int, default=9999,
help='indicate the label id of invalid label')
parser.add_argument('--use-dan', action="store_true", default=False,
help='use DAN instead of JAN')
# use less augmentations for fine-tune
data.set_data_aug_level(parser, 2)
parser.set_defaults(data_dir="./data", top_k=0, kv_store='local', data_nthreads=15)
#parser.set_defaults(model_prefix="", data_nthreads=15, batch_size=64, num_classes=263, gpus='0,1,2,3')
#parser.set_defaults(image_shape='3,320,320', num_epochs=32,
# lr=.0001, lr_step_epochs='12,20,24,28', wd=0, mom=0.9, lr_factor=0.5)
parser.set_defaults(image_shape='3,320,320', wd=0, mom=0.9)
args = parser.parse_args()
args.label_width = LABEL_WIDTH
args.gpu_num = len(args.gpus.split(','))
args.batch_per_gpu = args.batch_size/args.gpu_num
with open(args.data_dir+'/source.lst') as f:
args.num_examples = sum(1 for _ in f)
if args.train_stage==1:
with open(args.data_dir+'/target.lst') as f:
target_num_examples = sum(1 for _ in f)
args.num_examples = min(args.num_examples, target_num_examples)
print('num_examples', args.num_examples)
print('gpu_num', args.gpu_num)
# load pretrained model
dir_path = os.path.dirname(os.path.realpath(__file__))
prefix = args.pretrained_model
epoch = args.pretrained_epoch
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
fixed_params = None
if args.freeze:
fixed_params = []
active_list = ['bn1', 'fc', 'stage3', 'stage4']
for k in arg_params:
is_active = False
for a in active_list:
if k.startswith(a):
is_active = True
break
if not is_active:
fixed_params.append(k)
print(fixed_params)
# remove the last fullc layer
(new_sym, new_args) = get_fine_tune_model(
sym, arg_params, args)
# train
fit.fit(args = args,
network = new_sym,
data_loader = data.get_rec_iter,
arg_params = new_args,
aux_params = aux_params,
fixed_param_names = fixed_params)