-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtransfer_nasbench201.py
168 lines (145 loc) · 6.34 KB
/
transfer_nasbench201.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
## Transfer learning NAS-Bench-201 optimisation setup
# The script first run optimisation on NAS-Bench-201 on the base task. We then use the motifs generated by the
# the surrogate GP as prior to run optimisation on CIFAR-100 and ImageNet tasks.
import argparse
import datetime
import os
import pickle
import time
import torch
from tabulate import tabulate
import bayesopt
from bayesopt.generate_test_graphs import random_sampling, mutation
from bayesopt.gp import GraphGP
from bayesopt.interpreter import Interpreter
from benchmarks import NAS201
from kernels import WeisfilerLehman
from misc.find_stuctures import find_wl_feature
parser = argparse.ArgumentParser(description='Transfer Learning NAS-Bench-201')
parser.add_argument('--base_task', default='cifar10-valid', help='the base task to first run optimisation on')
parser.add_argument('--n_repeat', type=int, default=20)
parser.add_argument('--data_path', default='./data')
parser.add_argument('--n_init', type=int, default=200)
parser.add_argument('--base_max_iters', type=int, default=48)
parser.add_argument('--transfer_max_iters', type=int, default=38)
parser.add_argument('--save_path', default='./results')
parser.add_argument('--batch_size', type=int, default=5)
parser.add_argument('--fixed_query_seed', type=int, default=None)
parser.add_argument('--load_from_cache', action='store_true')
parser.add_argument('--threshold', default=30)
args = parser.parse_args()
options = vars(args)
print(options)
tasks = ['cifar10-valid', 'cifar100', 'ImageNet16-120']
assert args.base_task in tasks
columns = ['Iteration', 'Best func val', 'Best func test', 'Time', ]
def filter_pool(pool, include_list, exclude_list, kernel, ):
"""Given a pool of candidate architectures and feature_list, accept only those architectures that match one
of one of the features listed. (include only operation)"""
if include_list is None or not len(include_list):
if exclude_list is None or not len(exclude_list):
return pool
pruned_pool = []
for p in pool:
found = False
if include_list is not None:
for f in include_list:
if find_wl_feature(p, (f,), kernel):
found = True
break
if not found: continue
if exclude_list is not None and not found:
found = False
for f in exclude_list:
if find_wl_feature(p, (f,), kernel):
break
if found: continue
pruned_pool.append(p)
return pruned_pool
def train(sampler, max_iters, include_feats=None, exclude_feats=None, base_kernel=None):
"""Main train loop """
columns = ['Iteration', 'Best func val', 'Best func test', 'Time', ]
start_time = time.time()
best_tests = []
best_vals = []
x = []
while len(x) < args.n_init:
cand = random_sampling(args.n_init, benchmark='nasbench201', )[0]
cand = filter_pool(cand, include_feats, exclude_feats, base_kernel)
x += cand
x = x[:args.n_init]
y_np_list = [sampler.eval(x_) for x_ in x]
y = torch.tensor([y[0] for y in y_np_list]).float()
train_details = [y[1] for y in y_np_list]
test = torch.tensor([sampler.test(x_) for x_ in x])
# Initialise the surrogate
k = WeisfilerLehman(oa=False, h=1, requires_grad=True)
base_gp = GraphGP(x, y, [k], )
for i in range(max_iters):
base_gp.fit(wl_subtree_candidates=())
pool = []
while len(pool) < 200:
cand = \
mutation(x, y, benchmark='nasbench201', pool_size=200, n_best=10, n_mutate=100, allow_isomorphism=True)[0]
cand = filter_pool(cand, include_feats, exclude_feats, base_kernel)
pool += cand
pool = pool[:200]
a = bayesopt.GraphExpectedImprovement(base_gp)
next_x, eis, indices = a.propose_location(top_n=args.batch_size, candidates=pool)
# set up the next iteration
detail = [o.eval(x_) for x_ in next_x]
next_y = [y[0] for y in detail]
train_details += [y[1] for y in detail]
next_test = [o.test(x_).item() for x_ in next_x]
x.extend(next_x)
y = torch.cat((y, torch.tensor(next_y).view(-1))).float()
test = torch.cat((test, torch.tensor(next_test).view(-1)))
base_gp.reset_XY(x, y)
end_time = time.time()
# current best
best_val = torch.exp(-torch.max(y))
best_test = torch.exp(-torch.max(test))
values = [str(i), best_val.item(), best_test.item(), str(end_time - start_time), ]
table = tabulate([values], headers=columns, tablefmt='simple', floatfmt='8.4f')
best_vals.append(best_val)
best_tests.append(best_test)
if i % 40 == 0:
table = table.split('\n')
table = '\n'.join([table[1]] + table)
else:
table = table.split('\n')[2]
print(table)
return base_gp, k, best_vals, best_tests
cache_path = args.data_path + '/nasbench201.pickle'
o = None
if args.load_from_cache:
if os.path.exists(cache_path):
try:
o = pickle.load(open(cache_path, 'rb'))
o.seed = args.fixed_query_seed
o.task = args.base_task
except:
pass
if o is None:
o = NAS201(args.data_path, args.base_task, seed=args.fixed_query_seed)
all_res = []
for run in range(args.n_repeat):
print('######## STARTING REPEAT %d / %d ########' % (run + 1, args.n_repeat))
res = {'iterative': None, 'transfer1': None, 'transfer2': None}
o.task = args.base_task
print('-------- Starting Base Task Optimisation ----------')
base_gp, base_kernel, base_val, bese_test = train(o, args.base_max_iters)
print('------------- Base Task Completed --------------')
transfer_tasks = [t for t in tasks if t != args.base_task]
for i, t in enumerate(transfer_tasks):
interpreter = Interpreter(gp=base_gp, thres=args.threshold)
o.task = t
include = [interpreter.feat_list[i] for i in interpreter.good_idx]
print('Preserving Motifs: ', include)
_, _, tr_val, tr_test = train(o, args.transfer_max_iters, include, None, base_kernel)
res['transfer' + str(i + 1)] = [tr_val, tr_test]
all_res.append(res)
time_string = datetime.datetime.now()
time_string = time_string.strftime('%Y%m%d_%H%M%S')
pickle.dump(all_res, open(args.save_path + '/transfer_201' + time_string + '.pickle', 'wb'))
print('All done')