Skip to content

Commit

Permalink
Add tf support for vis_pred.py example (#374)
Browse files Browse the repository at this point in the history
* Add tf support for vis_pred.py example

* Apply style convention to vis_pred.py

Co-authored-by: Sanskar Agrawal <[email protected]>
  • Loading branch information
jokokojote and sanskar107 authored Nov 18, 2021
1 parent a808963 commit 5144b46
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 16 deletions.
2 changes: 1 addition & 1 deletion docs/howtos.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ This page is an effort to give short examples for common tasks and will be
extended over time.

## Visualize network predictions
Users can inspect the prediction results using the visualizer. Run `python examples/vis_pred.py` to see an example.
Users can inspect the prediction results using the visualizer. Run `python examples/vis_pred.py` to see an example (torch and tf version is available).

First, initialize a `Visualizer` and set up `LabelLUT` as label names to visualize. Here we would like to visualize points from `SemanticKITTI`. The labels can be obtained by `get_label_to_names()`
```python
Expand Down
69 changes: 54 additions & 15 deletions examples/vis_pred.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#!/usr/bin/env python
import open3d.ml.torch as ml3d
import open3d.ml.torch as ml3d # just switch to open3d.ml.tf for tf usage
import numpy as np
import os
from os.path import exists, join
import sys
from os.path import exists, join, dirname

example_dir = os.path.dirname(os.path.realpath(__file__))

Expand Down Expand Up @@ -71,6 +72,50 @@ def pred_custom_data(pc_names, pcs, pipeline_r, pipeline_k):
return vis_points


def get_torch_ckpts():
kpconv_url = "https://storage.googleapis.com/open3d-releases/model-zoo/kpconv_semantickitti_202009090354utc.pth"
randlanet_url = "https://storage.googleapis.com/open3d-releases/model-zoo/randlanet_semantickitti_202009090354utc.pth"

ckpt_path_r = example_dir + "/vis_weights_{}.pth".format('RandLANet')
if not exists(ckpt_path_r):
cmd = "wget {} -O {}".format(randlanet_url, ckpt_path_r)
os.system(cmd)

ckpt_path_k = example_dir + "/vis_weights_{}.pth".format('KPFCNN')
if not exists(ckpt_path_k):
cmd = "wget {} -O {}".format(kpconv_url, ckpt_path_k)
print(cmd)
os.system(cmd)

return ckpt_path_r, ckpt_path_k


def get_tf_ckpts():
kpconv_url = "https://storage.googleapis.com/open3d-releases/model-zoo/kpconv_semantickitti_202010021102utc.zip"
randlanet_url = "https://storage.googleapis.com/open3d-releases/model-zoo/randlanet_semantickitti_202010091306.zip"

ckpt_path_dir = example_dir + "/vis_weights_{}".format('RandLANet')
if not exists(ckpt_path_dir):
ckpt_path_zip = example_dir + "/vis_weights_{}.zip".format('RandLANet')
cmd = "wget {} -O {}".format(randlanet_url, ckpt_path_zip)
os.system(cmd)
cmd = "unzip -j -o {} -d {}".format(ckpt_path_zip, ckpt_path_dir)
os.system(cmd)
ckpt_path_r = example_dir + "/vis_weights_{}/{}_{}".format(
'RandLANet', 'randlanet', 'semantickitti')

ckpt_path_dir = example_dir + "/vis_weights_{}".format('KPFCNN')
if not exists(ckpt_path_dir):
ckpt_path_zip = example_dir + "/vis_weights_{}.zip".format('KPFCNN')
cmd = "wget {} -O {}".format(kpconv_url, ckpt_path_zip)
os.system(cmd)
cmd = "unzip -j -o {} -d {}".format(ckpt_path_zip, ckpt_path_dir)
os.system(cmd)
ckpt_path_k = example_dir + "/vis_weights_{}/{}".format('KPFCNN', 'ckpt-1')

return ckpt_path_r, ckpt_path_k


# ------------------------------


Expand All @@ -83,23 +128,17 @@ def main():
v.set_lut("labels", lut)
v.set_lut("pred", lut)

kpconv_url = "https://storage.googleapis.com/open3d-releases/model-zoo/kpconv_semantickitti_202009090354utc.pth"
randlanet_url = "https://storage.googleapis.com/open3d-releases/model-zoo/randlanet_semantickitti_202009090354utc.pth"
# load pretrained weights depending on used ml framework (torch or tf)
if ("open3d.ml.torch" in sys.modules): # torch is used
ckpt_path_r, ckpt_path_k = get_torch_ckpts()
else: # tf is used
ckpt_path_r, ckpt_path_k = get_tf_ckpts()

ckpt_path = example_dir + "/vis_weights_{}.pth".format('RandLANet')
if not exists(ckpt_path):
cmd = "wget {} -O {}".format(randlanet_url, ckpt_path)
os.system(cmd)
model = ml3d.models.RandLANet(ckpt_path=ckpt_path)
model = ml3d.models.RandLANet(ckpt_path=ckpt_path_r)
pipeline_r = ml3d.pipelines.SemanticSegmentation(model)
pipeline_r.load_ckpt(model.cfg.ckpt_path)

ckpt_path = example_dir + "/vis_weights_{}.pth".format('KPFCNN')
if not exists(ckpt_path):
cmd = "wget {} -O {}".format(kpconv_url, ckpt_path)
print(cmd)
os.system(cmd)
model = ml3d.models.KPFCNN(ckpt_path=ckpt_path, in_radius=10)
model = ml3d.models.KPFCNN(ckpt_path=ckpt_path_k)
pipeline_k = ml3d.pipelines.SemanticSegmentation(model)
pipeline_k.load_ckpt(model.cfg.ckpt_path)

Expand Down

0 comments on commit 5144b46

Please sign in to comment.