-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscript_train.py
executable file
·119 lines (105 loc) · 4.42 KB
/
script_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import time
# Define the number of epochs
epochs = 70
# Define the path to the checkpoints
# Define the path to the dataset
# Required to run on two different machines
hostname = os.popen('hostname').read().strip()
print(f"*** Hostname: {hostname}\n")
if hostname == 'tiago-deep':
dataset_root = '/home/tiago/workspace/DATASET'
else:
dataset_root = '/home/tbarros/workspace/DATASET'
TRAIN_FLAG = 0
# Path to save the predictions
save_path = 'RALv3'
# Define the number of points
density = '10000'
EXPERIMENT_NAME = 'RALv3_kitti_test'
EXPERIMENT_NAME = 'Thesis_full_add_results'
EVAL_PROTOCOL = "cross_validation" # cross_domain
input_preprocessing = ' --roi 0 --augmentation 1 --shuffle_points 1'
#test_sequences_kitti = ['00','02','05','06','08']
#test_sequences_kitti = ['02','05','06','08']
test_sequences_horto = ['ON23','OJ22','OJ23','ON22','SJ23','GTJ23']
#test_sequences_horto = ['-']
stages = [#'PointNetPGAP',
#'PointNetPGAPLoss',
'SPVSoAP3D',
'SPVVLAD',
'SPVGeM',
'SPVMAC',
'ResNet50VLAD',
'ResNet50GeM',
'ResNet50MAC',
'PointNetVLAD',
'PointNetGeM',
'PointNetMAC',
#'PointNetVLADLoss',
#'SPVVLADLoss',
#'ResNet50VLADLoss',
#'SPVGeMLoss',
#'ResNet50GeMLoss',
#'SPVMACLoss',
#'ResNet50MACLoss',
#'PointNetVLADLoss',
#'SPVSoAP3D',
#'SPVSoAP3DLoss',
#'LOGG3D',
#'LOGG3DLoss',
#'overlap_transformer',
#'overlap_transformerLoss',
]
test_batchsize = [
16,
16,
16,
16,
16,
15
] # 14 is the maximum batch size for GTJ23
eval_windows = [
600,
600,
600,
600,
600,
100, # 100 is the maximum window size for GTJ23
]
checkpoint = f"checkpoints/Thesis_full/triplet/ground_truth_ar0.5m_nr10m_pr2m.pkl/10000/ON23"
#time.sleep(1000)
for stage_conf in stages:
for seq,testb,window in zip(test_sequences_horto,test_batchsize,eval_windows):
for alpha in [10000]:
#for alpha in [100,500,1000,3000,5000,10000,15000,20000,30000]:
func_arg = [
f'--network {stage_conf}', # Network
f'--train {TRAIN_FLAG}', # Train or test
f'--dataset_root {dataset_root}', # path to Dataset
#'--resume best_model', # [best_model, last_model]
#f'--resume {checkpoint}/{stage_conf}-LazyTripletLoss_L2-segment_loss-m0.5/best_model.pth', # [best_model, last_model]
#f'--resume {checkpoint}/{stage_conf}-LazyTripletLoss_L2-segment_loss-m0.5/checkpoint.pth', # [best_model, last_model]
f'--resume {checkpoint}/{stage_conf}-LazyTripletLoss_L2/best_model.pth', # [best_model, last_model]
#f'--resume {checkpoint}/{stage_conf}-LazyTripletLoss_L2/checkpoint.pth', # [best_model, last_model]
f'--val_set {seq}',
f'--memory RAM' if TRAIN_FLAG == 1 else '--memory DISK', # [DISK, RAM]
'--device cuda', # Device
f'--save_predictions {save_path}', # Save predictions
f'--epochs {epochs}',
f'--max_points {alpha}',
f'--experiment {EXPERIMENT_NAME}',
#f'--feat_dim 16',
f'--eval_batch_size {testb}',
f'--mini_batch_size {1000}',
f'--loss_alpha 0.5',
f'--eval_roi_window {window}',
f'--eval_protocol {EVAL_PROTOCOL}',
input_preprocessing
]
func_arg_str = ' '.join(func_arg)
print(func_arg_str)
os.system('python3 train_knn.py ' + func_arg_str)
if TRAIN_FLAG == 1:
pass
time.sleep(600)