-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy path02_train.py
74 lines (57 loc) · 2.16 KB
/
02_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import sys
from time import time
from label_the_sky.training.trainer import Trainer, set_random_seeds
set_random_seeds()
if len(sys.argv) != 8:
print('usage: python {} <dataset> <backbone> <pretraining_dataset> <n_channels> <finetune> <dataset_mode> <timestamp>'.format(
sys.argv[0]))
exit(1)
dataset = sys.argv[1]
backbone = sys.argv[2]
pretraining_dataset = None if sys.argv[3]=='None' else sys.argv[3]
n_channels = int(sys.argv[4])
finetune = True if sys.argv[5]=='1' else False
dataset_mode = sys.argv[6]
timestamp = sys.argv[7]
if dataset_mode not in ['lowdata', 'full']:
raise Exception('dataset_mode must be: lowdata, full')
base_dir = os.environ['HOME']
if pretraining_dataset is not None and pretraining_dataset!='imagenet':
weights_file = os.path.join(
base_dir, 'trained_models', f'{timestamp}_{backbone}_{n_channels}_{pretraining_dataset}.h5')
else:
weights_file = pretraining_dataset
model_name = f'{timestamp}_{backbone}_{n_channels}_{pretraining_dataset}_clf_ft{int(finetune)}_{dataset_mode}'
trainer = Trainer(
backbone=backbone,
n_channels=n_channels,
output_type='class',
base_dir=base_dir,
weights=weights_file,
model_name=model_name
)
print('loading data')
X_train, y_train = trainer.load_data(dataset=dataset, split='train')
X_val, y_val = trainer.load_data(dataset=dataset, split='val')
X_test, y_test = trainer.load_data(dataset=dataset, split='test')
start = time()
mode = 'top_clf'
if pretraining_dataset is None:
mode = 'from_scratch'
elif finetune:
mode = 'finetune'
trainer.describe(verbose=True)
print(f'training: {mode}; dataset mode: {dataset_mode}')
if dataset_mode == 'full':
trainer.train(X_train, y_train, X_val, y_val, mode=mode)
else:
trainer.train_lowdata(X_train, y_train, X_val, y_val, mode=mode)
trainer.dump_history('history')
print('--- minutes taken:', int((time() - start) / 60))
if dataset_mode == 'full':
print('loading best model')
trainer.load_weights(os.path.join(base_dir, 'trained_models', model_name+'.h5'))
print('evaluating model on validation set')
trainer.evaluate(X_val, y_val)
print('--- minutes taken:', int((time() - start) / 60))