-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsegment.py
97 lines (75 loc) · 4.02 KB
/
segment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
'''Script that extracts the left and right images from each reprojected image folder
and saves them in a new folder called their folder name + "left_right" in res/dataset/segmented.
@author: Andrea Lombardo
'''
import argparse
import os
import re
import psutil
import concurrent.futures
from tqdm import tqdm
import mmcv
import mmseg
from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
from mmseg.core.evaluation import get_palette
def extract_left_right(args, img_path):
input_dir = os.path.join(args.input_dir, img_path)
_, _, left, right = os.listdir(input_dir)
left, right = os.path.join(input_dir, left), os.path.join(input_dir, right)
return left, right
def segment(args, img_path, output_directory, model):
# Extract left and right images
left, right = extract_left_right(args, img_path)
left_img = mmcv.imread(left)
right_img = mmcv.imread(right)
# Segment left and right images
left_result = inference_segmentor(model, left_img)
right_result = inference_segmentor(model, right_img)
# Save segmented images
model.show_result(left, left_result, show=False, out_file=os.path.join(output_directory, img_path, 'left.png'))
model.show_result(right, right_result, show=False, out_file=os.path.join(output_directory, img_path, 'right.png'))
# Extract the sidewalks and save them in the same folder
# Turn every pixel of img to black if is not in the sidewalk class (1)
left_extracted = left_img.copy()
left_extracted[left_result[0] != 1] = 0
right_extracted = right_img.copy()
right_extracted[right_result[0] != 1] = 0
# Save the images
mmcv.imwrite(left_extracted, os.path.join(output_directory, img_path, 'left_sidewalk.png'))
mmcv.imwrite(right_extracted, os.path.join(output_directory, img_path, 'right_sidewalk.png'))
def main(args):
# Replace everything that is not a character with an underscore in neighbourhood string, and make it lowercase
args.neighbourhood = re.sub(r'[^a-zA-Z]', '_', args.neighbourhood).lower()
# Add the neighbourhood name to the path
args.input_dir = os.path.join(args.input_dir, args.neighbourhood)
# Add an underscore and the quality of the images to the path
args.input_dir = args.input_dir + '_' + args.quality
# Add 'reprojected' to the path
args.input_dir = os.path.join(args.input_dir, 'reprojected')
print(args.input_dir)
# Create the output directory if it doesn't exist
# Take args.input_dir, strip the last part of the path and add reprojected
output_directory = os.path.join(os.path.dirname(args.input_dir), 'segmented')
print(output_directory)
# Create the output directory if it doesn't exist
if not os.path.exists(output_directory):
os.makedirs(output_directory)
# Define list of images in the input directory
img_list = os.listdir(args.input_dir)
# build the model from a config file and a checkpoint file
model = init_segmentor(args.config_file, args.checkpoint_file, device='cuda:0')
# Create a for loop where each iteration calls segment(args, img) for each image in img_list2
# Use tqdm to show a progress bar
# For testing purposes, limit the list to 1000 images
img_list = img_list[:1000]
for img in tqdm(img_list):
segment(args, img, output_directory, model)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input_dir', type=str, default='res/dataset', help='input directory')
parser.add_argument('--neighbourhood', type=str, default='Osdorp', help='neighbourhood')
parser.add_argument('--quality', type=str, default='small', help='quality of the images')
parser.add_argument('--config_file', type=str, default='lib/mmsegmentation/configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py', help='config file')
parser.add_argument('--checkpoint_file', type=str, default='lib/mmsegmentation/checkpoints/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth', help='checkpoint file')
args = parser.parse_args()
main(args)