-
Notifications
You must be signed in to change notification settings - Fork 4
/
generate_pseudo_label.py
186 lines (126 loc) · 4.2 KB
/
generate_pseudo_label.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# Modified by Wei Jiacheng
#
#
# Originally written by Hugues THOMAS - 11/06/2018
# ----------------------------------------------------------------------------------------------------------------------
#
# Imports and global variables
# \**********************************/
#
# Common libs
import time
import os
import numpy as np
# My libs
from utils.config import Config
from utils.tester_cam import ModelTester
from models.KPCNN_model import KernelPointCNN
from models.KPFCNN_mprm import KernelPointFCNN
# Datasets
from datasets.Scannet_subcloud import ScannetDataset
# ----------------------------------------------------------------------------------------------------------------------
#
# Utility functions
# \***********************/
#
def test_caller(path, step_ind, on_val):
##########################
# Initiate the environment
##########################
# Choose which gpu to use
GPU_ID = '3'
# Set GPU visible device
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
# Disable warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
###########################
# Load the model parameters
###########################
# Load model parameters
config = Config()
config.load(path)
##################################
# Change model parameters for test
##################################
# Change parameters for the test here. For example, you can stop augmenting the input data.
#config.augment_noise = 0.0001
#config.augment_color = 1.0
config.validation_size = 1201
#config.batch_num = 10
##############
# Prepare Data
##############
print()
print('Dataset Preparation')
print('*******************')
# Initiate dataset configuration
dataset = ScannetDataset(config.input_threads, load_test=(not on_val))
# Create subsample clouds of the models
dl0 = config.first_subsampling_dl
dataset.load_subsampled_clouds(dl0)
# Initialize input pipelines
if on_val:
dataset.init_input_pipeline(config)
else:
dataset.init_test_input_pipeline(config)
##############
# Define Model
##############
print('Creating Model')
print('**************\n')
t1 = time.time()
model = KernelPointFCNN(dataset.flat_inputs, config)
# Find all snapshot in the chosen training folder
snap_path = os.path.join(path, 'snapshots')
snap_steps = [int(f[:-5].split('-')[-1]) for f in os.listdir(snap_path) if f[-5:] == '.meta']
# Find which snapshot to restore
chosen_step = np.sort(snap_steps)[step_ind]
chosen_snap = os.path.join(path, 'snapshots', 'snap-{:d}'.format(chosen_step))
# Create a tester class
tester = ModelTester(model, restore_snap=chosen_snap)
t2 = time.time()
print('\n----------------')
print('Done in {:.1f} s'.format(t2 - t1))
print('----------------\n')
############
# Start test
############
print('Start Test')
print('**********\n')
if on_val:
tester.test_cloud_segmentation_on_val(model, dataset)
else:
tester.test_cloud_segmentation(model, dataset)
# ----------------------------------------------------------------------------------------------------------------------
#
# Main Call
# \***************/
#
if __name__ == '__main__':
##########################
# Choose the model to test
##########################
#
#
#saved model path
chosen_log = ''
#
# You can also choose the index of the snapshot to load (last by default)
#
chosen_snapshot = -1
#
# Eventually, you can choose to test your model on the validation set
#
on_val = True
#
# If you want to modify certain parameters in the Config class, for example, to stop augmenting the input data,
# there is a section for it in the function "test_caller" defined above.
#
###########################
# Call the test initializer
###########################
# Check if log exists
if not os.path.exists(chosen_log):
raise ValueError('The given log does not exists: ' + chosen_log)
# Let's go
test_caller(chosen_log, chosen_snapshot, on_val)