diff --git a/README.md b/README.md index bab2aa7e..b300af88 100644 --- a/README.md +++ b/README.md @@ -192,6 +192,7 @@ the scripts in [`scripts/download_datasets`](https://github.com/intel-isl/Open3D ## How-tos +* [Visualize network predictions](docs/howtos.md#visualize-network-predictions) * [Visualize custom data](docs/howtos.md#visualize-custom-data) * [Adding a new model](docs/howtos.md#adding-a-new-model) * [Adding a new dataset](docs/howtos.md#adding-a-new-dataset) diff --git a/docs/howtos.md b/docs/howtos.md index 2577aeff..6b6b7289 100644 --- a/docs/howtos.md +++ b/docs/howtos.md @@ -3,6 +3,70 @@ This page is an effort to give short examples for common tasks and will be extended over time. +## Visualize network predictions +Users can inspect the prediction results using the visualizer. Run `python examples/vis_pred.py` to see an example. + +First, initialize a `Visualizer` and set up `LabelLUT` as label names to visualize. Here we would like to visualize points from `SemanticKITTI`. The labels can be obtained by `get_label_to_names()` +```python + from ml3d.vis import Visualizer, LabelLUT + from ml3d.datasets import SemanticKITTI + + kitti_labels = SemanticKITTI.get_label_to_names() + v = Visualizer() + lut = LabelLUT() + for val in sorted(kitti_labels.keys()): + lut.add_label(kitti_labels[val], val) + v.set_lut("labels", lut) + v.set_lut("pred", lut) +``` + +Second, we will construct the networks and pipelines, load the pretrained weights, and prepare the data to be visualized. +```python + from ml3d.torch.pipelines import SemanticSegmentation + from ml3d.torch.models import RandLANet, KPFCNN + + kpconv_url = "https://storage.googleapis.com/open3d-releases/model-zoo/kpconv_semantickitti_202009090354utc.pth" + randlanet_url = "https://storage.googleapis.com/open3d-releases/model-zoo/randlanet_semantickitti_202009090354utc.pth" + + ckpt_path = "./logs/vis_weights_{}.pth".format('RandLANet') + if not exists(ckpt_path): + cmd = "wget {} -O {}".format(randlanet_url, ckpt_path) + os.system(cmd) + model = RandLANet(ckpt_path=ckpt_path) + pipeline_r = SemanticSegmentation(model) + pipeline_r.load_ckpt(model.cfg.ckpt_path) + + ckpt_path = "./logs/vis_weights_{}.pth".format('KPFCNN') + if not exists(ckpt_path): + cmd = "wget {} -O {}".format(kpconv_url, ckpt_path) + print(cmd) + os.system(cmd) + model = KPFCNN(ckpt_path=ckpt_path, in_radius=10) + pipeline_k = SemanticSegmentation(model) + pipeline_k.load_ckpt(model.cfg.ckpt_path) + + data_path = os.path.dirname(os.path.realpath(__file__)) + "/demo_data" + pc_names = ["000700", "000750"] + + # see this function in examples/vis_pred.py, + # or it can be your customized dataloader, + # or you can use the exsisting get_data() methods in ml3d/datasets + pcs = get_custom_data(pc_names, data_path) +``` + +Third, we can run the inference and collect the results and send the results to `Visualizer.visualize(list_of_pointclouds_to_visualize)`. Note that the input to `visualize()` visualize is a list of point clouds and their predictions. Each point cloud is a dictionary like, +```python + vis_d = { + "name": name, + "points": pts, # n x 3 + "labels": label, # n + "pred": pred_label, # n + } +``` +You will give its `name` and `points`. Other entries can be customized. For example, we can visualize its ground truth `label` and our prediction `pred` on a point cloud. + +Here is the result by running `python examples/vis_pred.py`, +![Visualize prediction GIF](images/visualizer_predictions.gif) ## Visualize custom data diff --git a/docs/images/visualizer_predictions.gif b/docs/images/visualizer_predictions.gif new file mode 100644 index 00000000..c8db473a Binary files /dev/null and b/docs/images/visualizer_predictions.gif differ diff --git a/examples/vis_pred.py b/examples/vis_pred.py new file mode 100644 index 00000000..64793abf --- /dev/null +++ b/examples/vis_pred.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python +from ml3d.datasets import ParisLille3D +from ml3d.datasets import S3DIS +from ml3d.datasets import Semantic3D +from ml3d.datasets import SemanticKITTI +from ml3d.datasets import Toronto3D +from ml3d.vis import Visualizer, LabelLUT +from ml3d.utils import get_module + +import argparse +import math +import numpy as np +import os +import random +import sys +import tensorflow as tf +import torch +from os.path import exists, join, isfile, dirname, abspath, split + + +def get_custom_data(pc_names, path): + + pc_data = [] + for i, name in enumerate(pc_names): + pc_path = join(path, 'points', name + '.npy') + label_path = join(path, 'labels', name + '.npy') + point = np.load(pc_path)[:, 0:3] + label = np.squeeze(np.load(label_path)) + + data = { + 'point': point, + 'feat': None, + 'label': label, + } + pc_data.append(data) + + return pc_data + + +def pred_custom_data(pc_names, pcs, pipeline_r, pipeline_k): + vis_points = [] + for i, data in enumerate(pcs): + name = pc_names[i] + + results_r = pipeline_r.run_inference(data) + pred_label_r = (results_r['predict_labels'] + 1).astype(np.int32) + pred_label_r[0] = 0 + + results_k = pipeline_k.run_inference(data) + pred_label_k = (results_k['predict_labels'] + 1).astype(np.int32) + pred_label_k[0] = 0 + + label = data['label'] + pts = data['point'] + + vis_d = { + "name": name, + "points": pts, + "labels": label, + "pred": pred_label_k, + } + vis_points.append(vis_d) + + vis_d = { + "name": name + "_randlanet", + "points": pts, + "labels": pred_label_r, + } + vis_points.append(vis_d) + + vis_d = { + "name": name + "_kpconv", + "points": pts, + "labels": pred_label_k, + } + vis_points.append(vis_d) + + return vis_points + + +# ------------------------------ + +from ml3d.torch.pipelines import SemanticSegmentation +from ml3d.torch.models import RandLANet, KPFCNN + + +def main(): + kitti_labels = SemanticKITTI.get_label_to_names() + v = Visualizer() + lut = LabelLUT() + for val in sorted(kitti_labels.keys()): + lut.add_label(kitti_labels[val], val) + v.set_lut("labels", lut) + v.set_lut("pred", lut) + + kpconv_url = "https://storage.googleapis.com/open3d-releases/model-zoo/kpconv_semantickitti_202009090354utc.pth" + randlanet_url = "https://storage.googleapis.com/open3d-releases/model-zoo/randlanet_semantickitti_202009090354utc.pth" + + ckpt_path = "./logs/vis_weights_{}.pth".format('RandLANet') + if not exists(ckpt_path): + cmd = "wget {} -O {}".format(randlanet_url, ckpt_path) + os.system(cmd) + model = RandLANet(ckpt_path=ckpt_path) + pipeline_r = SemanticSegmentation(model) + pipeline_r.load_ckpt(model.cfg.ckpt_path) + + ckpt_path = "./logs/vis_weights_{}.pth".format('KPFCNN') + if not exists(ckpt_path): + cmd = "wget {} -O {}".format(kpconv_url, ckpt_path) + print(cmd) + os.system(cmd) + model = KPFCNN(ckpt_path=ckpt_path, in_radius=10) + pipeline_k = SemanticSegmentation(model) + pipeline_k.load_ckpt(model.cfg.ckpt_path) + + data_path = os.path.dirname(os.path.realpath(__file__)) + "/demo_data" + pc_names = ["000700", "000750"] + pcs = get_custom_data(pc_names, data_path) + pcs_with_pred = pred_custom_data(pc_names, pcs, pipeline_r, pipeline_k) + + v.visualize(pcs_with_pred) + + +if __name__ == "__main__": + main() diff --git a/ml3d/datasets/base_dataset.py b/ml3d/datasets/base_dataset.py index 90a41958..8b9fbeb1 100644 --- a/ml3d/datasets/base_dataset.py +++ b/ml3d/datasets/base_dataset.py @@ -31,6 +31,16 @@ def __init__(self, **kwargs): self.cfg = Config(kwargs) self.name = self.cfg.name + @staticmethod + @abstractmethod + def get_label_to_names(): + """Returns a label to names dict. + + Returns: + A dict where keys are label numbers and + vals are the corresponding names. + """ + @abstractmethod def get_split(self, split): """Returns a dataset split. diff --git a/ml3d/datasets/parislille3d.py b/ml3d/datasets/parislille3d.py index 6d3563f9..270bf665 100644 --- a/ml3d/datasets/parislille3d.py +++ b/ml3d/datasets/parislille3d.py @@ -58,18 +58,7 @@ def __init__(self, cfg = self.cfg - self.label_to_names = { - 0: 'unclassified', - 1: 'ground', - 2: 'building', - 3: 'pole-road_sign-traffic_light', - 4: 'bollard-small_pole', - 5: 'trash_can', - 6: 'barrier', - 7: 'pedestrian', - 8: 'car', - 9: 'natural-vegetation' - } + self.label_to_names = self.get_label_to_names() self.num_classes = len(self.label_to_names) self.label_values = np.sort([k for k, v in self.label_to_names.items()]) @@ -88,6 +77,22 @@ def __init__(self, test_path = cfg.dataset_path + "/test_10_classes/" self.test_files = glob.glob(test_path + '*.ply') + @staticmethod + def get_label_to_names(): + label_to_names = { + 0: 'unclassified', + 1: 'ground', + 2: 'building', + 3: 'pole-road_sign-traffic_light', + 4: 'bollard-small_pole', + 5: 'trash_can', + 6: 'barrier', + 7: 'pedestrian', + 8: 'car', + 9: 'natural-vegetation' + } + return label_to_names + def get_split(self, split): return ParisLille3DSplit(self, split=split) diff --git a/ml3d/datasets/s3dis.py b/ml3d/datasets/s3dis.py index 2b9f06dd..0ca6a182 100644 --- a/ml3d/datasets/s3dis.py +++ b/ml3d/datasets/s3dis.py @@ -61,21 +61,7 @@ def __init__(self, cfg = self.cfg - self.label_to_names = { - 0: 'ceiling', - 1: 'floor', - 2: 'wall', - 3: 'beam', - 4: 'column', - 5: 'window', - 6: 'door', - 7: 'table', - 8: 'chair', - 9: 'sofa', - 10: 'bookcase', - 11: 'board', - 12: 'clutter' - } + self.label_to_names = self.get_label_to_names() self.num_classes = len(self.label_to_names) self.label_values = np.sort([k for k, v in self.label_to_names.items()]) self.label_to_idx = {l: i for i, l in enumerate(self.label_values)} @@ -94,6 +80,25 @@ def __init__(self, self.all_files = glob.glob( str(Path(self.cfg.dataset_path) / 'original_ply' / '*.ply')) + @staticmethod + def get_label_to_names(): + label_to_names = { + 0: 'ceiling', + 1: 'floor', + 2: 'wall', + 3: 'beam', + 4: 'column', + 5: 'window', + 6: 'door', + 7: 'table', + 8: 'chair', + 9: 'sofa', + 10: 'bookcase', + 11: 'board', + 12: 'clutter' + } + return label_to_names + def get_split(self, split): return S3DISSplit(self, split=split) diff --git a/ml3d/datasets/semantic3d.py b/ml3d/datasets/semantic3d.py index eb766ccf..f87408e4 100644 --- a/ml3d/datasets/semantic3d.py +++ b/ml3d/datasets/semantic3d.py @@ -64,17 +64,7 @@ def __init__(self, cfg = self.cfg - self.label_to_names = { - 0: 'unlabeled', - 1: 'man-made terrain', - 2: 'natural terrain', - 3: 'high vegetation', - 4: 'low vegetation', - 5: 'buildings', - 6: 'hard scape', - 7: 'scanning artefacts', - 8: 'cars' - } + self.label_to_names = self.get_label_to_names() self.num_classes = len(self.label_to_names) self.label_values = np.sort([k for k, v in self.label_to_names.items()]) self.label_to_idx = {l: i for i, l in enumerate(self.label_values)} @@ -103,6 +93,21 @@ def __init__(self, self.train_files = np.sort( [f for f in self.train_files if f not in self.val_files]) + @staticmethod + def get_label_to_names(): + label_to_names = { + 0: 'unlabeled', + 1: 'man-made terrain', + 2: 'natural terrain', + 3: 'high vegetation', + 4: 'low vegetation', + 5: 'buildings', + 6: 'hard scape', + 7: 'scanning artefacts', + 8: 'cars' + } + return label_to_names + def get_split(self, split): return Semantic3DSplit(self, split=split) diff --git a/ml3d/datasets/semantickitti.py b/ml3d/datasets/semantickitti.py index f9ae8326..8ed4d344 100644 --- a/ml3d/datasets/semantickitti.py +++ b/ml3d/datasets/semantickitti.py @@ -72,7 +72,31 @@ def __init__(self, cfg = self.cfg - self.label_to_names = { + self.label_to_names = self.get_label_to_names() + self.num_classes = len(self.label_to_names) + + data_config = join(dirname(abspath(__file__)), '_resources/', + 'semantic-kitti.yaml') + DATA = yaml.safe_load(open(data_config, 'r')) + remap_dict = DATA["learning_map_inv"] + + # make lookup table for mapping + max_key = max(remap_dict.keys()) + remap_lut = np.zeros((max_key + 100), dtype=np.int32) + remap_lut[list(remap_dict.keys())] = list(remap_dict.values()) + + remap_dict_val = DATA["learning_map"] + max_key = max(remap_dict_val.keys()) + remap_lut_val = np.zeros((max_key + 100), dtype=np.int32) + remap_lut_val[list(remap_dict_val.keys())] = list( + remap_dict_val.values()) + + self.remap_lut_val = remap_lut_val + self.remap_lut = remap_lut + + @staticmethod + def get_label_to_names(): + label_to_names = { 0: 'unlabeled', 1: 'car', 2: 'bicycle', @@ -94,26 +118,7 @@ def __init__(self, 18: 'pole', 19: 'traffic-sign' } - self.num_classes = len(self.label_to_names) - - data_config = join(dirname(abspath(__file__)), '_resources/', - 'semantic-kitti.yaml') - DATA = yaml.safe_load(open(data_config, 'r')) - remap_dict = DATA["learning_map_inv"] - - # make lookup table for mapping - max_key = max(remap_dict.keys()) - remap_lut = np.zeros((max_key + 100), dtype=np.int32) - remap_lut[list(remap_dict.keys())] = list(remap_dict.values()) - - remap_dict_val = DATA["learning_map"] - max_key = max(remap_dict_val.keys()) - remap_lut_val = np.zeros((max_key + 100), dtype=np.int32) - remap_lut_val[list(remap_dict_val.keys())] = list( - remap_dict_val.values()) - - self.remap_lut_val = remap_lut_val - self.remap_lut = remap_lut + return label_to_names def get_split(self, split): return SemanticKITTISplit(self, split=split) diff --git a/ml3d/datasets/toronto3d.py b/ml3d/datasets/toronto3d.py index b4b93806..481afa4f 100644 --- a/ml3d/datasets/toronto3d.py +++ b/ml3d/datasets/toronto3d.py @@ -64,17 +64,7 @@ def __init__(self, cfg = self.cfg - self.label_to_names = { - 0: 'Unclassified', - 1: 'Ground', - 2: 'Road_markings', - 3: 'Natural', - 4: 'Building', - 5: 'Utility_line', - 6: 'Pole', - 7: 'Car', - 8: 'Fence' - } + self.label_to_names = self.get_label_to_names() self.dataset_path = cfg.dataset_path self.num_classes = len(self.label_to_names) @@ -90,6 +80,21 @@ def __init__(self, join(self.cfg.dataset_path, f) for f in cfg.test_files ] + @staticmethod + def get_label_to_names(): + label_to_names = { + 0: 'Unclassified', + 1: 'Ground', + 2: 'Road_markings', + 3: 'Natural', + 4: 'Building', + 5: 'Utility_line', + 6: 'Pole', + 7: 'Car', + 8: 'Fence' + } + return label_to_names + def get_split(self, split): return Toronto3DSplit(self, split=split)