forked from YvanYin/VNL_Monocular_Depth_Prediction
-
Notifications
You must be signed in to change notification settings - Fork 1
/
init_depth_gen_infer.py
115 lines (92 loc) · 4.92 KB
/
init_depth_gen_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import cv2
import torch
import numpy as np
from tools.parse_arg_test import TestOptions
from lib.core.config import cfg, merge_cfg_from_file
import json
import time
import logging
test_args = TestOptions().parse()
test_args.thread = 1
test_args.batchsize = 1
merge_cfg_from_file(test_args)
time_tag = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
output_folder = os.path.join(test_args.output_save_folder , "VNL_infer_{}".format(time_tag))
os.makedirs(output_folder, exist_ok=True)
cfg.TRAIN.LOG_DIR = output_folder
from lib.utils.net_tools import load_ckpt
from lib.utils.logging import setup_logging
import torchvision.transforms as transforms
from data.load_dataset import CustomerDataLoader
from lib.models.metric_depth_model import MetricDepthModel
from lib.models.image_transfer import bins_to_depth
from mirror3d.utils.mirror3d_metrics import Mirror3dEval
logger = setup_logging(__name__)
from tqdm import tqdm
def scale_torch(img, scale):
"""
Scale the image and output it in torch.tensor.
:param img: input image. [C, H, W]
:param scale: the scale factor. float
:return: img. [C, H, W]
"""
img = np.transpose(img, (2, 0, 1))
img = img[::-1, :, :]
img = img.astype(np.float32)
img /= scale
img = torch.from_numpy(img.copy())
img = transforms.Normalize(cfg.DATASET.RGB_PIXEL_MEANS, cfg.DATASET.RGB_PIXEL_VARS)(img)
return img
def read_json(json_path):
with open(json_path, 'r') as j:
info = json.loads(j.read())
return info
if __name__ == '__main__':
data_loader = CustomerDataLoader(test_args)
test_datasize = len(data_loader)
logger.info('{:>15}: {:<30}'.format('test_data_size', test_datasize))
# load model
model = MetricDepthModel()
model.eval()
# load checkpoint
if test_args.load_ckpt:
load_ckpt(test_args, model)
model.cuda()
model = torch.nn.DataParallel(model)
coco_val_info = read_json(test_args.coco_val)
FORMAT = '%(levelname)s %(filename)s:%(lineno)4d: %(message)s'
log_file_save_path = os.path.join(output_folder, "infer.log")
logging.basicConfig(filename=log_file_save_path, filemode="a", level=logging.INFO, format=FORMAT)
logging.info("output folder {}".format(output_folder))
logging.info("checkpoint {}".format(test_args.load_ckpt))
mirror3d_eval = Mirror3dEval(test_args.refined_depth,logger=logging,input_tag="RGB", method_tag="VNL",dataset_root=test_args.coco_val_root)
for info in tqdm(coco_val_info["images"]):
img_path = os.path.join(test_args.coco_val_root, info["mirror_color_image_path"])
with torch.no_grad():
img = cv2.imread(img_path)
img_resize = cv2.resize(img, (int(img.shape[1]), int(img.shape[0])), interpolation=cv2.INTER_LINEAR)
img_torch = scale_torch(img_resize, 255)
img_torch = img_torch[None, :, :, :].cuda()
_, pred_depth_softmax= model.module.depth_model(img_torch)
pred_depth = bins_to_depth(pred_depth_softmax)
pred_depth = pred_depth.cpu().numpy().squeeze()
pred_depth_scale = (pred_depth / pred_depth.max() *10000).astype(np.uint16) # scale 60000 for visualization
pred_depth_scale[pred_depth_scale<0] = 0
pred_depth_scale = pred_depth_scale.astype(np.uint16)
if test_args.refined_depth:
if test_args.mesh_depth: # mesh refine
gt_depth_path = os.path.join(test_args.coco_val_root.strip().split(",")[0], info["refined_meshD_path"])
else: # hole refine
gt_depth_path = os.path.join(test_args.coco_val_root.strip().split(",")[0], info["refined_sensorD_path"])
else:
if test_args.mesh_depth: # mesh raw
gt_depth_path = os.path.join(test_args.coco_val_root.strip().split(",")[0], info["raw_meshD_path"])
else:# mesh raw hole raw
gt_depth_path = os.path.join(test_args.coco_val_root.strip().split(",")[0], info["raw_sensorD_path"])
gt_depth = cv2.resize(cv2.imread(gt_depth_path, cv2.IMREAD_ANYDEPTH), (pred_depth_scale.shape[1], pred_depth_scale.shape[0]), 0, 0, cv2.INTER_NEAREST)
color_img_path = img_path
mirror3d_eval.compute_and_update_mirror3D_metrics(pred_depth_scale / test_args.depth_shift, test_args.depth_shift, color_img_path, os.path.join(test_args.coco_val_root.strip().split(",")[0],info["raw_meshD_path"]), gt_depth_path, os.path.join(test_args.coco_val_root.strip().split(",")[0],info["mirror_instance_mask_path"]))
mirror3d_eval.save_result(output_folder, pred_depth_scale / test_args.depth_shift, test_args.depth_shift, color_img_path, os.path.join(test_args.coco_val_root.strip().split(",")[0],info["raw_meshD_path"]), gt_depth_path, os.path.join(test_args.coco_val_root.strip().split(",")[0],info["mirror_instance_mask_path"]))
mirror3d_eval.print_mirror3D_score()
print("checkpoint : ", test_args.load_ckpt)