diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..b58ee83
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,7 @@
+__pycache__/
+GAPartNet_All
+perception/
+wandb/
+ckpt/
+image_kuafu
+output/GAPartNet_result/
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..56715b3
--- /dev/null
+++ b/README.md
@@ -0,0 +1,100 @@
+
+ GAPartNet: Cross-Category Domain-Generalizable Object Perception and Manipulation via Generalizable and Actionable Parts
+
+ CVPR 2023 Highlight
+
+
+
+
+
+This is the official repository of [**GAPartNet: Cross-Category Domain-Generalizable Object Perception and Manipulation via Generalizable and Actionable Parts**](https://arxiv.org/abs/2211.05272).
+
+For more information, please visit our [**project page**](https://pku-epic.github.io/GAPartNet/).
+
+
+## 💡 News
+- `2023/6/28` We polish our model with user-friendly Lightning framework and release detailed training code! Check gapartnet folder for more details!
+
+- `2023/5/21` GAPartNet Dataset has been released, including Object & Part Assets and Annotations, Rendered PointCloud Data and our Pre-trained Checkpoint.
+
+## GAPartNet Dataset
+
+(New!) GAPartNet Dataset has been released, including Object & Part Assets and Annotations, Rendered PointCloud Data and our Pre-trained Checkpoint.
+
+To obtain our dataset, please fill out [**this form**](https://forms.gle/3qzv8z5vP2BT5ARN7) and check the [**Terms&Conditions**](https://docs.google.com/document/d/1kjFCTcDLtaycZiJVmSVhT9Yw8oCAHl-3XKdJapvRdW0/edit?usp=sharing). Please cite our paper if you use our dataset.
+
+Download our pretrained checkpoint [**here**](https://drive.google.com/file/d/1D1PwfXPYPtxadthKAJdehhIBbPEyBB6X/view?usp=sharing)! (Notive that the checkpoint in the dataset is expired, please use the this one.)
+
+## GAPartNet Network and Inference
+
+We release our network and checkpoint, check gapartnet folder for more details. You can segment part
+and estimate the pose of it. We also provide visualization code. This is an visualization example:
+![example](gapartnet/output/example.png)
+![example2](gapartnet/output/example2.png)
+
+## How to use our code and model:
+
+### 1. Install dependencies
+ - Python 3.8
+ - Pytorch >= 1.11.0
+ - CUDA >= 11.3
+ - Open3D with extension (See install guide below)
+ - epic_ops (See install guide below)
+ - pointnet2_ops (See install guide below)
+ - other pip packages
+
+### 2. Install Open3D & epic_ops & pointnet2_ops
+ See this repo for more details:
+
+ [GAPartNet_env](https://github.com/geng-haoran/GAPartNet_env): This repo includes Open3D, [epic_ops](https://github.com/geng-haoran/epic_ops) and pointnet2_ops. You can install them by following the instructions in this repo.
+
+### 3. Download our model and data
+ See gapartnet folder for more details.
+
+### 4. Inference and visualization
+ ```
+ cd gapartnet
+
+ CUDA_VISIBLE_DEVICES=0 \
+ python train.py test -c gapartnet.yaml \
+ --model.init_args.ckpt ckpt/new.ckpt
+ ```
+
+### 5. Training
+ You can run the following code to train the policy:
+ ```
+ cd gapartnet
+
+ CUDA_VISIBLE_DEVICES=0 \
+ python train.py fit -c gapartnet.yaml
+ ```
+
+## Citation
+If you find our work useful in your research, please consider citing:
+
+```
+@article{geng2022gapartnet,
+ title={GAPartNet: Cross-Category Domain-Generalizable Object Perception and Manipulation via Generalizable and Actionable Parts},
+ author={Geng, Haoran and Xu, Helin and Zhao, Chengyang and Xu, Chao and Yi, Li and Huang, Siyuan and Wang, He},
+ journal={arXiv preprint arXiv:2211.05272},
+ year={2022}
+}
+```
+
+## Contact
+If you have any questions, please open a github issue or contact us:
+
+Haoran Geng: ghr@stu.pku.edu.cn
+
+Helin Xu: xuhelin1911@gmail.com
+
+Chengyang Zhao: zhaochengyang@pku.edu.cn
+
+He Wang: hewang@pku.edu.cn
diff --git a/dataset/.gitignore b/dataset/.gitignore
new file mode 100644
index 0000000..733e4a9
--- /dev/null
+++ b/dataset/.gitignore
@@ -0,0 +1,14 @@
+.DS_Store
+*/.DS_Store
+*.pyc
+*.png
+log_*.txt
+
+__pycache__/
+example_data/
+example_rendered/
+sampled_data/
+visu/
+
+dataset_*/
+log_*/
diff --git a/dataset/README.md b/dataset/README.md
new file mode 100644
index 0000000..1626672
--- /dev/null
+++ b/dataset/README.md
@@ -0,0 +1,100 @@
+# GAPartNet Dataset
+
+## Data Format
+
+The GAPartNet dataset is built based on two exsiting datasets, PartNet-Mobility and AKB-48, from which the 3D object shapes are collected, cleaned, and equipped with new uniform GAPart-based semantics and poses annotations. The model_ids we use are provided in `render_tools/meta/{partnet_all_id_list.txt, akb48_all_id_list.txt}`.
+
+Four additional files accompany each object shape from PartNet-Mobility, providing annotations in the following formats:
+
+- `semantics_gapartnet.txt`: This file contains link semantics. Each line corresponds to a link in the kinematic chain, as indicated in `mobility_annotation_gapartnet.urdf`, formatted as "[link_name] [joint_type] [semantics]".
+- `mobility_annotation_gapartnet.urdf`: This document describes the kinematic chain, including our newly re-merged links and modified meshes. Each GAPart in the object shape corresponds to an individual link. We recommend using this file for annotation (semantics, poses) rendering and part properties queries.
+- `mobility_texture_gapartnet.urdf`: This file also describes the kinematic chain but uses the original meshes. Each GAPart in the kinematic chain is not guaranteed to be an individual link. In our paper, we mentioned that since the GAPart semantics are newly defined, the meshes and annotations in the original assets may be inconsistent with our definition, which requires a finer level of detail. For example, in the original mesh for "Oven" or "Dishwasher," a line_fixed_handle and a hinge_door could be attached into a single .obj mesh file. To address this issue, we modified the meshes to separate the GAParts. However, these mesh modifications may have caused issues in the broken texture, resulting in poor quality in rendering. As a temporary solution, we provide this file and use the original meshes for texture rendering. The examplar code for the joint correspondence between the kinematic chains in `mobility_annotation_gapartnet.urdf` and `mobility_texture_gapartnet.urdf` can be found in our rendering toolkit.
+- `link_annotation_gapartnet.json`: The json file contains GAPart semantics and pose of each link in the kinematic chain in `mobility_annotation_gapartnet.urdf`. Spefically, for each link, "link_name", "is_gapart", "category", "bbox" are provided, where "bbox" are the 3D bounding box position of the part in the rest state, i.e., all joint states are set to zero. The order of the eight vertices is as follows: [(-x,+y,+z), (+x,+y,+z), (+x,-y,+z), (-x,-y,+z), (-x,+y,-z), (+x,+y,-z), (+x,-y,-z), (-x,-y,-z)].
+
+## Data Split
+
+The data splits used in our paper can be found in `render_tools/meta/{partnet_all_split.json, akb48_all_split.json}`. We split all 27 object categories into 17 seen and 10 unseen categories. Each seen category was further split into seen and unseen instances. This two-level split ensures that all GAPart classes exist in both seen and unseen object categories, which helps evaluate intra- and inter-category generalizability.
+
+## Rendering Toolkit
+
+We provide an example toolkit for rendering and visualizing our GAPartNet dataset, located in `render_tools/`. This toolkit relies on [SAPIEN](https://github.com/haosulab/SAPIEN). To use it, please check the requirements in `render_tools/requirements.txt` and install the required packages.
+
+To render a single view of an object shape, use the `render_tools/render.py` script with the following command:
+
+```
+python render.py --model_id {MODEL_ID} \
+ --camera_idx {CAMERA_INDEX} \
+ --render_idx {RENDER_INDEX} \
+ --height {HEIGHT} \
+ --width {WIDTH} \
+ --ray_tracing {USE_RAY_TRACING} \
+ --replace_texture {REPLACE_TEXTURE}
+```
+
+The parameters are as follows:
+
+- `MODEL_ID`: The ID of the object shape you want to render.
+- `CAMERA_INDEX`: The index of the selected camera position range. This index is pre-defined in `render_tools/config_utils.py`.
+- `RENDER_INDEX`: The index of the specific rendered view.
+- `HEIGHT`: The height of the rendered image.
+- `WIDTH`: The width of the rendered image.
+- `USE_RAY_TRACING`: A boolean value specifying whether to use ray tracing for rendering. Use 'true' to enable and 'false' to disable.
+- `REPLACE_TEXTURE`: A boolean value that determines whether to use the original texture or the modified texture for rendering. Set it to 'true' to use the original texture (better) and 'false' to use the modified.
+
+To render the entire dataset, utilize the `render_tools/render_all_partnet.py` script with the following command:
+
+``````shell
+python render_all_partnet.py --ray_tracing {USE_RAY_TRACING} \
+ --replace_texture {REPLACE_TEXTURE} \
+ --start_idx {START_INDEX} \
+ --num_render {NUM_RENDER} \
+ --log_dir {LOG_DIR}
+
+``````
+
+The parameters are defined as follows:
+
+- `USE_RAY_TRACING` and `REPLACE_TEXTURE`: These parameters are identical to those described earlier.
+- `START_INDEX`: Specifies the starting render index, which is the same as the `RENDER_INDEX` mentioned previously.
+- `NUM_RENDER`: Specifies the number of views to render for each object shape and camera range.
+- `LOG_DIR`: The directory where the log files will be saved.
+
+To visualize the rendering results, use the `render_tools/visualize.py` script with this command:
+
+```shell
+python visualize.py --model_id {MODEL_ID} \
+ --category {CATEGORY} \
+ --camera_position_index {CAMERA_INDEX} \
+ --render_index {RENDER_INDEX}
+```
+
+The parameters are as follows:
+
+- `MODEL_ID`: The ID of the object shape to visualize.
+- `CATEGORY`: The category of the object.
+- `CAMERA_INDEX`: The index of the selected range for the camera position, pre-defined in `render_tools/config_utils.py`.
+- `RENDER_INDEX`: The index of the view that you wish to visualize.
+
+
+## Pre-processing Toolkit
+
+In addition to the rendering toolkit, we also provide a pre-processing toolkit to convert the rendered results into our model's input data format. This toolkit loads the rendered results, generates a partial point cloud via back-projection, and uses Farthest-Point-Sampling (FPS) to sample points from the dense point cloud.
+
+To use the toolkit, first install the PointNet++ library in `process_tools/utils/pointnet_lib` with the following command: `python setup.py install`. This installation will enable FPS performance on GPU. The library is sourced from [HalfSummer11/CAPTRA](https://github.com/HalfSummer11/CAPTRA), which is based on [sshaoshuai/Pointnet2.PyTorch](https://github.com/sshaoshuai/Pointnet2.PyTorch) and [yanx27/Pointnet_Pointnet2_pytorch](https://github.com/yanx27/Pointnet_Pointnet2_pytorch).
+
+To pre-process the rendered results, use the `process_tools/convert_rendered_into_input.py` script with the following command:
+
+```shell
+python convert_rendered_into_input.py --data_path {DATA_PATH} \
+ --save_path {SAVE_PATH} \
+ --num_points {NUM_POINTS} \
+ --visualize {VISUALIZE}
+```
+
+The parameters are as follows:
+
+- `DATA_PATH`: Path to the directory containing the rendered results.
+- `SAVE_PATH`: Path to the directory where the pre-processed results will be stored.
+- `NUM_POINTS`: The number of points to sample from the partial point cloud.
+- `VISUALIZE`: A boolean value indicating whether to visualize the pre-processed results. Use 'true' to enable and 'false' to disable.
+
diff --git a/dataset/process_tools/convert_rendered_into_input.py b/dataset/process_tools/convert_rendered_into_input.py
new file mode 100644
index 0000000..68d81d8
--- /dev/null
+++ b/dataset/process_tools/convert_rendered_into_input.py
@@ -0,0 +1,225 @@
+'''
+Convert the rendered data into the input format for the GAPartNet framework.
+
+Output .pth format:
+point_cloud: (N,3), float32, (x,y,z) in camera coordinate
+per_point_rgb: (N,3), float32, ranging in [0,1] (R,G,B)
+semantic_label: (N, ), int32, ranging in [0,nClass], 0 for others, [1, nClass] for part categories
+instance_label: (N, ), int32, ranging in {-100} \cup [0,nInstance-1], -100 for others, [0, nInstance-1] for parts
+NPCS: (N,3), float32, ranging in [-1,1] (x,y,z)
+idx: (N,2), int32, (y,x) in the image coordinate
+'''
+
+import os
+from os.path import join as pjoin
+from argparse import ArgumentParser
+import numpy as np
+import torch
+import open3d as o3d
+
+from utils.read_utils import load_rgb_image, load_depth_map, load_anno_dict, load_meta
+from utils.sample_utils import FPS
+
+LOG_PATH = './log_sample.txt'
+
+OBJECT_CATEGORIES = [
+ 'Box', 'Camera', 'CoffeeMachine', 'Dishwasher', 'KitchenPot', 'Microwave', 'Oven', 'Phone', 'Refrigerator',
+ 'Remote', 'Safe', 'StorageFurniture', 'Table', 'Toaster', 'TrashCan', 'WashingMachine', 'Keyboard', 'Laptop', 'Door', 'Printer',
+ 'Suitcase', 'Bucket', 'Toilet'
+]
+
+MAX_INSTANCE_NUM = 1000
+
+def log_string(file, s):
+ file.write(s + '\n')
+ print(s)
+
+
+def get_point_cloud(rgb_image, depth_map, sem_seg_map, ins_seg_map, npcs_map, meta):
+ width = meta['width']
+ height = meta['height']
+ K = np.array(meta['camera_intrinsic']).reshape(3, 3)
+
+ point_cloud = []
+ per_point_rgb = []
+ per_point_sem_label = []
+ per_point_ins_label = []
+ per_point_npcs = []
+ per_point_idx = []
+
+ for y_ in range(height):
+ for x_ in range(width):
+ if sem_seg_map[y_, x_] == -2 or ins_seg_map[y_, x_] == -2:
+ continue
+ z_new = float(depth_map[y_, x_])
+ x_new = (x_ - K[0, 2]) * z_new / K[0, 0]
+ y_new = (y_ - K[1, 2]) * z_new / K[1, 1]
+ point_cloud.append([x_new, y_new, z_new])
+ per_point_rgb.append((rgb_image[y_, x_] / 255.0))
+ per_point_sem_label.append(sem_seg_map[y_, x_])
+ per_point_ins_label.append(ins_seg_map[y_, x_])
+ per_point_npcs.append(npcs_map[y_, x_])
+ per_point_idx.append([y_, x_])
+
+ return np.array(point_cloud), np.array(per_point_rgb), np.array(per_point_sem_label), np.array(
+ per_point_ins_label), np.array(per_point_npcs), np.array(per_point_idx)
+
+
+def FindMaxDis(pointcloud):
+ max_xyz = pointcloud.max(0)
+ min_xyz = pointcloud.min(0)
+ center = (max_xyz + min_xyz) / 2
+ max_radius = ((((pointcloud - center)**2).sum(1))**0.5).max()
+ return max_radius, center
+
+
+def WorldSpaceToBallSpace(pointcloud):
+ """
+ change the raw pointcloud in world space to united vector ball space
+ return: max_radius: the max_distance in raw pointcloud to center
+ center: [x,y,z] of the raw center
+ """
+ max_radius, center = FindMaxDis(pointcloud)
+ pointcloud_normalized = (pointcloud - center) / max_radius
+ return pointcloud_normalized, max_radius, center
+
+
+def sample_and_save(filename, data_path, save_path, num_points, visualize=False):
+
+ pth_save_path = pjoin(save_path, 'pth')
+ os.makedirs(pth_save_path, exist_ok=True)
+ meta_save_path = pjoin(save_path, 'meta')
+ os.makedirs(meta_save_path, exist_ok=True)
+ gt_save_path = pjoin(save_path, 'gt')
+ os.makedirs(gt_save_path, exist_ok=True)
+
+ anno_dict = load_anno_dict(data_path, filename)
+ metafile = load_meta(data_path, filename)
+ rgb_image = load_rgb_image(data_path, filename)
+ depth_map = load_depth_map(data_path, filename)
+
+ # Get point cloud from back-projection
+ pcs, pcs_rgb, pcs_sem, pcs_ins, pcs_npcs, pcs_idx = get_point_cloud(rgb_image,
+ depth_map,
+ anno_dict['semantic_segmentation'],
+ anno_dict['instance_segmentation'],
+ anno_dict['npcs_map'],
+ metafile)
+
+ assert ((pcs_sem == -1) == (pcs_ins == -1)).all(), 'Semantic and instance labels do not match!'
+
+ # FPS sampling
+ pcs_sampled, fps_idx = FPS(pcs, num_points)
+ if pcs_sampled is None:
+ return -1
+
+ pcs_rgb_sampled = pcs_rgb[fps_idx]
+ pcs_sem_sampled = pcs_sem[fps_idx]
+ pcs_ins_sampled = pcs_ins[fps_idx]
+ pcs_npcs_sampled = pcs_npcs[fps_idx]
+ pcs_idx_sampled = pcs_idx[fps_idx]
+
+ # normalize point cloud
+ pcs_sampled_normalized, max_radius, center = WorldSpaceToBallSpace(pcs_sampled)
+ scale_param = np.array([max_radius, center[0], center[1], center[2]])
+
+ # convert semantic and instance labels
+ # old label:
+ # semantic label: -1 for others, [0, nClass-1] for part categories
+ # instance label: -1 for others, [0, nInstance-1] for parts
+ # new label:
+ # semantic label: 0 for others, [1, nClass] for part categories
+ # instance label: -100 for others, [0, nInstance-1] for parts
+ pcs_sem_sampled_converted = pcs_sem_sampled + 1
+ pcs_ins_sampled_converted = pcs_ins_sampled.copy()
+ mask = pcs_ins_sampled_converted == -1
+ pcs_ins_sampled_converted[mask] = -100
+
+ # re-label instance label to be continuous (discontinuous because of FPS sampling)
+ j = 0
+ while (j < pcs_ins_sampled_converted.max()):
+ if (len(np.where(pcs_ins_sampled_converted == j)[0]) == 0):
+ mask = pcs_ins_sampled_converted == pcs_ins_sampled_converted.max()
+ pcs_ins_sampled_converted[mask] = j
+ j += 1
+
+ # visualize
+ if visualize:
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(pcs_sampled_normalized)
+ pcd.colors = o3d.utility.Vector3dVector(pcs_rgb_sampled)
+ o3d.visualization.draw_geometries([pcd])
+
+ torch.save((pcs_sampled_normalized.astype(np.float32), pcs_rgb_sampled.astype(
+ np.float32), pcs_sem_sampled_converted.astype(np.int32), pcs_ins_sampled_converted.astype(
+ np.int32), pcs_npcs_sampled.astype(np.float32), pcs_idx_sampled.astype(np.int32)), pjoin(pth_save_path, filename + '.pth'))
+ np.savetxt(pjoin(meta_save_path, filename + '.txt'), scale_param, delimiter=',')
+
+ # save gt for evaluation
+ label_sem_ins = np.ones(pcs_ins_sampled_converted.shape, dtype=np.int32) * (-100)
+ inst_num = int(pcs_ins_sampled_converted.max() + 1)
+ for inst_id in range(inst_num):
+ instance_mask = np.where(pcs_ins_sampled_converted == inst_id)[0]
+ if instance_mask.shape[0] == 0:
+ raise ValueError(f'{filename} has a part missing from point cloud, instance label is not continuous')
+ semantic_label = int(pcs_sem_sampled_converted[instance_mask[0]])
+ if semantic_label == 0:
+ raise ValueError(f'{filename} has a part with semantic label [others]')
+ label_sem_ins[instance_mask] = semantic_label * MAX_INSTANCE_NUM + inst_id
+
+ np.savetxt(pjoin(gt_save_path, filename + '.txt'), label_sem_ins, fmt='%d')
+
+ return 0
+
+
+if __name__ == "__main__":
+
+ parser = ArgumentParser()
+ parser.add_argument('--data_path', type=str, default='./rendered_data', help='Specify the path to the rendered data')
+ parser.add_argument('--save_path', type=str, default='./sampled_data', help='Specify the path to save the sampled data')
+ parser.add_argument('--num_points', type=int, default=20000, help='Specify the number of points to sample')
+ parser.add_argument('--visualize', type=bool, default=False, help='Whether to visualize the sampled point cloud')
+
+ args = parser.parse_args()
+
+ DATA_PATH = args.data_path
+ SAVE_PATH = args.save_path
+ if not os.path.exists(SAVE_PATH):
+ os.mkdir(SAVE_PATH)
+ NUM_POINTS = args.num_points
+ VISUALIZE = args.visualize
+
+ filename_list = sorted([x.split('.')[0] for x in os.listdir(pjoin(DATA_PATH, 'rgb'))])
+ filename_dict = {x: [] for x in OBJECT_CATEGORIES}
+ for fn in filename_list:
+ for x in OBJECT_CATEGORIES:
+ if fn.startswith(x):
+ filename_dict[x].append(fn)
+ break
+
+ LOG_FILE = open(LOG_PATH, 'w')
+
+ def log_writer(s):
+ log_string(LOG_FILE, s)
+
+ for category in filename_dict:
+ log_writer(f'Start: {category}')
+
+ fn_list = filename_dict[category]
+ log_writer(f'{category} : {len(fn_list)}')
+
+ for idx, fn in enumerate(fn_list):
+ log_writer(f'Sampling {idx}/{len(fn_list)} {fn}')
+
+ ret = sample_and_save(fn, DATA_PATH, SAVE_PATH, NUM_POINTS, VISUALIZE)
+ if ret == -1:
+ log_writer(f'Error in {fn} {category}, num of points less than NUM_POINTS!')
+ else:
+ log_writer(f'Finish: {fn}')
+
+ log_writer(f'Finish: {category}')
+
+ LOG_FILE.close()
+
+ print('All finished!')
+
diff --git a/dataset/process_tools/utils/pointnet_lib/README.md b/dataset/process_tools/utils/pointnet_lib/README.md
new file mode 100644
index 0000000..db0dbe8
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/README.md
@@ -0,0 +1,3 @@
+# PointNet++ Library
+
+Copied from [HalfSummer11/CAPTRA](https://github.com/HalfSummer11/CAPTRA), based on [sshaoshuai/Pointnet2.PyTorch](https://github.com/sshaoshuai/Pointnet2.PyTorch) and [yanx27/Pointnet_Pointnet2_pytorch](https://github.com/yanx27/Pointnet_Pointnet2_pytorch).
\ No newline at end of file
diff --git a/dataset/process_tools/utils/pointnet_lib/pointnet2_modules.py b/dataset/process_tools/utils/pointnet_lib/pointnet2_modules.py
new file mode 100644
index 0000000..5f125ce
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/pointnet2_modules.py
@@ -0,0 +1,160 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from . import pointnet2_utils
+from . import pytorch_utils as pt_utils
+from typing import List
+
+
+class _PointnetSAModuleBase(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.npoint = None
+ self.groupers = None
+ self.mlps = None
+ self.pool_method = 'max_pool'
+
+ def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor):
+ """
+ :param xyz: (B, N, 3) tensor of the xyz coordinates of the features
+ :param features: (B, N, C) tensor of the descriptors of the the features
+ :param new_xyz:
+ :return:
+ new_xyz: (B, npoint, 3) tensor of the new features' xyz
+ new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
+ """
+ new_features_list = []
+
+ xyz_flipped = xyz.transpose(1, 2).contiguous()
+ if new_xyz is None:
+ new_xyz = pointnet2_utils.gather_operation(
+ xyz_flipped,
+ pointnet2_utils.furthest_point_sample(xyz, self.npoint)
+ ).transpose(1, 2).contiguous() if self.npoint is not None else None
+
+ for i in range(len(self.groupers)):
+ new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
+
+ new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
+ if self.pool_method == 'max_pool':
+ new_features = F.max_pool2d(
+ new_features, kernel_size=[1, new_features.size(3)]
+ ) # (B, mlp[-1], npoint, 1)
+ elif self.pool_method == 'avg_pool':
+ new_features = F.avg_pool2d(
+ new_features, kernel_size=[1, new_features.size(3)]
+ ) # (B, mlp[-1], npoint, 1)
+ else:
+ raise NotImplementedError
+
+ new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
+ new_features_list.append(new_features)
+
+ return new_xyz, torch.cat(new_features_list, dim=1)
+
+
+class PointnetSAModuleMSG(_PointnetSAModuleBase):
+ """Pointnet set abstraction layer with multiscale grouping"""
+
+ def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True,
+ use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
+ """
+ :param npoint: int
+ :param radii: list of float, list of radii to group with
+ :param nsamples: list of int, number of samples in each ball query
+ :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
+ :param bn: whether to use batchnorm
+ :param use_xyz:
+ :param pool_method: max_pool / avg_pool
+ :param instance_norm: whether to use instance_norm
+ """
+ super().__init__()
+
+ assert len(radii) == len(nsamples) == len(mlps)
+
+ self.npoint = npoint
+ self.groupers = nn.ModuleList()
+ self.mlps = nn.ModuleList()
+ for i in range(len(radii)):
+ radius = radii[i]
+ nsample = nsamples[i]
+ self.groupers.append(
+ pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
+ if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
+ )
+ mlp_spec = mlps[i]
+ if use_xyz:
+ mlp_spec[0] += 3
+
+ self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm))
+ self.pool_method = pool_method
+
+
+class PointnetSAModule(PointnetSAModuleMSG):
+ """Pointnet set abstraction layer"""
+
+ def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None,
+ bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
+ """
+ :param mlp: list of int, spec of the pointnet before the global max_pool
+ :param npoint: int, number of features
+ :param radius: float, radius of ball
+ :param nsample: int, number of samples in the ball query
+ :param bn: whether to use batchnorm
+ :param use_xyz:
+ :param pool_method: max_pool / avg_pool
+ :param instance_norm: whether to use instance_norm
+ """
+ super().__init__(
+ mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz,
+ pool_method=pool_method, instance_norm=instance_norm
+ )
+
+
+class PointnetFPModule(nn.Module):
+ r"""Propigates the features of one set to another"""
+
+ def __init__(self, *, mlp: List[int], bn: bool = True):
+ """
+ :param mlp: list of int
+ :param bn: whether to use batchnorm
+ """
+ super().__init__()
+ self.mlp = pt_utils.SharedMLP(mlp, bn=bn)
+
+ def forward(
+ self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
+ :param known: (B, m, 3) tensor of the xyz positions of the known features
+ :param unknow_feats: (B, C1, n) tensor of the features to be propigated to
+ :param known_feats: (B, C2, m) tensor of features to be propigated
+ :return:
+ new_features: (B, mlp[-1], n) tensor of the features of the unknown features
+ """
+ if known is not None:
+ dist, idx = pointnet2_utils.three_nn(unknown, known)
+ dist_recip = 1.0 / (dist + 1e-8)
+ norm = torch.sum(dist_recip, dim=2, keepdim=True)
+ weight = dist_recip / norm
+
+ interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)
+ else:
+ interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1))
+
+ if unknow_feats is not None:
+ new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n)
+ else:
+ new_features = interpolated_feats
+
+ new_features = new_features.unsqueeze(-1)
+ new_features = self.mlp(new_features)
+
+ return new_features.squeeze(-1)
+
+
+if __name__ == "__main__":
+ pass
diff --git a/dataset/process_tools/utils/pointnet_lib/pointnet2_utils.py b/dataset/process_tools/utils/pointnet_lib/pointnet2_utils.py
new file mode 100644
index 0000000..a52a6b4
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/pointnet2_utils.py
@@ -0,0 +1,385 @@
+import torch
+from torch.autograd import Variable
+from torch.autograd import Function
+import torch.nn as nn
+from typing import Tuple
+
+import pointnet2_cuda as pointnet2
+
+
+class FurthestPointSampling(Function):
+ @staticmethod
+ def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
+ """
+ Uses iterative furthest point sampling to select a set of npoint features that have the largest
+ minimum distance
+ :param ctx:
+ :param xyz: (B, N, 3) where N > npoint
+ :param npoint: int, number of features in the sampled set
+ :return:
+ output: (B, npoint) tensor containing the set
+ """
+ xyz = xyz.contiguous()
+ # assert xyz.is_contiguous()
+
+ B, N, _ = xyz.size()
+ output = torch.cuda.IntTensor(B, npoint)
+ temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
+
+ pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
+ return output
+
+ @staticmethod
+ def backward(xyz, a=None):
+ return None, None
+
+
+furthest_point_sample = FurthestPointSampling.apply
+
+
+class GatherOperation(Function):
+
+ @staticmethod
+ def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
+ """
+ :param ctx:
+ :param features: (B, C, N)
+ :param idx: (B, npoint) index tensor of the features to gather
+ :return:
+ output: (B, C, npoint)
+ """
+ features = features.contiguous()
+ idx = idx.contiguous()
+ assert features.is_contiguous()
+ assert idx.is_contiguous()
+
+ B, npoint = idx.size()
+ _, C, N = features.size()
+ output = torch.cuda.FloatTensor(B, C, npoint)
+
+ pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
+
+ ctx.for_backwards = (idx, C, N)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ idx, C, N = ctx.for_backwards
+ B, npoint = idx.size()
+
+ grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
+ grad_out_data = grad_out.data.contiguous()
+ pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
+ return grad_features, None
+
+
+gather_operation = GatherOperation.apply
+
+class KNN(Function):
+
+ @staticmethod
+ def forward(ctx, k: int, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Find the three nearest neighbors of unknown in known
+ :param ctx:
+ :param unknown: (B, N, 3)
+ :param known: (B, M, 3)
+ :return:
+ dist: (B, N, k) l2 distance to the three nearest neighbors
+ idx: (B, N, k) index of 3 nearest neighbors
+ """
+ unknown = unknown.contiguous()
+ known = known.contiguous()
+ assert unknown.is_contiguous()
+ assert known.is_contiguous()
+
+ B, N, _ = unknown.size()
+ m = known.size(1)
+ dist2 = torch.cuda.FloatTensor(B, N, k)
+ idx = torch.cuda.IntTensor(B, N, k)
+
+ pointnet2.knn_wrapper(B, N, m, k, unknown, known, dist2, idx)
+ return torch.sqrt(dist2), idx
+
+ @staticmethod
+ def backward(ctx, a=None, b=None):
+ return None, None, None
+
+knn = KNN.apply
+
+class ThreeNN(Function):
+
+ @staticmethod
+ def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Find the three nearest neighbors of unknown in known
+ :param ctx:
+ :param unknown: (B, N, 3)
+ :param known: (B, M, 3)
+ :return:
+ dist: (B, N, 3) l2 distance to the three nearest neighbors
+ idx: (B, N, 3) index of 3 nearest neighbors
+ """
+ unknown = unknown.contiguous()
+ known = known.contiguous()
+ assert unknown.is_contiguous()
+ assert known.is_contiguous()
+
+ B, N, _ = unknown.size()
+ m = known.size(1)
+ dist2 = torch.cuda.FloatTensor(B, N, 3)
+ idx = torch.cuda.IntTensor(B, N, 3)
+
+ pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
+ return torch.sqrt(dist2), idx
+
+ @staticmethod
+ def backward(ctx, a=None, b=None):
+ return None, None
+
+
+three_nn = ThreeNN.apply
+
+
+class ThreeInterpolate(Function):
+
+ @staticmethod
+ def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
+ """
+ Performs weight linear interpolation on 3 features
+ :param ctx:
+ :param features: (B, C, M) Features descriptors to be interpolated from
+ :param idx: (B, n, 3) three nearest neighbors of the target features in features
+ :param weight: (B, n, 3) weights
+ :return:
+ output: (B, C, N) tensor of the interpolated features
+ """
+ features = features.contiguous()
+ idx = idx.contiguous()
+ weight = weight.contiguous()
+ assert features.is_contiguous()
+ assert idx.is_contiguous()
+ assert weight.is_contiguous()
+
+ B, c, m = features.size()
+ n = idx.size(1)
+ ctx.three_interpolate_for_backward = (idx, weight, m)
+ output = torch.cuda.FloatTensor(B, c, n)
+
+ pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ :param ctx:
+ :param grad_out: (B, C, N) tensor with gradients of outputs
+ :return:
+ grad_features: (B, C, M) tensor with gradients of features
+ None:
+ None:
+ """
+ idx, weight, m = ctx.three_interpolate_for_backward
+ B, c, n = grad_out.size()
+
+ grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
+ grad_out_data = grad_out.data.contiguous()
+
+ pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
+ return grad_features, None, None
+
+
+three_interpolate = ThreeInterpolate.apply
+
+
+class GroupingOperation(Function):
+
+ @staticmethod
+ def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
+ """
+ :param ctx:
+ :param features: (B, C, N) tensor of features to group
+ :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
+ :return:
+ output: (B, C, npoint, nsample) tensor
+ """
+ features = features.contiguous()
+ idx = idx.contiguous()
+ assert features.is_contiguous()
+ assert idx.is_contiguous()
+ idx = idx.int()
+ B, nfeatures, nsample = idx.size()
+ _, C, N = features.size()
+ output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
+
+ pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
+
+ ctx.for_backwards = (idx, N)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ :param ctx:
+ :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
+ :return:
+ grad_features: (B, C, N) gradient of the features
+ """
+ idx, N = ctx.for_backwards
+
+ B, C, npoint, nsample = grad_out.size()
+ grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
+
+ grad_out_data = grad_out.data.contiguous()
+ pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
+ return grad_features, None
+
+
+grouping_operation = GroupingOperation.apply
+
+
+class BallQuery(Function):
+
+ @staticmethod
+ def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
+ """
+ :param ctx:
+ :param radius: float, radius of the balls
+ :param nsample: int, maximum number of features in the balls
+ :param xyz: (B, N, 3) xyz coordinates of the features
+ :param new_xyz: (B, npoint, 3) centers of the ball query
+ :return:
+ idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
+ """
+ new_xyz = new_xyz.contiguous()
+ xyz = xyz.contiguous()
+ assert new_xyz.is_contiguous()
+ assert xyz.is_contiguous()
+
+ B, N, _ = xyz.size()
+ npoint = new_xyz.size(1)
+ idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
+
+ pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
+ return idx
+
+ @staticmethod
+ def backward(ctx, a=None):
+ return None, None, None, None
+
+
+ball_query = BallQuery.apply
+
+
+class QueryAndGroup(nn.Module):
+ def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
+ """
+ :param radius: float, radius of ball
+ :param nsample: int, maximum number of features to gather in the ball
+ :param use_xyz:
+ """
+ super().__init__()
+ self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
+
+ def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
+ """
+ :param xyz: (B, N, 3) xyz coordinates of the features
+ :param new_xyz: (B, npoint, 3) centroids
+ :param features: (B, C, N) descriptors of the features
+ :return:
+ new_features: (B, 3 + C, npoint, nsample)
+ """
+ idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
+ xyz_trans = xyz.transpose(1, 2).contiguous()
+ grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
+ grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
+
+ if features is not None:
+ grouped_features = grouping_operation(features, idx)
+ if self.use_xyz:
+ new_features = torch.cat([grouped_features, grouped_xyz] , dim=1) # (B, C + 3, npoint, nsample)
+ else:
+ new_features = grouped_features
+ else:
+ assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
+ new_features = grouped_xyz
+
+ return new_features
+
+
+class GroupAll(nn.Module):
+ def __init__(self, use_xyz: bool = True):
+ super().__init__()
+ self.use_xyz = use_xyz
+
+ def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
+ """
+ :param xyz: (B, N, 3) xyz coordinates of the features
+ :param new_xyz: ignored
+ :param features: (B, C, N) descriptors of the features
+ :return:
+ new_features: (B, C + 3, 1, N)
+ """
+ grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
+ if features is not None:
+ grouped_features = features.unsqueeze(2)
+ if self.use_xyz:
+ new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
+ else:
+ new_features = grouped_features
+ else:
+ new_features = grouped_xyz
+
+ return new_features
+
+class KNNAndGroup(nn.Module):
+ def __init__(self, radius:float, nsample: int, use_xyz: bool = True):
+ """
+ :param radius: float, radius of ball
+ :param nsample: int, maximum number of features to gather in the ball
+ :param use_xyz:
+ """
+ super().__init__()
+ self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
+
+ def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor = None, idx: torch.Tensor = None, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
+ """
+ :param xyz: (B, N, 3) xyz coordinates of the features
+ :param new_xyz: (B, M, 3) centroids
+ :param idx: (B, M, K) centroids
+ :param features: (B, C, N) descriptors of the features
+ :return:
+ new_features: (B, 3 + C, M, K) if use_xyz = True else (B, C, M, K)
+ """
+
+ ##TODO: implement new_xyz into knn
+ if new_xyz is None:
+ new_xyz = xyz
+
+ if idx is None:
+ idx = knn(xyz, new_xyz, self.radius, self.nsample) # B, M, K
+ idx = idx.detach()
+
+ xyz_trans = xyz.transpose(1, 2).contiguous()
+ new_xyz_trans = new_xyz.transpose(1, 2).contiguous()
+
+ grouped_xyz = grouping_operation(xyz_trans, idx) # B, 3, M, K
+ grouped_xyz -= new_xyz_trans.unsqueeze(-1) # B, 3, M, K
+ #grouped_r = torch.norm(grouped_xyz, dim=1).max(dim=-1)[0]#B,M
+ #print(new_xyz.shape[1], grouped_r)
+
+ if features is not None:
+ grouped_features = grouping_operation(features, idx) # B, C, M, K
+ # grouped_features_test = grouping_operation(features, idx)
+ # assert (grouped_features == grouped_features).all()
+ if self.use_xyz:
+ new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, M, K)
+ else:
+ new_features = grouped_features
+ else:
+ assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
+ new_features = grouped_xyz
+
+ return new_features
+
+
diff --git a/dataset/process_tools/utils/pointnet_lib/pytorch_utils.py b/dataset/process_tools/utils/pointnet_lib/pytorch_utils.py
new file mode 100644
index 0000000..09cb7bc
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/pytorch_utils.py
@@ -0,0 +1,236 @@
+import torch.nn as nn
+from typing import List, Tuple
+
+
+class SharedMLP(nn.Sequential):
+
+ def __init__(
+ self,
+ args: List[int],
+ *,
+ bn: bool = False,
+ activation=nn.ReLU(inplace=True),
+ preact: bool = False,
+ first: bool = False,
+ name: str = "",
+ instance_norm: bool = False,
+ ):
+ super().__init__()
+
+ for i in range(len(args) - 1):
+ self.add_module(
+ name + 'layer{}'.format(i),
+ Conv2d(
+ args[i],
+ args[i + 1],
+ bn=(not first or not preact or (i != 0)) and bn,
+ activation=activation
+ if (not first or not preact or (i != 0)) else None,
+ preact=preact,
+ instance_norm=instance_norm
+ )
+ )
+
+
+class _ConvBase(nn.Sequential):
+
+ def __init__(
+ self,
+ in_size,
+ out_size,
+ kernel_size,
+ stride,
+ padding,
+ activation,
+ bn,
+ init,
+ conv=None,
+ batch_norm=None,
+ bias=True,
+ preact=False,
+ name="",
+ instance_norm=False,
+ instance_norm_func=None
+ ):
+ super().__init__()
+
+ bias = bias and (not bn)
+ conv_unit = conv(
+ in_size,
+ out_size,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ bias=bias
+ )
+ init(conv_unit.weight)
+ if bias:
+ nn.init.constant_(conv_unit.bias, 0)
+
+ if bn:
+ if not preact:
+ bn_unit = batch_norm(out_size)
+ else:
+ bn_unit = batch_norm(in_size)
+ if instance_norm:
+ if not preact:
+ in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False)
+ else:
+ in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False)
+
+ if preact:
+ if bn:
+ self.add_module(name + 'bn', bn_unit)
+
+ if activation is not None:
+ self.add_module(name + 'activation', activation)
+
+ if not bn and instance_norm:
+ self.add_module(name + 'in', in_unit)
+
+ self.add_module(name + 'conv', conv_unit)
+
+ if not preact:
+ if bn:
+ self.add_module(name + 'bn', bn_unit)
+
+ if activation is not None:
+ self.add_module(name + 'activation', activation)
+
+ if not bn and instance_norm:
+ self.add_module(name + 'in', in_unit)
+
+
+class _BNBase(nn.Sequential):
+
+ def __init__(self, in_size, batch_norm=None, name=""):
+ super().__init__()
+ self.add_module(name + "bn", batch_norm(in_size))
+
+ nn.init.constant_(self[0].weight, 1.0)
+ nn.init.constant_(self[0].bias, 0)
+
+
+class BatchNorm1d(_BNBase):
+
+ def __init__(self, in_size: int, *, name: str = ""):
+ super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
+
+
+class BatchNorm2d(_BNBase):
+
+ def __init__(self, in_size: int, name: str = ""):
+ super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
+
+
+class Conv1d(_ConvBase):
+
+ def __init__(
+ self,
+ in_size: int,
+ out_size: int,
+ *,
+ kernel_size: int = 1,
+ stride: int = 1,
+ padding: int = 0,
+ activation=nn.ReLU(inplace=True),
+ bn: bool = False,
+ init=nn.init.kaiming_normal_,
+ bias: bool = True,
+ preact: bool = False,
+ name: str = "",
+ instance_norm=False
+ ):
+ super().__init__(
+ in_size,
+ out_size,
+ kernel_size,
+ stride,
+ padding,
+ activation,
+ bn,
+ init,
+ conv=nn.Conv1d,
+ batch_norm=BatchNorm1d,
+ bias=bias,
+ preact=preact,
+ name=name,
+ instance_norm=instance_norm,
+ instance_norm_func=nn.InstanceNorm1d
+ )
+
+
+class Conv2d(_ConvBase):
+
+ def __init__(
+ self,
+ in_size: int,
+ out_size: int,
+ *,
+ kernel_size: Tuple[int, int] = (1, 1),
+ stride: Tuple[int, int] = (1, 1),
+ padding: Tuple[int, int] = (0, 0),
+ activation=nn.ReLU(inplace=True),
+ bn: bool = False,
+ init=nn.init.kaiming_normal_,
+ bias: bool = True,
+ preact: bool = False,
+ name: str = "",
+ instance_norm=False
+ ):
+ super().__init__(
+ in_size,
+ out_size,
+ kernel_size,
+ stride,
+ padding,
+ activation,
+ bn,
+ init,
+ conv=nn.Conv2d,
+ batch_norm=BatchNorm2d,
+ bias=bias,
+ preact=preact,
+ name=name,
+ instance_norm=instance_norm,
+ instance_norm_func=nn.InstanceNorm2d
+ )
+
+
+class FC(nn.Sequential):
+
+ def __init__(
+ self,
+ in_size: int,
+ out_size: int,
+ *,
+ activation=nn.ReLU(inplace=True),
+ bn: bool = False,
+ init=None,
+ preact: bool = False,
+ name: str = ""
+ ):
+ super().__init__()
+
+ fc = nn.Linear(in_size, out_size, bias=not bn)
+ if init is not None:
+ init(fc.weight)
+ if not bn:
+ nn.init.constant(fc.bias, 0)
+
+ if preact:
+ if bn:
+ self.add_module(name + 'bn', BatchNorm1d(in_size))
+
+ if activation is not None:
+ self.add_module(name + 'activation', activation)
+
+ self.add_module(name + 'fc', fc)
+
+ if not preact:
+ if bn:
+ self.add_module(name + 'bn', BatchNorm1d(out_size))
+
+ if activation is not None:
+ self.add_module(name + 'activation', activation)
+
diff --git a/dataset/process_tools/utils/pointnet_lib/setup.py b/dataset/process_tools/utils/pointnet_lib/setup.py
new file mode 100644
index 0000000..99e59e3
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/setup.py
@@ -0,0 +1,23 @@
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+setup(
+ name='pointnet2',
+ ext_modules=[
+ CUDAExtension('pointnet2_cuda', [
+ 'src/pointnet2_api.cpp',
+
+ 'src/ball_query.cpp',
+ 'src/ball_query_gpu.cu',
+ 'src/group_points.cpp',
+ 'src/group_points_gpu.cu',
+ 'src/interpolate.cpp',
+ 'src/interpolate_gpu.cu',
+ 'src/sampling.cpp',
+ 'src/sampling_gpu.cu',
+ ],
+ extra_compile_args={'cxx': ['-g'],
+ 'nvcc': ['-O2']})
+ ],
+ cmdclass={'build_ext': BuildExtension}
+)
diff --git a/dataset/process_tools/utils/pointnet_lib/src/ball_query.cpp b/dataset/process_tools/utils/pointnet_lib/src/ball_query.cpp
new file mode 100644
index 0000000..91c1768
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/src/ball_query.cpp
@@ -0,0 +1,25 @@
+#include
+#include
+#include
+#include
+#include
+#include "ball_query_gpu.h"
+
+extern THCState *state;
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
+#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
+
+int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
+ at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) {
+ CHECK_INPUT(new_xyz_tensor);
+ CHECK_INPUT(xyz_tensor);
+ const float *new_xyz = new_xyz_tensor.data();
+ const float *xyz = xyz_tensor.data();
+ int *idx = idx_tensor.data();
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(); //THCState_getCurrentStream(state);
+ ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream);
+ return 1;
+}
\ No newline at end of file
diff --git a/dataset/process_tools/utils/pointnet_lib/src/ball_query_gpu.cu b/dataset/process_tools/utils/pointnet_lib/src/ball_query_gpu.cu
new file mode 100644
index 0000000..f8840aa
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/src/ball_query_gpu.cu
@@ -0,0 +1,67 @@
+#include
+#include
+#include
+
+#include "ball_query_gpu.h"
+#include "cuda_utils.h"
+
+
+__global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample,
+ const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) {
+ // new_xyz: (B, M, 3)
+ // xyz: (B, N, 3)
+ // output:
+ // idx: (B, M, nsample)
+ int bs_idx = blockIdx.y;
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (bs_idx >= b || pt_idx >= m) return;
+
+ new_xyz += bs_idx * m * 3 + pt_idx * 3;
+ xyz += bs_idx * n * 3;
+ idx += bs_idx * m * nsample + pt_idx * nsample;
+
+ float radius2 = radius * radius;
+ float new_x = new_xyz[0];
+ float new_y = new_xyz[1];
+ float new_z = new_xyz[2];
+
+ int cnt = 0;
+ for (int k = 0; k < n; ++k) {
+ float x = xyz[k * 3 + 0];
+ float y = xyz[k * 3 + 1];
+ float z = xyz[k * 3 + 2];
+ float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
+ if (d2 < radius2){
+ if (cnt == 0){
+ for (int l = 0; l < nsample; ++l) {
+ idx[l] = k;
+ }
+ }
+ idx[cnt] = k;
+ ++cnt;
+ if (cnt >= nsample) break;
+ }
+ }
+}
+
+
+void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \
+ const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) {
+ // new_xyz: (B, M, 3)
+ // xyz: (B, N, 3)
+ // output:
+ // idx: (B, M, nsample)
+
+ cudaError_t err;
+
+ dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK);
+
+ ball_query_kernel_fast<<>>(b, n, m, radius, nsample, new_xyz, xyz, idx);
+ // cudaDeviceSynchronize(); // for using printf in kernel function
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
\ No newline at end of file
diff --git a/dataset/process_tools/utils/pointnet_lib/src/ball_query_gpu.h b/dataset/process_tools/utils/pointnet_lib/src/ball_query_gpu.h
new file mode 100644
index 0000000..ffc831a
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/src/ball_query_gpu.h
@@ -0,0 +1,15 @@
+#ifndef _BALL_QUERY_GPU_H
+#define _BALL_QUERY_GPU_H
+
+#include
+#include
+#include
+#include
+
+int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
+ at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor);
+
+void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample,
+ const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream);
+
+#endif
diff --git a/dataset/process_tools/utils/pointnet_lib/src/cuda_utils.h b/dataset/process_tools/utils/pointnet_lib/src/cuda_utils.h
new file mode 100644
index 0000000..7fe2796
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/src/cuda_utils.h
@@ -0,0 +1,15 @@
+#ifndef _CUDA_UTILS_H
+#define _CUDA_UTILS_H
+
+#include
+
+#define TOTAL_THREADS 1024
+#define THREADS_PER_BLOCK 256
+#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
+
+inline int opt_n_threads(int work_size) {
+ const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
+
+ return max(min(1 << pow_2, TOTAL_THREADS), 1);
+}
+#endif
diff --git a/dataset/process_tools/utils/pointnet_lib/src/group_points.cpp b/dataset/process_tools/utils/pointnet_lib/src/group_points.cpp
new file mode 100644
index 0000000..1bbbc1e
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/src/group_points.cpp
@@ -0,0 +1,36 @@
+#include
+#include
+#include
+#include
+#include
+#include "group_points_gpu.h"
+
+extern THCState *state;
+
+
+int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample,
+ at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
+
+ float *grad_points = grad_points_tensor.data();
+ const int *idx = idx_tensor.data();
+ const float *grad_out = grad_out_tensor.data();
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(); //THCState_getCurrentStream(state);
+
+ group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream);
+ return 1;
+}
+
+
+int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample,
+ at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) {
+
+ const float *points = points_tensor.data();
+ const int *idx = idx_tensor.data();
+ float *out = out_tensor.data();
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(); //THCState_getCurrentStream(state);
+
+ group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream);
+ return 1;
+}
\ No newline at end of file
diff --git a/dataset/process_tools/utils/pointnet_lib/src/group_points_gpu.cu b/dataset/process_tools/utils/pointnet_lib/src/group_points_gpu.cu
new file mode 100644
index 0000000..c015a81
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/src/group_points_gpu.cu
@@ -0,0 +1,86 @@
+#include
+#include
+
+#include "cuda_utils.h"
+#include "group_points_gpu.h"
+
+
+__global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample,
+ const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) {
+ // grad_out: (B, C, npoints, nsample)
+ // idx: (B, npoints, nsample)
+ // output:
+ // grad_points: (B, C, N)
+ int bs_idx = blockIdx.z;
+ int c_idx = blockIdx.y;
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ int pt_idx = index / nsample;
+ if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
+
+ int sample_idx = index % nsample;
+ grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
+ idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
+
+ atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]);
+}
+
+void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
+ const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) {
+ // grad_out: (B, C, npoints, nsample)
+ // idx: (B, npoints, nsample)
+ // output:
+ // grad_points: (B, C, N)
+ cudaError_t err;
+ dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK);
+
+ group_points_grad_kernel_fast<<>>(b, c, n, npoints, nsample, grad_out, idx, grad_points);
+
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
+
+
+__global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample,
+ const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
+ // points: (B, C, N)
+ // idx: (B, npoints, nsample)
+ // output:
+ // out: (B, C, npoints, nsample)
+ int bs_idx = blockIdx.z;
+ int c_idx = blockIdx.y;
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ int pt_idx = index / nsample;
+ if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
+
+ int sample_idx = index % nsample;
+
+ idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
+ int in_idx = bs_idx * c * n + c_idx * n + idx[0];
+ int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
+
+ out[out_idx] = points[in_idx];
+}
+
+
+void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
+ const float *points, const int *idx, float *out, cudaStream_t stream) {
+ // points: (B, C, N)
+ // idx: (B, npoints, nsample)
+ // output:
+ // out: (B, C, npoints, nsample)
+ cudaError_t err;
+ dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK);
+
+ group_points_kernel_fast<<>>(b, c, n, npoints, nsample, points, idx, out);
+ // cudaDeviceSynchronize(); // for using printf in kernel function
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
diff --git a/dataset/process_tools/utils/pointnet_lib/src/group_points_gpu.h b/dataset/process_tools/utils/pointnet_lib/src/group_points_gpu.h
new file mode 100644
index 0000000..76c73ca
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/src/group_points_gpu.h
@@ -0,0 +1,22 @@
+#ifndef _GROUP_POINTS_GPU_H
+#define _GROUP_POINTS_GPU_H
+
+#include
+#include
+#include
+#include
+
+
+int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample,
+ at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
+
+void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
+ const float *points, const int *idx, float *out, cudaStream_t stream);
+
+int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample,
+ at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
+
+void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample,
+ const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);
+
+#endif
diff --git a/dataset/process_tools/utils/pointnet_lib/src/interpolate.cpp b/dataset/process_tools/utils/pointnet_lib/src/interpolate.cpp
new file mode 100644
index 0000000..3efbaa1
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/src/interpolate.cpp
@@ -0,0 +1,69 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "interpolate_gpu.h"
+
+extern THCState *state;
+
+
+void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
+ at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) {
+ const float *unknown = unknown_tensor.data();
+ const float *known = known_tensor.data();
+ float *dist2 = dist2_tensor.data();
+ int *idx = idx_tensor.data();
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(); //THCState_getCurrentStream(state);
+ // cudaStream_t stream = THCState_getCurrentStream(state);
+ three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream);
+}
+
+void knn_wrapper_fast(int b, int n, int m, int k, at::Tensor unknown_tensor,
+ at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) {
+ const float *unknown = unknown_tensor.data();
+ const float *known = known_tensor.data();
+ float *dist2 = dist2_tensor.data();
+ int *idx = idx_tensor.data();
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(); //THCState_getCurrentStream(state);
+ // cudaStream_t stream = THCState_getCurrentStream(state);
+ knn_kernel_launcher_fast(b, n, m, k, unknown, known, dist2, idx, stream);
+}
+
+
+void three_interpolate_wrapper_fast(int b, int c, int m, int n,
+ at::Tensor points_tensor,
+ at::Tensor idx_tensor,
+ at::Tensor weight_tensor,
+ at::Tensor out_tensor) {
+
+ const float *points = points_tensor.data();
+ const float *weight = weight_tensor.data();
+ float *out = out_tensor.data();
+ const int *idx = idx_tensor.data();
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(); //THCState_getCurrentStream(state);
+ // cudaStream_t stream = THCState_getCurrentStream(state);
+ three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream);
+}
+
+void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m,
+ at::Tensor grad_out_tensor,
+ at::Tensor idx_tensor,
+ at::Tensor weight_tensor,
+ at::Tensor grad_points_tensor) {
+
+ const float *grad_out = grad_out_tensor.data();
+ const float *weight = weight_tensor.data();
+ float *grad_points = grad_points_tensor.data();
+ const int *idx = idx_tensor.data();
+
+ // cudaStream_t stream = THCState_getCurrentStream(state);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(); //THCState_getCurrentStream(state);
+ three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream);
+}
\ No newline at end of file
diff --git a/dataset/process_tools/utils/pointnet_lib/src/interpolate_gpu.cu b/dataset/process_tools/utils/pointnet_lib/src/interpolate_gpu.cu
new file mode 100644
index 0000000..2bcac2e
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/src/interpolate_gpu.cu
@@ -0,0 +1,233 @@
+#include
+#include
+#include
+
+#include "cuda_utils.h"
+#include "interpolate_gpu.h"
+
+
+__global__ void knn_kernel_fast(int b, int n, int m, int k, const float *__restrict__ unknown,
+ const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) {
+ // unknown: (B, N, 3)
+ // known: (B, M, 3)
+ // output:
+ // dist2: (B, N, k)
+ // idx: (B, N, k)
+
+ int bs_idx = blockIdx.y;
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (bs_idx >= b || pt_idx >= n) return;
+
+ unknown += bs_idx * n * 3 + pt_idx * 3;
+ known += bs_idx * m * 3;
+ dist2 += bs_idx * n * k + pt_idx * k;
+ idx += bs_idx * n * k + pt_idx * k;
+
+ float ux = unknown[0];
+ float uy = unknown[1];
+ float uz = unknown[2];
+
+ double best[200];
+ int besti[200];
+ for(int i = 0; i < k; i++){
+ best[i] = 1e40;
+ besti[i] = 0;
+ }
+ for (int i = 0; i < m; ++i) {
+ float x = known[i * 3 + 0];
+ float y = known[i * 3 + 1];
+ float z = known[i * 3 + 2];
+ float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
+ for(int j = 0; j < k; j++){
+ if(d < best[j]){
+ for(int l = k - 1; l > j; l--){
+ best[l] = best[l - 1];
+ besti[l] = besti[l - 1];
+ }
+ best[j] = d;
+ besti[j] = i;
+ break;
+ }
+ }
+ }
+ for(int i = 0; i < k; i++){
+ idx[i] = besti[i];
+ dist2[i] = best[i];
+ }
+}
+
+
+void knn_kernel_launcher_fast(int b, int n, int m, int k, const float *unknown,
+ const float *known, float *dist2, int *idx, cudaStream_t stream) {
+ // unknown: (B, N, 3)
+ // known: (B, M, 3)
+ // output:
+ // dist2: (B, N, k)
+ // idx: (B, N, k)
+
+ cudaError_t err;
+ dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK);
+
+ knn_kernel_fast<<>>(b, n, m, k, unknown, known, dist2, idx);
+
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
+
+__global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown,
+ const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) {
+ // unknown: (B, N, 3)
+ // known: (B, M, 3)
+ // output:
+ // dist2: (B, N, 3)
+ // idx: (B, N, 3)
+
+ int bs_idx = blockIdx.y;
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (bs_idx >= b || pt_idx >= n) return;
+
+ unknown += bs_idx * n * 3 + pt_idx * 3;
+ known += bs_idx * m * 3;
+ dist2 += bs_idx * n * 3 + pt_idx * 3;
+ idx += bs_idx * n * 3 + pt_idx * 3;
+
+ float ux = unknown[0];
+ float uy = unknown[1];
+ float uz = unknown[2];
+
+ double best1 = 1e40, best2 = 1e40, best3 = 1e40;
+ int besti1 = 0, besti2 = 0, besti3 = 0;
+ for (int k = 0; k < m; ++k) {
+ float x = known[k * 3 + 0];
+ float y = known[k * 3 + 1];
+ float z = known[k * 3 + 2];
+ float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
+ if (d < best1) {
+ best3 = best2; besti3 = besti2;
+ best2 = best1; besti2 = besti1;
+ best1 = d; besti1 = k;
+ }
+ else if (d < best2) {
+ best3 = best2; besti3 = besti2;
+ best2 = d; besti2 = k;
+ }
+ else if (d < best3) {
+ best3 = d; besti3 = k;
+ }
+ }
+ dist2[0] = best1; dist2[1] = best2; dist2[2] = best3;
+ idx[0] = besti1; idx[1] = besti2; idx[2] = besti3;
+}
+
+
+void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown,
+ const float *known, float *dist2, int *idx, cudaStream_t stream) {
+ // unknown: (B, N, 3)
+ // known: (B, M, 3)
+ // output:
+ // dist2: (B, N, 3)
+ // idx: (B, N, 3)
+
+ cudaError_t err;
+ dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK);
+
+ three_nn_kernel_fast<<>>(b, n, m, unknown, known, dist2, idx);
+
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
+
+
+__global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points,
+ const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) {
+ // points: (B, C, M)
+ // idx: (B, N, 3)
+ // weight: (B, N, 3)
+ // output:
+ // out: (B, C, N)
+
+ int bs_idx = blockIdx.z;
+ int c_idx = blockIdx.y;
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
+
+ weight += bs_idx * n * 3 + pt_idx * 3;
+ points += bs_idx * c * m + c_idx * m;
+ idx += bs_idx * n * 3 + pt_idx * 3;
+ out += bs_idx * c * n + c_idx * n;
+
+ out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]];
+}
+
+void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n,
+ const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) {
+ // points: (B, C, M)
+ // idx: (B, N, 3)
+ // weight: (B, N, 3)
+ // output:
+ // out: (B, C, N)
+
+ cudaError_t err;
+ dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK);
+ three_interpolate_kernel_fast<<>>(b, c, m, n, points, idx, weight, out);
+
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
+
+
+__global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out,
+ const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) {
+ // grad_out: (B, C, N)
+ // weight: (B, N, 3)
+ // output:
+ // grad_points: (B, C, M)
+
+ int bs_idx = blockIdx.z;
+ int c_idx = blockIdx.y;
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
+
+ grad_out += bs_idx * c * n + c_idx * n + pt_idx;
+ weight += bs_idx * n * 3 + pt_idx * 3;
+ grad_points += bs_idx * c * m + c_idx * m;
+ idx += bs_idx * n * 3 + pt_idx * 3;
+
+
+ atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
+ atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
+ atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
+}
+
+void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out,
+ const int *idx, const float *weight, float *grad_points, cudaStream_t stream) {
+ // grad_out: (B, C, N)
+ // weight: (B, N, 3)
+ // output:
+ // grad_points: (B, C, M)
+
+ cudaError_t err;
+ dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK);
+ three_interpolate_grad_kernel_fast<<>>(b, c, n, m, grad_out, idx, weight, grad_points);
+
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
\ No newline at end of file
diff --git a/dataset/process_tools/utils/pointnet_lib/src/interpolate_gpu.h b/dataset/process_tools/utils/pointnet_lib/src/interpolate_gpu.h
new file mode 100644
index 0000000..5ceb3ad
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/src/interpolate_gpu.h
@@ -0,0 +1,36 @@
+#ifndef _INTERPOLATE_GPU_H
+#define _INTERPOLATE_GPU_H
+
+#include
+#include
+#include
+#include
+
+
+void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
+ at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor);
+
+void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown,
+ const float *known, float *dist2, int *idx, cudaStream_t stream);
+
+void knn_wrapper_fast(int b, int n, int m, int k, at::Tensor unknown_tensor,
+ at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor);
+
+void knn_kernel_launcher_fast(int b, int n, int m, int k, const float *unknown,
+ const float *known, float *dist2, int *idx, cudaStream_t stream);
+
+
+void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor,
+ at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor);
+
+void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n,
+ const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream);
+
+
+void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor,
+ at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor);
+
+void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out,
+ const int *idx, const float *weight, float *grad_points, cudaStream_t stream);
+
+#endif
diff --git a/dataset/process_tools/utils/pointnet_lib/src/pointnet2_api.cpp b/dataset/process_tools/utils/pointnet_lib/src/pointnet2_api.cpp
new file mode 100644
index 0000000..def9a3c
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/src/pointnet2_api.cpp
@@ -0,0 +1,25 @@
+#include
+#include
+
+#include "ball_query_gpu.h"
+#include "group_points_gpu.h"
+#include "sampling_gpu.h"
+#include "interpolate_gpu.h"
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast");
+
+ m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast");
+ m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast");
+
+ m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast");
+ m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast");
+
+ m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
+
+ m.def("knn_wrapper", &knn_wrapper_fast, "knn_wrapper_fast");
+ m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast");
+ m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast");
+ m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast");
+}
diff --git a/dataset/process_tools/utils/pointnet_lib/src/sampling.cpp b/dataset/process_tools/utils/pointnet_lib/src/sampling.cpp
new file mode 100644
index 0000000..380132d
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/src/sampling.cpp
@@ -0,0 +1,49 @@
+#include
+#include
+#include
+#include
+
+#include "sampling_gpu.h"
+
+extern THCState *state;
+
+
+int gather_points_wrapper_fast(int b, int c, int n, int npoints,
+ at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){
+ const float *points = points_tensor.data();
+ const int *idx = idx_tensor.data();
+ float *out = out_tensor.data();
+
+ // cudaStream_t stream = THCState_getCurrentStream(state);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(); //THCState_getCurrentStream(state);
+ gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream);
+ return 1;
+}
+
+
+int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
+ at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
+
+ const float *grad_out = grad_out_tensor.data();
+ const int *idx = idx_tensor.data();
+ float *grad_points = grad_points_tensor.data();
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(); //THCState_getCurrentStream(state);
+ // cudaStream_t stream = THCState_getCurrentStream(state);
+ gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream);
+ return 1;
+}
+
+
+int furthest_point_sampling_wrapper(int b, int n, int m,
+ at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
+
+ const float *points = points_tensor.data();
+ float *temp = temp_tensor.data();
+ int *idx = idx_tensor.data();
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(); //THCState_getCurrentStream(state);
+ // cudaStream_t stream = THCState_getCurrentStream(state);
+ furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream);
+ return 1;
+}
diff --git a/dataset/process_tools/utils/pointnet_lib/src/sampling_gpu.cu b/dataset/process_tools/utils/pointnet_lib/src/sampling_gpu.cu
new file mode 100644
index 0000000..9e49a60
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/src/sampling_gpu.cu
@@ -0,0 +1,253 @@
+#include
+#include
+
+#include "cuda_utils.h"
+#include "sampling_gpu.h"
+
+
+__global__ void gather_points_kernel_fast(int b, int c, int n, int m,
+ const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
+ // points: (B, C, N)
+ // idx: (B, M)
+ // output:
+ // out: (B, C, M)
+
+ int bs_idx = blockIdx.z;
+ int c_idx = blockIdx.y;
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
+
+ out += bs_idx * c * m + c_idx * m + pt_idx;
+ idx += bs_idx * m + pt_idx;
+ points += bs_idx * c * n + c_idx * n;
+ out[0] = points[idx[0]];
+}
+
+void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints,
+ const float *points, const int *idx, float *out, cudaStream_t stream) {
+ // points: (B, C, N)
+ // idx: (B, npoints)
+ // output:
+ // out: (B, C, npoints)
+
+ cudaError_t err;
+ dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK);
+
+ gather_points_kernel_fast<<>>(b, c, n, npoints, points, idx, out);
+
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
+
+__global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out,
+ const int *__restrict__ idx, float *__restrict__ grad_points) {
+ // grad_out: (B, C, M)
+ // idx: (B, M)
+ // output:
+ // grad_points: (B, C, N)
+
+ int bs_idx = blockIdx.z;
+ int c_idx = blockIdx.y;
+ int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
+
+ grad_out += bs_idx * c * m + c_idx * m + pt_idx;
+ idx += bs_idx * m + pt_idx;
+ grad_points += bs_idx * c * n + c_idx * n;
+
+ atomicAdd(grad_points + idx[0], grad_out[0]);
+}
+
+void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
+ const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) {
+ // grad_out: (B, C, npoints)
+ // idx: (B, npoints)
+ // output:
+ // grad_points: (B, C, N)
+
+ cudaError_t err;
+ dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK);
+
+ gather_points_grad_kernel_fast<<>>(b, c, n, npoints, grad_out, idx, grad_points);
+
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
+
+
+__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){
+ const float v1 = dists[idx1], v2 = dists[idx2];
+ const int i1 = dists_i[idx1], i2 = dists_i[idx2];
+ dists[idx1] = max(v1, v2);
+ dists_i[idx1] = v2 > v1 ? i2 : i1;
+}
+
+template
+__global__ void furthest_point_sampling_kernel(int b, int n, int m,
+ const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
+ // dataset: (B, N, 3)
+ // tmp: (B, N)
+ // output:
+ // idx: (B, M)
+
+ if (m <= 0) return;
+ __shared__ float dists[block_size];
+ __shared__ int dists_i[block_size];
+
+ int batch_index = blockIdx.x;
+ dataset += batch_index * n * 3;
+ temp += batch_index * n;
+ idxs += batch_index * m;
+
+ int tid = threadIdx.x;
+ const int stride = block_size;
+
+ int old = 0;
+ if (threadIdx.x == 0)
+ idxs[0] = old;
+
+ __syncthreads();
+ for (int j = 1; j < m; j++) {
+ int besti = 0;
+ float best = -1;
+ float x1 = dataset[old * 3 + 0];
+ float y1 = dataset[old * 3 + 1];
+ float z1 = dataset[old * 3 + 2];
+ for (int k = tid; k < n; k += stride) {
+ float x2, y2, z2;
+ x2 = dataset[k * 3 + 0];
+ y2 = dataset[k * 3 + 1];
+ z2 = dataset[k * 3 + 2];
+ // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
+ // if (mag <= 1e-3)
+ // continue;
+
+ float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
+ float d2 = min(d, temp[k]);
+ temp[k] = d2;
+ besti = d2 > best ? k : besti;
+ best = d2 > best ? d2 : best;
+ }
+ dists[tid] = best;
+ dists_i[tid] = besti;
+ __syncthreads();
+
+ if (block_size >= 1024) {
+ if (tid < 512) {
+ __update(dists, dists_i, tid, tid + 512);
+ }
+ __syncthreads();
+ }
+
+ if (block_size >= 512) {
+ if (tid < 256) {
+ __update(dists, dists_i, tid, tid + 256);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 256) {
+ if (tid < 128) {
+ __update(dists, dists_i, tid, tid + 128);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 128) {
+ if (tid < 64) {
+ __update(dists, dists_i, tid, tid + 64);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 64) {
+ if (tid < 32) {
+ __update(dists, dists_i, tid, tid + 32);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 32) {
+ if (tid < 16) {
+ __update(dists, dists_i, tid, tid + 16);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 16) {
+ if (tid < 8) {
+ __update(dists, dists_i, tid, tid + 8);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 8) {
+ if (tid < 4) {
+ __update(dists, dists_i, tid, tid + 4);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 4) {
+ if (tid < 2) {
+ __update(dists, dists_i, tid, tid + 2);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 2) {
+ if (tid < 1) {
+ __update(dists, dists_i, tid, tid + 1);
+ }
+ __syncthreads();
+ }
+
+ old = dists_i[0];
+ if (tid == 0)
+ idxs[j] = old;
+ }
+}
+
+void furthest_point_sampling_kernel_launcher(int b, int n, int m,
+ const float *dataset, float *temp, int *idxs, cudaStream_t stream) {
+ // dataset: (B, N, 3)
+ // tmp: (B, N)
+ // output:
+ // idx: (B, M)
+
+ cudaError_t err;
+ unsigned int n_threads = opt_n_threads(n);
+
+ switch (n_threads) {
+ case 1024:
+ furthest_point_sampling_kernel<1024><<>>(b, n, m, dataset, temp, idxs); break;
+ case 512:
+ furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); break;
+ case 256:
+ furthest_point_sampling_kernel<256><<>>(b, n, m, dataset, temp, idxs); break;
+ case 128:
+ furthest_point_sampling_kernel<128><<>>(b, n, m, dataset, temp, idxs); break;
+ case 64:
+ furthest_point_sampling_kernel<64><<>>(b, n, m, dataset, temp, idxs); break;
+ case 32:
+ furthest_point_sampling_kernel<32><<>>(b, n, m, dataset, temp, idxs); break;
+ case 16:
+ furthest_point_sampling_kernel<16><<>>(b, n, m, dataset, temp, idxs); break;
+ case 8:
+ furthest_point_sampling_kernel<8><<>>(b, n, m, dataset, temp, idxs); break;
+ case 4:
+ furthest_point_sampling_kernel<4><<>>(b, n, m, dataset, temp, idxs); break;
+ case 2:
+ furthest_point_sampling_kernel<2><<>>(b, n, m, dataset, temp, idxs); break;
+ case 1:
+ furthest_point_sampling_kernel<1><<>>(b, n, m, dataset, temp, idxs); break;
+ default:
+ furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs);
+ }
+
+ err = cudaGetLastError();
+ if (cudaSuccess != err) {
+ fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+ exit(-1);
+ }
+}
diff --git a/dataset/process_tools/utils/pointnet_lib/src/sampling_gpu.h b/dataset/process_tools/utils/pointnet_lib/src/sampling_gpu.h
new file mode 100644
index 0000000..6200c59
--- /dev/null
+++ b/dataset/process_tools/utils/pointnet_lib/src/sampling_gpu.h
@@ -0,0 +1,29 @@
+#ifndef _SAMPLING_GPU_H
+#define _SAMPLING_GPU_H
+
+#include
+#include
+#include
+
+
+int gather_points_wrapper_fast(int b, int c, int n, int npoints,
+ at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
+
+void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints,
+ const float *points, const int *idx, float *out, cudaStream_t stream);
+
+
+int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
+ at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
+
+void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
+ const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);
+
+
+int furthest_point_sampling_wrapper(int b, int n, int m,
+ at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
+
+void furthest_point_sampling_kernel_launcher(int b, int n, int m,
+ const float *dataset, float *temp, int *idxs, cudaStream_t stream);
+
+#endif
diff --git a/dataset/process_tools/utils/read_utils.py b/dataset/process_tools/utils/read_utils.py
new file mode 100644
index 0000000..b14af83
--- /dev/null
+++ b/dataset/process_tools/utils/read_utils.py
@@ -0,0 +1,45 @@
+import os
+from os.path import join as pjoin
+import json
+import numpy as np
+from PIL import Image
+import pickle
+
+
+def load_rgb_image(save_path, filename):
+ img = Image.open(pjoin(save_path, 'rgb', f'{filename}.png'))
+ return np.array(img)
+
+
+def load_depth_map(save_path, filename):
+ depth_dict = np.load(pjoin(save_path, 'depth', f'{filename}.npz'))
+ depth_map = depth_dict['depth_map']
+ return depth_map
+
+
+def load_anno_dict(save_path, filename):
+ anno_dict = {}
+
+ seg_path = pjoin(save_path, 'segmentation')
+ bbox_path = pjoin(save_path, 'bbox')
+ npcs_path = pjoin(save_path, 'npcs')
+
+ seg_dict = np.load(pjoin(seg_path, f'{filename}.npz'))
+ anno_dict['semantic_segmentation'] = seg_dict['semantic_segmentation']
+ anno_dict['instance_segmentation'] = seg_dict['instance_segmentation']
+
+ npcs_dict = np.load(pjoin(npcs_path, f'{filename}.npz'))
+ anno_dict['npcs_map'] = npcs_dict['npcs_map']
+
+ with open(pjoin(bbox_path, f'{filename}.pkl'), 'rb') as fd:
+ bbox_dict = pickle.load(fd)
+ anno_dict['bbox_pose_dict'] = bbox_dict['bbox_pose_dict']
+
+ return anno_dict
+
+
+def load_meta(save_path, filename):
+ with open(pjoin(save_path, 'metafile', f'{filename}.json'), 'r') as fd:
+ meta = json.load(fd)
+ return meta
+
diff --git a/dataset/process_tools/utils/sample_utils.py b/dataset/process_tools/utils/sample_utils.py
new file mode 100644
index 0000000..d24d2e5
--- /dev/null
+++ b/dataset/process_tools/utils/sample_utils.py
@@ -0,0 +1,73 @@
+import os
+import sys
+from os.path import join as pjoin
+import numpy as np
+from numpy.random.mtrand import sample
+
+import torch
+
+CUDA = torch.cuda.is_available()
+if CUDA:
+ import pointnet_lib.pointnet2_utils as futils
+
+
+def farthest_point_sample(xyz, npoint):
+ """
+ Copied from CAPTRA
+
+ Input:
+ xyz: pointcloud data, [B, N, 3], tensor
+ npoint: number of samples
+ Return:
+ centroids: sampled pointcloud index, [B, npoint]
+ """
+ device = xyz.device
+ B, N, C = xyz.shape
+ # return torch.randint(0, N, (B, npoint), dtype=torch.long).to(device)
+ if CUDA:
+ print('Use pointnet2_cuda!')
+ idx = futils.furthest_point_sample(xyz, npoint).long()
+ return idx
+
+ centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
+ distance = torch.ones(B, N).to(device) * 1e10
+ farthest = torch.randint(0, N, (B, ), dtype=torch.long).to(device)
+ batch_indices = torch.arange(B, dtype=torch.long).to(device)
+ for i in range(npoint):
+ centroids[:, i] = farthest
+ centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
+ dist = torch.sum((xyz - centroid)**2, -1)
+ mask = dist < distance
+ distance[mask] = dist[mask]
+ farthest = torch.max(distance, -1)[1]
+ return centroids
+
+
+def FPS(pcs, npoint):
+ """
+ Input:
+ pcs: pointcloud data, [N, 3]
+ npoint: number of samples
+ Return:
+ sampled_pcs: [npoint, 3]
+ fps_idx: sampled pointcloud index, [npoint, ]
+ """
+ if pcs.shape[0] < npoint:
+ print('Error! shape[0] of point cloud is less than npoint!')
+ return None, None
+
+ if pcs.shape[0] == npoint:
+ return pcs, np.arange(pcs.shape[0])
+
+ pcs_tensor = torch.from_numpy(np.expand_dims(pcs, 0)).float()
+ fps_idx_tensor = farthest_point_sample(pcs_tensor, npoint)
+ fps_idx = fps_idx_tensor.cpu().numpy()[0]
+ sampled_pcs = pcs[fps_idx]
+ return sampled_pcs, fps_idx
+
+
+if __name__ == "__main__":
+ pc = np.random.random((50000, 3))
+ pc_sampled, idx = FPS(pc, 10000)
+ print(pc_sampled)
+ print(idx)
\ No newline at end of file
diff --git a/dataset/render_tools/meta/akb48_all_id_list.txt b/dataset/render_tools/meta/akb48_all_id_list.txt
new file mode 100644
index 0000000..e2563ec
--- /dev/null
+++ b/dataset/render_tools/meta/akb48_all_id_list.txt
@@ -0,0 +1,121 @@
+Box 100
+Box 102
+Box 117
+Box 124
+Box 128
+Box 135
+Box 141
+Box 142
+Box 153
+Box 164
+Box 165
+Box 167
+Box 169
+Box 174
+Box 178
+Box 184
+Box 186
+Box 194
+Box 195
+Box 196
+Box 198
+Box 201
+Box 204
+Box 208
+Box 209
+Box 210
+Box 40
+Box 41
+Box 45
+Box 55
+Box 58
+Box 60
+Box 63
+Box 64
+Box 68
+Box 78
+Box 89
+Box 94
+Box 96
+Box 99
+Bucket 0
+Bucket 1
+Bucket 10
+Bucket 11
+Bucket 12
+Bucket 13
+Bucket 14
+Bucket 15
+Bucket 16
+Bucket 17
+Bucket 18
+Bucket 19
+Bucket 2
+Bucket 20
+Bucket 21
+Bucket 22
+Bucket 23
+Bucket 24
+Bucket 26
+Bucket 27
+Bucket 29
+Bucket 3
+Bucket 30
+Bucket 31
+Bucket 32
+Bucket 34
+Bucket 35
+Bucket 36
+Bucket 37
+Bucket 38
+Bucket 39
+Bucket 4
+Bucket 5
+Bucket 6
+Bucket 7
+Bucket 8
+Bucket 9
+Drawer 275
+Drawer 276
+Drawer 278
+Drawer 282
+Drawer 283
+Drawer 288
+Drawer 289
+Drawer 290
+Drawer 292
+Drawer 294
+Drawer 297
+Drawer 298
+Drawer 300
+Drawer 301
+Drawer 302
+Drawer 304
+Drawer 308
+Drawer 309
+Drawer 311
+Drawer 312
+Drawer 313
+Drawer 314
+TrashCan 213
+TrashCan 219
+TrashCan 224
+TrashCan 225
+TrashCan 227
+TrashCan 229
+TrashCan 230
+TrashCan 232
+TrashCan 234
+TrashCan 237
+TrashCan 244
+TrashCan 245
+TrashCan 246
+TrashCan 247
+TrashCan 249
+TrashCan 250
+TrashCan 256
+TrashCan 257
+TrashCan 258
+TrashCan 260
+TrashCan 263
+TrashCan 270
diff --git a/dataset/render_tools/meta/akb48_all_split.json b/dataset/render_tools/meta/akb48_all_split.json
new file mode 100644
index 0000000..20ea733
--- /dev/null
+++ b/dataset/render_tools/meta/akb48_all_split.json
@@ -0,0 +1,150 @@
+{
+ "seen_category": {
+ "Drawer": {
+ "seen_instance": [
+ 275,
+ 276,
+ 278,
+ 290,
+ 292,
+ 294,
+ 297,
+ 298,
+ 300,
+ 301,
+ 302,
+ 304,
+ 308,
+ 309,
+ 311,
+ 312,
+ 313,
+ 314
+ ],
+ "unseen_instance": [
+ 282,
+ 283,
+ 288,
+ 289
+ ]
+ },
+ "Box": {
+ "seen_instance": [
+ 68,
+ 78,
+ 89,
+ 94,
+ 96,
+ 99,
+ 100,
+ 102,
+ 117,
+ 124,
+ 128,
+ 135,
+ 141,
+ 142,
+ 153,
+ 164,
+ 165,
+ 167,
+ 169,
+ 174,
+ 178,
+ 184,
+ 186,
+ 194,
+ 195,
+ 196,
+ 198,
+ 201,
+ 204,
+ 208,
+ 209,
+ 210
+ ],
+ "unseen_instance": [
+ 40,
+ 41,
+ 45,
+ 55,
+ 58,
+ 60,
+ 63,
+ 64
+ ]
+ },
+ "Bucket": {
+ "seen_instance": [
+ 7,
+ 8,
+ 9,
+ 10,
+ 11,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 20,
+ 21,
+ 22,
+ 23,
+ 24,
+ 26,
+ 27,
+ 29,
+ 30,
+ 31,
+ 32,
+ 34,
+ 35,
+ 36,
+ 37,
+ 38,
+ 39
+ ],
+ "unseen_instance": [
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6
+ ]
+ },
+ "TrashCan": {
+ "seen_instance": [
+ 229,
+ 230,
+ 232,
+ 234,
+ 237,
+ 244,
+ 245,
+ 246,
+ 247,
+ 249,
+ 250,
+ 256,
+ 257,
+ 258,
+ 260,
+ 263,
+ 270
+ ],
+ "unseen_instance": [
+ 213,
+ 219,
+ 224,
+ 225,
+ 227
+ ]
+ }
+ },
+ "unseen_category": {}
+}
\ No newline at end of file
diff --git a/dataset/render_tools/meta/partnet_all_id_list.txt b/dataset/render_tools/meta/partnet_all_id_list.txt
new file mode 100644
index 0000000..8a2936f
--- /dev/null
+++ b/dataset/render_tools/meta/partnet_all_id_list.txt
@@ -0,0 +1,1045 @@
+Box 100129
+Box 100141
+Box 100162
+Box 100189
+Box 100191
+Box 100197
+Box 100202
+Box 100214
+Box 100221
+Box 100224
+Box 100234
+Box 100243
+Box 100247
+Box 100426
+Box 100658
+Box 100664
+Box 100671
+Box 100676
+Box 100685
+Box 102373
+Box 102377
+Box 102379
+Box 102456
+Box 47645
+Box 48492
+Bucket 100431
+Bucket 100432
+Bucket 100435
+Bucket 100438
+Bucket 100439
+Bucket 100441
+Bucket 100443
+Bucket 100446
+Bucket 100448
+Bucket 100452
+Bucket 100454
+Bucket 100460
+Bucket 100461
+Bucket 100462
+Bucket 100464
+Bucket 100465
+Bucket 100466
+Bucket 100468
+Bucket 100469
+Bucket 100470
+Bucket 100472
+Bucket 100473
+Bucket 100477
+Bucket 100484
+Bucket 100486
+Bucket 102352
+Bucket 102358
+Bucket 102359
+Bucket 102365
+Bucket 102367
+Bucket 102369
+Camera 101352
+Camera 101362
+Camera 102394
+Camera 102398
+Camera 102403
+Camera 102407
+Camera 102408
+Camera 102411
+Camera 102417
+Camera 102431
+Camera 102432
+Camera 102434
+Camera 102442
+Camera 102472
+Camera 102505
+Camera 102520
+Camera 102523
+Camera 102527
+Camera 102528
+Camera 102532
+Camera 102536
+Camera 102539
+Camera 102542
+Camera 102829
+Camera 102831
+Camera 102834
+Camera 102845
+Camera 102873
+Camera 102874
+Camera 102876
+Camera 102882
+Camera 102892
+CoffeeMachine 102145
+CoffeeMachine 102901
+CoffeeMachine 103002
+CoffeeMachine 103016
+CoffeeMachine 103030
+CoffeeMachine 103031
+CoffeeMachine 103037
+CoffeeMachine 103038
+CoffeeMachine 103041
+CoffeeMachine 103043
+CoffeeMachine 103046
+CoffeeMachine 103048
+CoffeeMachine 103057
+CoffeeMachine 103060
+CoffeeMachine 103064
+CoffeeMachine 103065
+CoffeeMachine 103069
+CoffeeMachine 103072
+CoffeeMachine 103074
+CoffeeMachine 103075
+CoffeeMachine 103079
+CoffeeMachine 103080
+CoffeeMachine 103082
+CoffeeMachine 103084
+CoffeeMachine 103086
+CoffeeMachine 103092
+CoffeeMachine 103098
+CoffeeMachine 103101
+CoffeeMachine 103105
+CoffeeMachine 103110
+CoffeeMachine 103118
+CoffeeMachine 103121
+CoffeeMachine 103123
+CoffeeMachine 103124
+CoffeeMachine 103127
+CoffeeMachine 103128
+CoffeeMachine 103129
+CoffeeMachine 103137
+CoffeeMachine 103143
+CoffeeMachine 103144
+CoffeeMachine 103146
+Dishwasher 11622
+Dishwasher 11661
+Dishwasher 11700
+Dishwasher 11826
+Dishwasher 12065
+Dishwasher 12085
+Dishwasher 12092
+Dishwasher 12259
+Dishwasher 12349
+Dishwasher 12428
+Dishwasher 12480
+Dishwasher 12484
+Dishwasher 12530
+Dishwasher 12531
+Dishwasher 12540
+Dishwasher 12542
+Dishwasher 12543
+Dishwasher 12552
+Dishwasher 12553
+Dishwasher 12558
+Dishwasher 12559
+Dishwasher 12560
+Dishwasher 12561
+Dishwasher 12562
+Dishwasher 12563
+Dishwasher 12565
+Dishwasher 12579
+Dishwasher 12580
+Dishwasher 12583
+Dishwasher 12584
+Dishwasher 12587
+Dishwasher 12590
+Dishwasher 12592
+Dishwasher 12594
+Dishwasher 12597
+Dishwasher 12605
+Dishwasher 12606
+Dishwasher 12612
+Dishwasher 12614
+Dishwasher 12621
+Dishwasher 12654
+Door 8867
+Door 8897
+Door 8903
+Door 8919
+Door 8961
+Door 8983
+Door 8994
+Door 8997
+Door 9016
+Door 9070
+Door 9117
+Door 9263
+Door 9288
+Door 9393
+Keyboard 12727
+Keyboard 12738
+Keyboard 12829
+Keyboard 12834
+Keyboard 12836
+Keyboard 12838
+Keyboard 12880
+Keyboard 12902
+Keyboard 12917
+Keyboard 12923
+Keyboard 12953
+Keyboard 12956
+Keyboard 12965
+Keyboard 12968
+Keyboard 12977
+Keyboard 12996
+Keyboard 12999
+Keyboard 13004
+Keyboard 13023
+Keyboard 13062
+Keyboard 13064
+Keyboard 13075
+Keyboard 13082
+Keyboard 13086
+Keyboard 13095
+Keyboard 13100
+Keyboard 13106
+Keyboard 13120
+Keyboard 13136
+Keyboard 13153
+Keyboard 7619
+KitchenPot 100015
+KitchenPot 100017
+KitchenPot 100021
+KitchenPot 100023
+KitchenPot 100025
+KitchenPot 100028
+KitchenPot 100032
+KitchenPot 100033
+KitchenPot 100040
+KitchenPot 100045
+KitchenPot 100047
+KitchenPot 100051
+KitchenPot 100054
+KitchenPot 100056
+KitchenPot 100057
+KitchenPot 100058
+KitchenPot 100060
+KitchenPot 100619
+KitchenPot 100623
+KitchenPot 102080
+Laptop 10040
+Laptop 10090
+Laptop 10098
+Laptop 10101
+Laptop 10108
+Laptop 10125
+Laptop 10211
+Laptop 10213
+Laptop 10238
+Laptop 10239
+Laptop 10243
+Laptop 10248
+Laptop 10269
+Laptop 10270
+Laptop 10280
+Laptop 10289
+Laptop 10305
+Laptop 10306
+Laptop 10356
+Laptop 10383
+Laptop 10626
+Laptop 10707
+Laptop 10885
+Laptop 11030
+Laptop 11075
+Laptop 11141
+Laptop 11156
+Laptop 11242
+Laptop 11395
+Laptop 11405
+Laptop 11406
+Laptop 11429
+Laptop 11477
+Laptop 11581
+Laptop 11586
+Laptop 11691
+Laptop 11778
+Laptop 11854
+Laptop 11876
+Laptop 11945
+Laptop 12073
+Laptop 9748
+Laptop 9912
+Laptop 9918
+Laptop 9960
+Laptop 9968
+Laptop 9992
+Laptop 9996
+Microwave 7119
+Microwave 7128
+Microwave 7167
+Microwave 7221
+Microwave 7236
+Microwave 7263
+Microwave 7265
+Microwave 7273
+Microwave 7292
+Microwave 7296
+Microwave 7304
+Microwave 7306
+Microwave 7310
+Microwave 7320
+Microwave 7349
+Microwave 7366
+Oven 101773
+Oven 101908
+Oven 101917
+Oven 101921
+Oven 101930
+Oven 101931
+Oven 101943
+Oven 101946
+Oven 101947
+Oven 101971
+Oven 102018
+Oven 102044
+Oven 102055
+Oven 102060
+Oven 7120
+Oven 7130
+Oven 7138
+Oven 7179
+Oven 7187
+Oven 7201
+Oven 7220
+Oven 7290
+Oven 7332
+Phone 103251
+Phone 103285
+Phone 103347
+Phone 103593
+Phone 103699
+Phone 103813
+Phone 103814
+Phone 103828
+Phone 103886
+Phone 103892
+Phone 103917
+Phone 103925
+Phone 103927
+Phone 103935
+Phone 103941
+Printer 100279
+Printer 103811
+Printer 103859
+Printer 103863
+Printer 103866
+Printer 103867
+Printer 103869
+Printer 103872
+Printer 103878
+Printer 103894
+Printer 103972
+Printer 103974
+Printer 103978
+Printer 103981
+Printer 103988
+Printer 103989
+Printer 103996
+Printer 104000
+Printer 104004
+Printer 104006
+Printer 104007
+Printer 104009
+Printer 104011
+Printer 104013
+Printer 104016
+Printer 104020
+Printer 104027
+Printer 104030
+Refrigerator 10068
+Refrigerator 10143
+Refrigerator 10144
+Refrigerator 10373
+Refrigerator 10489
+Refrigerator 10586
+Refrigerator 10620
+Refrigerator 10627
+Refrigerator 10655
+Refrigerator 10685
+Refrigerator 10751
+Refrigerator 10797
+Refrigerator 10849
+Refrigerator 10867
+Refrigerator 10900
+Refrigerator 10905
+Refrigerator 10944
+Refrigerator 11178
+Refrigerator 11211
+Refrigerator 11231
+Refrigerator 11260
+Refrigerator 11299
+Refrigerator 11304
+Refrigerator 11709
+Refrigerator 11712
+Refrigerator 11846
+Refrigerator 12036
+Refrigerator 12038
+Refrigerator 12042
+Refrigerator 12050
+Refrigerator 12054
+Refrigerator 12055
+Refrigerator 12059
+Refrigerator 12248
+Refrigerator 12249
+Refrigerator 12250
+Refrigerator 12252
+Remote 100013
+Remote 100269
+Remote 100270
+Remote 100385
+Remote 100392
+Remote 100394
+Remote 100395
+Remote 100405
+Remote 100408
+Remote 100412
+Remote 100706
+Remote 100712
+Remote 100809
+Remote 100811
+Remote 100814
+Remote 100816
+Remote 100819
+Remote 100828
+Remote 100991
+Remote 100993
+Remote 100997
+Remote 100999
+Remote 101002
+Remote 101004
+Remote 101007
+Remote 101010
+Remote 101011
+Remote 101014
+Remote 101015
+Remote 101016
+Remote 101023
+Remote 101028
+Remote 101034
+Remote 101104
+Remote 101117
+Remote 101118
+Remote 101121
+Remote 101131
+Remote 101133
+Remote 101139
+Remote 101142
+Remote 102130
+Remote 104036
+Remote 104038
+Remote 104039
+Remote 104040
+Remote 104041
+Remote 104044
+Remote 104045
+Safe 101363
+Safe 101564
+Safe 101579
+Safe 101583
+Safe 101584
+Safe 101591
+Safe 101593
+Safe 101594
+Safe 101599
+Safe 101603
+Safe 101605
+Safe 101611
+Safe 101612
+Safe 101613
+Safe 101619
+Safe 101623
+Safe 102278
+Safe 102301
+Safe 102309
+Safe 102311
+Safe 102316
+Safe 102318
+Safe 102380
+Safe 102381
+Safe 102384
+Safe 102387
+Safe 102389
+Safe 102418
+Safe 102423
+StorageFurniture 35059
+StorageFurniture 38516
+StorageFurniture 40147
+StorageFurniture 40417
+StorageFurniture 40453
+StorageFurniture 41003
+StorageFurniture 41004
+StorageFurniture 41083
+StorageFurniture 41085
+StorageFurniture 41086
+StorageFurniture 41452
+StorageFurniture 41510
+StorageFurniture 41529
+StorageFurniture 44781
+StorageFurniture 44817
+StorageFurniture 44853
+StorageFurniture 44962
+StorageFurniture 45001
+StorageFurniture 45007
+StorageFurniture 45087
+StorageFurniture 45091
+StorageFurniture 45092
+StorageFurniture 45130
+StorageFurniture 45132
+StorageFurniture 45134
+StorageFurniture 45135
+StorageFurniture 45146
+StorageFurniture 45159
+StorageFurniture 45162
+StorageFurniture 45164
+StorageFurniture 45166
+StorageFurniture 45168
+StorageFurniture 45173
+StorageFurniture 45176
+StorageFurniture 45177
+StorageFurniture 45178
+StorageFurniture 45189
+StorageFurniture 45194
+StorageFurniture 45203
+StorageFurniture 45212
+StorageFurniture 45213
+StorageFurniture 45219
+StorageFurniture 45235
+StorageFurniture 45238
+StorageFurniture 45243
+StorageFurniture 45244
+StorageFurniture 45247
+StorageFurniture 45249
+StorageFurniture 45261
+StorageFurniture 45262
+StorageFurniture 45267
+StorageFurniture 45271
+StorageFurniture 45290
+StorageFurniture 45297
+StorageFurniture 45305
+StorageFurniture 45323
+StorageFurniture 45332
+StorageFurniture 45354
+StorageFurniture 45372
+StorageFurniture 45374
+StorageFurniture 45378
+StorageFurniture 45384
+StorageFurniture 45385
+StorageFurniture 45387
+StorageFurniture 45397
+StorageFurniture 45403
+StorageFurniture 45415
+StorageFurniture 45419
+StorageFurniture 45420
+StorageFurniture 45423
+StorageFurniture 45427
+StorageFurniture 45443
+StorageFurniture 45444
+StorageFurniture 45448
+StorageFurniture 45463
+StorageFurniture 45503
+StorageFurniture 45504
+StorageFurniture 45505
+StorageFurniture 45516
+StorageFurniture 45523
+StorageFurniture 45524
+StorageFurniture 45526
+StorageFurniture 45573
+StorageFurniture 45575
+StorageFurniture 45594
+StorageFurniture 45600
+StorageFurniture 45606
+StorageFurniture 45612
+StorageFurniture 45621
+StorageFurniture 45622
+StorageFurniture 45623
+StorageFurniture 45632
+StorageFurniture 45633
+StorageFurniture 45636
+StorageFurniture 45638
+StorageFurniture 45645
+StorageFurniture 45661
+StorageFurniture 45662
+StorageFurniture 45667
+StorageFurniture 45670
+StorageFurniture 45671
+StorageFurniture 45676
+StorageFurniture 45677
+StorageFurniture 45687
+StorageFurniture 45689
+StorageFurniture 45691
+StorageFurniture 45693
+StorageFurniture 45694
+StorageFurniture 45696
+StorageFurniture 45699
+StorageFurniture 45710
+StorageFurniture 45717
+StorageFurniture 45725
+StorageFurniture 45746
+StorageFurniture 45747
+StorageFurniture 45749
+StorageFurniture 45756
+StorageFurniture 45759
+StorageFurniture 45767
+StorageFurniture 45779
+StorageFurniture 45780
+StorageFurniture 45783
+StorageFurniture 45784
+StorageFurniture 45790
+StorageFurniture 45801
+StorageFurniture 45841
+StorageFurniture 45850
+StorageFurniture 45853
+StorageFurniture 45855
+StorageFurniture 45908
+StorageFurniture 45910
+StorageFurniture 45915
+StorageFurniture 45916
+StorageFurniture 45922
+StorageFurniture 45936
+StorageFurniture 45937
+StorageFurniture 45940
+StorageFurniture 45948
+StorageFurniture 45949
+StorageFurniture 45950
+StorageFurniture 45961
+StorageFurniture 45963
+StorageFurniture 45964
+StorageFurniture 45984
+StorageFurniture 46002
+StorageFurniture 46014
+StorageFurniture 46019
+StorageFurniture 46029
+StorageFurniture 46033
+StorageFurniture 46037
+StorageFurniture 46044
+StorageFurniture 46045
+StorageFurniture 46057
+StorageFurniture 46060
+StorageFurniture 46084
+StorageFurniture 46092
+StorageFurniture 46107
+StorageFurniture 46108
+StorageFurniture 46109
+StorageFurniture 46117
+StorageFurniture 46120
+StorageFurniture 46123
+StorageFurniture 46127
+StorageFurniture 46130
+StorageFurniture 46132
+StorageFurniture 46134
+StorageFurniture 46145
+StorageFurniture 46166
+StorageFurniture 46172
+StorageFurniture 46179
+StorageFurniture 46180
+StorageFurniture 46197
+StorageFurniture 46199
+StorageFurniture 46230
+StorageFurniture 46236
+StorageFurniture 46277
+StorageFurniture 46334
+StorageFurniture 46380
+StorageFurniture 46401
+StorageFurniture 46408
+StorageFurniture 46417
+StorageFurniture 46427
+StorageFurniture 46430
+StorageFurniture 46437
+StorageFurniture 46440
+StorageFurniture 46443
+StorageFurniture 46452
+StorageFurniture 46456
+StorageFurniture 46462
+StorageFurniture 46466
+StorageFurniture 46480
+StorageFurniture 46481
+StorageFurniture 46490
+StorageFurniture 46537
+StorageFurniture 46544
+StorageFurniture 46549
+StorageFurniture 46563
+StorageFurniture 46598
+StorageFurniture 46616
+StorageFurniture 46641
+StorageFurniture 46653
+StorageFurniture 46655
+StorageFurniture 46699
+StorageFurniture 46700
+StorageFurniture 46732
+StorageFurniture 46741
+StorageFurniture 46744
+StorageFurniture 46762
+StorageFurniture 46768
+StorageFurniture 46787
+StorageFurniture 46801
+StorageFurniture 46825
+StorageFurniture 46847
+StorageFurniture 46856
+StorageFurniture 46859
+StorageFurniture 46874
+StorageFurniture 46879
+StorageFurniture 46889
+StorageFurniture 46893
+StorageFurniture 46896
+StorageFurniture 46906
+StorageFurniture 46922
+StorageFurniture 46944
+StorageFurniture 46955
+StorageFurniture 46966
+StorageFurniture 46981
+StorageFurniture 47021
+StorageFurniture 47024
+StorageFurniture 47088
+StorageFurniture 47089
+StorageFurniture 47099
+StorageFurniture 47133
+StorageFurniture 47178
+StorageFurniture 47180
+StorageFurniture 47182
+StorageFurniture 47183
+StorageFurniture 47185
+StorageFurniture 47187
+StorageFurniture 47207
+StorageFurniture 47227
+StorageFurniture 47233
+StorageFurniture 47235
+StorageFurniture 47238
+StorageFurniture 47252
+StorageFurniture 47254
+StorageFurniture 47278
+StorageFurniture 47281
+StorageFurniture 47290
+StorageFurniture 47296
+StorageFurniture 47315
+StorageFurniture 47316
+StorageFurniture 47388
+StorageFurniture 47391
+StorageFurniture 47419
+StorageFurniture 47438
+StorageFurniture 47443
+StorageFurniture 47514
+StorageFurniture 47529
+StorageFurniture 47565
+StorageFurniture 47570
+StorageFurniture 47577
+StorageFurniture 47578
+StorageFurniture 47585
+StorageFurniture 47595
+StorageFurniture 47601
+StorageFurniture 47613
+StorageFurniture 47632
+StorageFurniture 47669
+StorageFurniture 47686
+StorageFurniture 47701
+StorageFurniture 47711
+StorageFurniture 47729
+StorageFurniture 47742
+StorageFurniture 47747
+StorageFurniture 47808
+StorageFurniture 47817
+StorageFurniture 47853
+StorageFurniture 47926
+StorageFurniture 47944
+StorageFurniture 47963
+StorageFurniture 47976
+StorageFurniture 48010
+StorageFurniture 48018
+StorageFurniture 48023
+StorageFurniture 48036
+StorageFurniture 48051
+StorageFurniture 48063
+StorageFurniture 48167
+StorageFurniture 48169
+StorageFurniture 48177
+StorageFurniture 48243
+StorageFurniture 48253
+StorageFurniture 48258
+StorageFurniture 48263
+StorageFurniture 48271
+StorageFurniture 48356
+StorageFurniture 48379
+StorageFurniture 48381
+StorageFurniture 48413
+StorageFurniture 48452
+StorageFurniture 48467
+StorageFurniture 48490
+StorageFurniture 48491
+StorageFurniture 48497
+StorageFurniture 48513
+StorageFurniture 48517
+StorageFurniture 48519
+StorageFurniture 48623
+StorageFurniture 48700
+StorageFurniture 48721
+StorageFurniture 48740
+StorageFurniture 48797
+StorageFurniture 48855
+StorageFurniture 48859
+StorageFurniture 48876
+StorageFurniture 48878
+StorageFurniture 49025
+StorageFurniture 49038
+StorageFurniture 49042
+StorageFurniture 49062
+StorageFurniture 49132
+StorageFurniture 49133
+StorageFurniture 49140
+StorageFurniture 49188
+Suitcase 100249
+Suitcase 100550
+Suitcase 100767
+Suitcase 100776
+Suitcase 100825
+Suitcase 101668
+Suitcase 101673
+Suitcase 103755
+Suitcase 103761
+Suitcase 103762
+Table 19179
+Table 19825
+Table 19836
+Table 19855
+Table 19898
+Table 20043
+Table 20279
+Table 20411
+Table 20453
+Table 20555
+Table 20985
+Table 21467
+Table 22241
+Table 22301
+Table 22339
+Table 22367
+Table 22433
+Table 22508
+Table 22692
+Table 23372
+Table 23472
+Table 23511
+Table 23724
+Table 23782
+Table 23807
+Table 24644
+Table 24931
+Table 25144
+Table 25308
+Table 25493
+Table 25913
+Table 26073
+Table 26387
+Table 26503
+Table 26525
+Table 26608
+Table 26652
+Table 26657
+Table 26670
+Table 26806
+Table 26875
+Table 26899
+Table 27044
+Table 27189
+Table 27267
+Table 27619
+Table 28668
+Table 29133
+Table 29525
+Table 29557
+Table 29921
+Table 30238
+Table 30341
+Table 30663
+Table 30666
+Table 30739
+Table 30857
+Table 30869
+Table 31249
+Table 31601
+Table 32052
+Table 32086
+Table 32174
+Table 32259
+Table 32324
+Table 32354
+Table 32566
+Table 32601
+Table 32746
+Table 32761
+Table 32932
+Table 33116
+Table 33914
+Table 33930
+Table 34178
+Table 34610
+Table 34617
+Toaster 103466
+Toaster 103469
+Toaster 103475
+Toaster 103477
+Toaster 103482
+Toaster 103485
+Toaster 103486
+Toaster 103502
+Toaster 103514
+Toaster 103524
+Toaster 103545
+Toaster 103547
+Toaster 103548
+Toaster 103549
+Toaster 103553
+Toaster 103556
+Toaster 103558
+Toaster 103560
+Toaster 103561
+Toilet 101319
+Toilet 101320
+Toilet 101323
+Toilet 102619
+Toilet 102620
+Toilet 102621
+Toilet 102622
+Toilet 102625
+Toilet 102628
+Toilet 102629
+Toilet 102630
+Toilet 102631
+Toilet 102632
+Toilet 102634
+Toilet 102636
+Toilet 102639
+Toilet 102641
+Toilet 102643
+Toilet 102645
+Toilet 102646
+Toilet 102647
+Toilet 102648
+Toilet 102650
+Toilet 102651
+Toilet 102652
+Toilet 102654
+Toilet 102655
+Toilet 102657
+Toilet 102658
+Toilet 102660
+Toilet 102662
+Toilet 102663
+Toilet 102664
+Toilet 102665
+Toilet 102666
+Toilet 102667
+Toilet 102668
+Toilet 102669
+Toilet 102670
+Toilet 102675
+Toilet 102676
+Toilet 102677
+Toilet 102678
+Toilet 102679
+Toilet 102682
+Toilet 102684
+Toilet 102685
+Toilet 102687
+Toilet 102688
+Toilet 102689
+Toilet 102690
+Toilet 102692
+Toilet 102697
+Toilet 102698
+Toilet 102699
+Toilet 102701
+Toilet 102702
+Toilet 102703
+Toilet 102704
+Toilet 102706
+Toilet 102707
+Toilet 102708
+Toilet 102710
+Toilet 103230
+Toilet 103233
+Toilet 103234
+TrashCan 100731
+TrashCan 100732
+TrashCan 101377
+TrashCan 101378
+TrashCan 101380
+TrashCan 102154
+TrashCan 102155
+TrashCan 102156
+TrashCan 102158
+TrashCan 102160
+TrashCan 102165
+TrashCan 102171
+TrashCan 102181
+TrashCan 102186
+TrashCan 102187
+TrashCan 102189
+TrashCan 102192
+TrashCan 102194
+TrashCan 102200
+TrashCan 102201
+TrashCan 102202
+TrashCan 102209
+TrashCan 102210
+TrashCan 102218
+TrashCan 102219
+TrashCan 102227
+TrashCan 102229
+TrashCan 102234
+TrashCan 102244
+TrashCan 102252
+TrashCan 102254
+TrashCan 102256
+TrashCan 102257
+TrashCan 102992
+TrashCan 102996
+TrashCan 103008
+TrashCan 103010
+TrashCan 103013
+TrashCan 103633
+TrashCan 103634
+TrashCan 103646
+TrashCan 103647
+TrashCan 11229
+TrashCan 11259
+TrashCan 11279
+TrashCan 11361
+TrashCan 11818
+TrashCan 11951
+TrashCan 12231
+TrashCan 12447
+TrashCan 12483
+TrashCan 4108
+WashingMachine 100282
+WashingMachine 100283
+WashingMachine 103351
+WashingMachine 103361
+WashingMachine 103369
+WashingMachine 103425
+WashingMachine 103452
+WashingMachine 103480
+WashingMachine 103490
+WashingMachine 103508
+WashingMachine 103518
+WashingMachine 103521
+WashingMachine 103528
+WashingMachine 103775
+WashingMachine 103776
+WashingMachine 103778
+WashingMachine 103781
diff --git a/dataset/render_tools/meta/partnet_all_split.json b/dataset/render_tools/meta/partnet_all_split.json
new file mode 100644
index 0000000..178308d
--- /dev/null
+++ b/dataset/render_tools/meta/partnet_all_split.json
@@ -0,0 +1,1179 @@
+{
+ "seen_category": {
+ "Toaster": {
+ "seen_instance": [
+ 103466,
+ 103475,
+ 103477,
+ 103482,
+ 103485,
+ 103502,
+ 103514,
+ 103545,
+ 103547,
+ 103548,
+ 103549,
+ 103553,
+ 103556,
+ 103558,
+ 103560
+ ],
+ "unseen_instance": [
+ 103469,
+ 103486,
+ 103524,
+ 103561
+ ]
+ },
+ "StorageFurniture": {
+ "seen_instance": [
+ 38516,
+ 40147,
+ 40417,
+ 41003,
+ 41083,
+ 41085,
+ 41086,
+ 41452,
+ 41510,
+ 41529,
+ 44781,
+ 44817,
+ 44853,
+ 44962,
+ 45001,
+ 45087,
+ 45091,
+ 45092,
+ 45130,
+ 45132,
+ 45135,
+ 45146,
+ 45162,
+ 45164,
+ 45168,
+ 45176,
+ 45177,
+ 45178,
+ 45189,
+ 45194,
+ 45203,
+ 45213,
+ 45219,
+ 45235,
+ 45238,
+ 45243,
+ 45244,
+ 45247,
+ 45249,
+ 45262,
+ 45267,
+ 45271,
+ 45290,
+ 45297,
+ 45305,
+ 45323,
+ 45332,
+ 45354,
+ 45374,
+ 45378,
+ 45384,
+ 45387,
+ 45397,
+ 45415,
+ 45419,
+ 45420,
+ 45423,
+ 45427,
+ 45443,
+ 45448,
+ 45463,
+ 45503,
+ 45504,
+ 45505,
+ 45516,
+ 45523,
+ 45524,
+ 45526,
+ 45573,
+ 45575,
+ 45594,
+ 45600,
+ 45612,
+ 45621,
+ 45622,
+ 45632,
+ 45636,
+ 45638,
+ 45645,
+ 45661,
+ 45662,
+ 45667,
+ 45670,
+ 45677,
+ 45687,
+ 45689,
+ 45693,
+ 45694,
+ 45696,
+ 45699,
+ 45710,
+ 45717,
+ 45725,
+ 45747,
+ 45749,
+ 45756,
+ 45780,
+ 45784,
+ 45790,
+ 45801,
+ 45850,
+ 45853,
+ 45908,
+ 45915,
+ 45916,
+ 45922,
+ 45936,
+ 45937,
+ 45940,
+ 45948,
+ 45950,
+ 45961,
+ 45964,
+ 45984,
+ 46002,
+ 46029,
+ 46033,
+ 46037,
+ 46044,
+ 46045,
+ 46057,
+ 46060,
+ 46084,
+ 46092,
+ 46108,
+ 46117,
+ 46120,
+ 46123,
+ 46127,
+ 46130,
+ 46132,
+ 46134,
+ 46166,
+ 46180,
+ 46197,
+ 46199,
+ 46230,
+ 46236,
+ 46334,
+ 46401,
+ 46408,
+ 46417,
+ 46427,
+ 46430,
+ 46440,
+ 46443,
+ 46452,
+ 46456,
+ 46462,
+ 46466,
+ 46480,
+ 46490,
+ 46537,
+ 46544,
+ 46549,
+ 46563,
+ 46616,
+ 46700,
+ 46732,
+ 46741,
+ 46762,
+ 46768,
+ 46787,
+ 46801,
+ 46825,
+ 46847,
+ 46856,
+ 46859,
+ 46874,
+ 46879,
+ 46893,
+ 46896,
+ 46906,
+ 46922,
+ 46944,
+ 46955,
+ 46966,
+ 46981,
+ 47024,
+ 47088,
+ 47089,
+ 47099,
+ 47178,
+ 47180,
+ 47183,
+ 47185,
+ 47187,
+ 47207,
+ 47227,
+ 47235,
+ 47238,
+ 47252,
+ 47254,
+ 47278,
+ 47281,
+ 47290,
+ 47296,
+ 47315,
+ 47316,
+ 47388,
+ 47391,
+ 47438,
+ 47514,
+ 47529,
+ 47565,
+ 47570,
+ 47577,
+ 47578,
+ 47595,
+ 47601,
+ 47613,
+ 47632,
+ 47669,
+ 47686,
+ 47701,
+ 47711,
+ 47729,
+ 47742,
+ 47747,
+ 47808,
+ 47817,
+ 47853,
+ 47926,
+ 47944,
+ 47963,
+ 47976,
+ 48010,
+ 48018,
+ 48023,
+ 48051,
+ 48063,
+ 48167,
+ 48169,
+ 48177,
+ 48243,
+ 48253,
+ 48258,
+ 48263,
+ 48271,
+ 48356,
+ 48379,
+ 48413,
+ 48452,
+ 48467,
+ 48490,
+ 48491,
+ 48513,
+ 48517,
+ 48519,
+ 48623,
+ 48700,
+ 48740,
+ 48855,
+ 48859,
+ 48876,
+ 48878,
+ 49025,
+ 49042,
+ 49062,
+ 49132,
+ 49133,
+ 49140
+ ],
+ "unseen_instance": [
+ 35059,
+ 40453,
+ 41004,
+ 45007,
+ 45134,
+ 45159,
+ 45166,
+ 45173,
+ 45212,
+ 45261,
+ 45372,
+ 45385,
+ 45403,
+ 45444,
+ 45606,
+ 45623,
+ 45633,
+ 45671,
+ 45676,
+ 45691,
+ 45746,
+ 45759,
+ 45767,
+ 45779,
+ 45783,
+ 45841,
+ 45855,
+ 45910,
+ 45949,
+ 45963,
+ 46014,
+ 46019,
+ 46107,
+ 46109,
+ 46145,
+ 46172,
+ 46179,
+ 46277,
+ 46380,
+ 46437,
+ 46481,
+ 46598,
+ 46641,
+ 46653,
+ 46655,
+ 46699,
+ 46744,
+ 46889,
+ 47021,
+ 47133,
+ 47182,
+ 47233,
+ 47419,
+ 47443,
+ 47585,
+ 48036,
+ 48381,
+ 48497,
+ 48721,
+ 48797,
+ 49038,
+ 49188
+ ]
+ },
+ "Bucket": {
+ "seen_instance": [
+ 100441,
+ 100443,
+ 100446,
+ 100448,
+ 100452,
+ 100454,
+ 100460,
+ 100461,
+ 100462,
+ 100464,
+ 100465,
+ 100466,
+ 100468,
+ 100469,
+ 100470,
+ 100472,
+ 100473,
+ 100477,
+ 100484,
+ 102352,
+ 102358,
+ 102359,
+ 102365,
+ 102367,
+ 102369
+ ],
+ "unseen_instance": [
+ 100431,
+ 100432,
+ 100435,
+ 100438,
+ 100439,
+ 100486
+ ]
+ },
+ "WashingMachine": {
+ "seen_instance": [
+ 100282,
+ 103351,
+ 103361,
+ 103425,
+ 103452,
+ 103480,
+ 103490,
+ 103508,
+ 103518,
+ 103521,
+ 103775,
+ 103776,
+ 103778,
+ 103781
+ ],
+ "unseen_instance": [
+ 100283,
+ 103369,
+ 103528
+ ]
+ },
+ "CoffeeMachine": {
+ "seen_instance": [
+ 102145,
+ 103002,
+ 103016,
+ 103030,
+ 103031,
+ 103038,
+ 103041,
+ 103043,
+ 103046,
+ 103048,
+ 103057,
+ 103060,
+ 103064,
+ 103065,
+ 103069,
+ 103072,
+ 103075,
+ 103079,
+ 103080,
+ 103084,
+ 103086,
+ 103098,
+ 103101,
+ 103105,
+ 103110,
+ 103118,
+ 103121,
+ 103123,
+ 103124,
+ 103127,
+ 103137,
+ 103144,
+ 103146
+ ],
+ "unseen_instance": [
+ 102901,
+ 103037,
+ 103074,
+ 103082,
+ 103092,
+ 103128,
+ 103129,
+ 103143
+ ]
+ },
+ "Microwave": {
+ "seen_instance": [
+ 7119,
+ 7128,
+ 7167,
+ 7221,
+ 7236,
+ 7265,
+ 7273,
+ 7292,
+ 7304,
+ 7306,
+ 7310,
+ 7320,
+ 7366
+ ],
+ "unseen_instance": [
+ 7263,
+ 7296,
+ 7349
+ ]
+ },
+ "Box": {
+ "seen_instance": [
+ 47645,
+ 48492,
+ 100129,
+ 100141,
+ 100162,
+ 100197,
+ 100202,
+ 100214,
+ 100224,
+ 100234,
+ 100426,
+ 100658,
+ 100664,
+ 100671,
+ 100676,
+ 100685,
+ 102373,
+ 102377,
+ 102379,
+ 102456
+ ],
+ "unseen_instance": [
+ 100189,
+ 100191,
+ 100221,
+ 100243,
+ 100247
+ ]
+ },
+ "Remote": {
+ "seen_instance": [
+ 100013,
+ 100269,
+ 100270,
+ 100392,
+ 100394,
+ 100405,
+ 100408,
+ 100412,
+ 100712,
+ 100809,
+ 100811,
+ 100814,
+ 100816,
+ 100819,
+ 100828,
+ 100991,
+ 100993,
+ 100997,
+ 100999,
+ 101002,
+ 101007,
+ 101010,
+ 101011,
+ 101014,
+ 101015,
+ 101016,
+ 101023,
+ 101034,
+ 101117,
+ 101118,
+ 101131,
+ 101139,
+ 101142,
+ 102130,
+ 104036,
+ 104039,
+ 104041,
+ 104044,
+ 104045
+ ],
+ "unseen_instance": [
+ 100385,
+ 100395,
+ 100706,
+ 101004,
+ 101028,
+ 101104,
+ 101121,
+ 101133,
+ 104038,
+ 104040
+ ]
+ },
+ "Toilet": {
+ "seen_instance": [
+ 101319,
+ 101320,
+ 102619,
+ 102620,
+ 102621,
+ 102622,
+ 102625,
+ 102628,
+ 102629,
+ 102632,
+ 102634,
+ 102639,
+ 102641,
+ 102643,
+ 102645,
+ 102647,
+ 102648,
+ 102650,
+ 102652,
+ 102654,
+ 102655,
+ 102657,
+ 102658,
+ 102660,
+ 102662,
+ 102663,
+ 102664,
+ 102665,
+ 102666,
+ 102667,
+ 102668,
+ 102669,
+ 102675,
+ 102677,
+ 102678,
+ 102679,
+ 102682,
+ 102684,
+ 102685,
+ 102687,
+ 102688,
+ 102690,
+ 102692,
+ 102698,
+ 102699,
+ 102701,
+ 102702,
+ 102703,
+ 102704,
+ 102707,
+ 102710,
+ 103230,
+ 103233
+ ],
+ "unseen_instance": [
+ 101323,
+ 102630,
+ 102631,
+ 102636,
+ 102646,
+ 102651,
+ 102670,
+ 102676,
+ 102689,
+ 102697,
+ 102706,
+ 102708,
+ 103234
+ ]
+ },
+ "Keyboard": {
+ "seen_instance": [
+ 7619,
+ 12727,
+ 12829,
+ 12834,
+ 12836,
+ 12838,
+ 12880,
+ 12902,
+ 12923,
+ 12953,
+ 12956,
+ 12965,
+ 12968,
+ 12996,
+ 12999,
+ 13004,
+ 13023,
+ 13062,
+ 13064,
+ 13095,
+ 13100,
+ 13106,
+ 13120,
+ 13136,
+ 13153
+ ],
+ "unseen_instance": [
+ 12738,
+ 12917,
+ 12977,
+ 13075,
+ 13082,
+ 13086
+ ]
+ },
+ "Printer": {
+ "seen_instance": [
+ 103811,
+ 103859,
+ 103866,
+ 103867,
+ 103869,
+ 103872,
+ 103878,
+ 103894,
+ 103974,
+ 103978,
+ 103981,
+ 103988,
+ 103989,
+ 103996,
+ 104004,
+ 104006,
+ 104007,
+ 104009,
+ 104013,
+ 104020,
+ 104027,
+ 104030
+ ],
+ "unseen_instance": [
+ 100279,
+ 103863,
+ 103972,
+ 104000,
+ 104011,
+ 104016
+ ]
+ },
+ "Dishwasher": {
+ "seen_instance": [
+ 11622,
+ 11661,
+ 11826,
+ 12085,
+ 12092,
+ 12259,
+ 12349,
+ 12480,
+ 12531,
+ 12540,
+ 12542,
+ 12543,
+ 12552,
+ 12553,
+ 12560,
+ 12561,
+ 12562,
+ 12563,
+ 12565,
+ 12579,
+ 12580,
+ 12583,
+ 12584,
+ 12587,
+ 12590,
+ 12592,
+ 12594,
+ 12605,
+ 12606,
+ 12612,
+ 12614,
+ 12621,
+ 12654
+ ],
+ "unseen_instance": [
+ 11700,
+ 12065,
+ 12428,
+ 12484,
+ 12530,
+ 12558,
+ 12559,
+ 12597
+ ]
+ },
+ "Camera": {
+ "seen_instance": [
+ 101352,
+ 101362,
+ 102394,
+ 102407,
+ 102408,
+ 102411,
+ 102417,
+ 102431,
+ 102432,
+ 102434,
+ 102442,
+ 102505,
+ 102520,
+ 102523,
+ 102527,
+ 102528,
+ 102532,
+ 102539,
+ 102542,
+ 102829,
+ 102831,
+ 102834,
+ 102845,
+ 102876,
+ 102892
+ ],
+ "unseen_instance": [
+ 102398,
+ 102403,
+ 102472,
+ 102536,
+ 102873,
+ 102874,
+ 102882
+ ]
+ }
+ },
+ "unseen_category": {
+ "TrashCan": {
+ "seen_instance": [],
+ "unseen_instance": [
+ 4108,
+ 11229,
+ 11259,
+ 11279,
+ 11361,
+ 11818,
+ 11951,
+ 12231,
+ 12447,
+ 12483,
+ 100731,
+ 100732,
+ 101377,
+ 101378,
+ 101380,
+ 102154,
+ 102155,
+ 102156,
+ 102158,
+ 102160,
+ 102165,
+ 102171,
+ 102181,
+ 102186,
+ 102187,
+ 102189,
+ 102192,
+ 102194,
+ 102200,
+ 102201,
+ 102202,
+ 102209,
+ 102210,
+ 102218,
+ 102219,
+ 102227,
+ 102229,
+ 102234,
+ 102244,
+ 102252,
+ 102254,
+ 102256,
+ 102257,
+ 102992,
+ 102996,
+ 103008,
+ 103010,
+ 103013,
+ 103633,
+ 103634,
+ 103646,
+ 103647
+ ]
+ },
+ "Safe": {
+ "seen_instance": [],
+ "unseen_instance": [
+ 101363,
+ 101564,
+ 101579,
+ 101583,
+ 101584,
+ 101591,
+ 101593,
+ 101594,
+ 101599,
+ 101603,
+ 101605,
+ 101611,
+ 101612,
+ 101613,
+ 101619,
+ 101623,
+ 102278,
+ 102301,
+ 102309,
+ 102311,
+ 102316,
+ 102318,
+ 102380,
+ 102381,
+ 102384,
+ 102387,
+ 102389,
+ 102418,
+ 102423
+ ]
+ },
+ "KitchenPot": {
+ "seen_instance": [],
+ "unseen_instance": [
+ 100015,
+ 100017,
+ 100021,
+ 100023,
+ 100025,
+ 100028,
+ 100032,
+ 100033,
+ 100040,
+ 100045,
+ 100047,
+ 100051,
+ 100054,
+ 100056,
+ 100057,
+ 100058,
+ 100060,
+ 100619,
+ 100623,
+ 102080
+ ]
+ },
+ "Table": {
+ "seen_instance": [],
+ "unseen_instance": [
+ 19179,
+ 19825,
+ 19836,
+ 19855,
+ 19898,
+ 20043,
+ 20279,
+ 20411,
+ 20453,
+ 20555,
+ 20985,
+ 21467,
+ 22241,
+ 22301,
+ 22339,
+ 22367,
+ 22433,
+ 22508,
+ 22692,
+ 23372,
+ 23472,
+ 23511,
+ 23724,
+ 23782,
+ 23807,
+ 24644,
+ 24931,
+ 25144,
+ 25308,
+ 25493,
+ 25913,
+ 26073,
+ 26387,
+ 26503,
+ 26525,
+ 26608,
+ 26652,
+ 26657,
+ 26670,
+ 26806,
+ 26875,
+ 26899,
+ 27044,
+ 27189,
+ 27267,
+ 27619,
+ 28668,
+ 29133,
+ 29525,
+ 29557,
+ 29921,
+ 30238,
+ 30341,
+ 30663,
+ 30666,
+ 30739,
+ 30857,
+ 30869,
+ 31249,
+ 31601,
+ 32052,
+ 32086,
+ 32174,
+ 32259,
+ 32324,
+ 32354,
+ 32566,
+ 32601,
+ 32746,
+ 32761,
+ 32932,
+ 33116,
+ 33914,
+ 33930,
+ 34178,
+ 34610,
+ 34617
+ ]
+ },
+ "Refrigerator": {
+ "seen_instance": [],
+ "unseen_instance": [
+ 10068,
+ 10143,
+ 10144,
+ 10373,
+ 10489,
+ 10586,
+ 10620,
+ 10627,
+ 10655,
+ 10685,
+ 10751,
+ 10797,
+ 10849,
+ 10867,
+ 10900,
+ 10905,
+ 10944,
+ 11178,
+ 11211,
+ 11231,
+ 11260,
+ 11299,
+ 11304,
+ 11709,
+ 11712,
+ 11846,
+ 12036,
+ 12038,
+ 12042,
+ 12050,
+ 12054,
+ 12055,
+ 12059,
+ 12248,
+ 12249,
+ 12250,
+ 12252
+ ]
+ },
+ "Suitcase": {
+ "seen_instance": [],
+ "unseen_instance": [
+ 100249,
+ 100550,
+ 100767,
+ 100776,
+ 100825,
+ 101668,
+ 101673,
+ 103755,
+ 103761,
+ 103762
+ ]
+ },
+ "Laptop": {
+ "seen_instance": [],
+ "unseen_instance": [
+ 9748,
+ 9912,
+ 9918,
+ 9960,
+ 9968,
+ 9992,
+ 9996,
+ 10040,
+ 10090,
+ 10098,
+ 10101,
+ 10108,
+ 10125,
+ 10211,
+ 10213,
+ 10238,
+ 10239,
+ 10243,
+ 10248,
+ 10269,
+ 10270,
+ 10280,
+ 10289,
+ 10305,
+ 10306,
+ 10356,
+ 10383,
+ 10626,
+ 10707,
+ 10885,
+ 11030,
+ 11075,
+ 11141,
+ 11156,
+ 11242,
+ 11395,
+ 11405,
+ 11406,
+ 11429,
+ 11477,
+ 11581,
+ 11586,
+ 11691,
+ 11778,
+ 11854,
+ 11876,
+ 11945,
+ 12073
+ ]
+ },
+ "Phone": {
+ "seen_instance": [],
+ "unseen_instance": [
+ 103251,
+ 103285,
+ 103347,
+ 103593,
+ 103699,
+ 103813,
+ 103814,
+ 103828,
+ 103886,
+ 103892,
+ 103917,
+ 103925,
+ 103927,
+ 103935,
+ 103941
+ ]
+ },
+ "Door": {
+ "seen_instance": [],
+ "unseen_instance": [
+ 8867,
+ 8897,
+ 8903,
+ 8919,
+ 8961,
+ 8983,
+ 8994,
+ 8997,
+ 9016,
+ 9070,
+ 9117,
+ 9263,
+ 9288,
+ 9393
+ ]
+ },
+ "Oven": {
+ "seen_instance": [],
+ "unseen_instance": [
+ 7120,
+ 7130,
+ 7138,
+ 7179,
+ 7187,
+ 7201,
+ 7220,
+ 7290,
+ 7332,
+ 101773,
+ 101908,
+ 101917,
+ 101921,
+ 101930,
+ 101931,
+ 101943,
+ 101946,
+ 101947,
+ 101971,
+ 102018,
+ 102044,
+ 102055,
+ 102060
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git a/dataset/render_tools/render.py b/dataset/render_tools/render.py
new file mode 100644
index 0000000..4fb0ea6
--- /dev/null
+++ b/dataset/render_tools/render.py
@@ -0,0 +1,148 @@
+import os
+import sys
+from os.path import join as pjoin
+import numpy as np
+from argparse import ArgumentParser
+
+from utils.config_utils import ID_PATH, DATASET_PATH, CAMERA_POSITION_RANGE, TARGET_GAPARTS, BACKGROUND_RGB, SAVE_PATH
+from utils.read_utils import get_id_category, read_joints_from_urdf_file, save_rgb_image, save_depth_map, save_anno_dict, save_meta
+from utils.render_utils import get_cam_pos, set_all_scene, render_rgb_image, render_depth_map, \
+ render_sem_ins_seg_map, add_background_color_for_image, get_camera_pos_mat, merge_joint_qpos
+from utils.pose_utils import query_part_pose_from_joint_qpos, get_NPCS_map_from_oriented_bbox
+
+
+def render_one_image(model_id, camera_idx, render_idx, height, width, use_raytracing=False, replace_texture=False):
+ # 1. read the id list to get the category
+ category = get_id_category(model_id, ID_PATH)
+ if category is None:
+ raise ValueError(f'Cannot find the category of model {model_id}')
+
+ # 2. read the urdf file, get the kinematic chain, and collect all the joints information
+ data_path = pjoin(DATASET_PATH, str(model_id))
+ joints_dict = read_joints_from_urdf_file(data_path, 'mobility_annotation_gapartnet.urdf')
+
+ # 3. generate the joint qpos randomly in the limit range
+ joint_qpos = {}
+ for joint_name in joints_dict:
+ joint_type = joints_dict[joint_name]['type']
+ if joint_type == 'prismatic' or joint_type == 'revolute':
+ joint_limit = joints_dict[joint_name]['limit']
+ joint_qpos[joint_name] = np.random.uniform(joint_limit[0], joint_limit[1])
+ elif joint_type == 'fixed':
+ joint_qpos[joint_name] = 0.0 # ! the qpos of fixed joint must be 0.0
+ elif joint_type == 'continuous':
+ joint_qpos[joint_name] = np.random.uniform(-10000.0, 10000.0)
+ else:
+ raise ValueError(f'Unknown joint type {joint_type}')
+
+ # 4. generate the camera pose randomly in the specified range
+ camera_range = CAMERA_POSITION_RANGE[category][camera_idx]
+ camera_pos = get_cam_pos(
+ theta_min=camera_range['theta_min'], theta_max=camera_range['theta_max'],
+ phi_min=camera_range['phi_min'], phi_max=camera_range['phi_max'],
+ dis_min=camera_range['distance_min'], dis_max=camera_range['distance_max']
+ )
+
+ # 5. pass the joint qpos and the augmentation parameters to set up render environment and robot
+ scene, camera, engine, robot = set_all_scene(data_path=data_path,
+ urdf_file='mobility_annotation_gapartnet.urdf',
+ cam_pos=camera_pos,
+ width=width,
+ height=height,
+ use_raytracing=False,
+ joint_qpos_dict=joint_qpos)
+
+ # 6. use qpos to calculate the gapart poses
+ link_pose_dict = query_part_pose_from_joint_qpos(data_path=data_path, anno_file='link_annotation_gapartnet.json', joint_qpos=joint_qpos, joints_dict=joints_dict, target_parts=TARGET_GAPARTS, robot=robot)
+
+ # 7. render the rgb, depth, mask, valid(visible) gapart
+ rgb_image = render_rgb_image(camera=camera)
+ depth_map = render_depth_map(camera=camera)
+ sem_seg_map, ins_seg_map, valid_linkName_to_instId = render_sem_ins_seg_map(scene=scene, camera=camera, link_pose_dict=link_pose_dict, depth_map=depth_map)
+ valid_link_pose_dict = {link_name: link_pose_dict[link_name] for link_name in valid_linkName_to_instId.keys()}
+
+ # 8. acquire camera intrinsic and extrinsic matrix
+ camera_intrinsic, world2camera_rotation, camera2world_translation = get_camera_pos_mat(camera)
+
+ # 9. calculate NPCS map
+ valid_linkPose_RTS_dict, valid_NPCS_map = get_NPCS_map_from_oriented_bbox(depth_map, ins_seg_map, valid_linkName_to_instId, valid_link_pose_dict, camera_intrinsic, world2camera_rotation, camera2world_translation)
+
+ # 10. (optional) use texture to render rgb to replace the previous rgb (texture issue during cutting the mesh)
+ if replace_texture:
+ texture_joints_dict = read_joints_from_urdf_file(data_path, 'mobility_texture_gapartnet.urdf')
+ texture_joint_qpos = merge_joint_qpos(joint_qpos, joints_dict, texture_joints_dict)
+ scene, camera, engine, robot = set_all_scene(data_path=data_path,
+ urdf_file='mobility_texture_gapartnet.urdf',
+ cam_pos=camera_pos,
+ width=width,
+ height=height,
+ use_raytracing=use_raytracing,
+ joint_qpos_dict=texture_joint_qpos,
+ engine=engine)
+ rgb_image = render_rgb_image(camera=camera)
+
+ # 11. add background color
+ rgb_image = add_background_color_for_image(rgb_image, depth_map, BACKGROUND_RGB)
+
+ # 12. save the rendered results
+ save_name = f"{category}_{model_id}_{camera_idx}_{render_idx}"
+ if not os.path.exists(SAVE_PATH):
+ os.mkdir(SAVE_PATH)
+
+ save_rgb_image(rgb_image, SAVE_PATH, save_name)
+
+ save_depth_map(depth_map, SAVE_PATH, save_name)
+
+ bbox_pose_dict = {}
+ for link_name in valid_link_pose_dict:
+ bbox_pose_dict[link_name] = {
+ 'bbox': valid_link_pose_dict[link_name]['bbox'],
+ 'category_id': valid_link_pose_dict[link_name]['category_id'],
+ 'instance_id': valid_linkName_to_instId[link_name],
+ 'pose_RTS_param': valid_linkPose_RTS_dict[link_name],
+ }
+ anno_dict = {
+ 'semantic_segmentation': sem_seg_map,
+ 'instance_segmentation': ins_seg_map,
+ 'npcs_map': valid_NPCS_map,
+ 'bbox_pose_dict': bbox_pose_dict,
+ }
+ save_anno_dict(anno_dict, SAVE_PATH, save_name)
+
+ metafile = {
+ 'model_id': model_id,
+ 'category': category,
+ 'camera_idx': camera_idx,
+ 'render_idx': render_idx,
+ 'width': width,
+ 'height': height,
+ 'joint_qpos': joint_qpos,
+ 'camera_pos': camera_pos.reshape(-1).tolist(),
+ 'camera_intrinsic': camera_intrinsic.reshape(-1).tolist(),
+ 'world2camera_rotation': world2camera_rotation.reshape(-1).tolist(),
+ 'camera2world_translation': camera2world_translation.reshape(-1).tolist(),
+ 'target_gaparts': TARGET_GAPARTS,
+ 'use_raytracing': use_raytracing,
+ 'replace_texture': replace_texture,
+ }
+ save_meta(metafile, SAVE_PATH, save_name)
+
+ print(f"Rendered {save_name} successfully!")
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser()
+ parser.add_argument('--model_id', type=int, default=41083, help='Specify the model id to render')
+ parser.add_argument('--camera_idx', type=int, default=0, help='Specify the camera range index to render')
+ parser.add_argument('--render_idx', type=int, default=0, help='Specify the render index to render')
+ parser.add_argument('--height', type=int, default=800, help='Specify the height of the rendered image')
+ parser.add_argument('--width', type=int, default=800, help='Specify the width of the rendered image')
+ parser.add_argument('--ray_tracing', type=bool, default=False, help='Specify whether to use ray tracing in rendering')
+ parser.add_argument('--replace_texture', type=bool, default=False, help='Specify whether to replace the texture of the rendered image using the original model')
+
+ args = parser.parse_args()
+
+ render_one_image(args.model_id, args.camera_idx, args.render_idx, args.height, args.width, args.ray_tracing, args.replace_texture)
+
+ print("Done!")
+
diff --git a/dataset/render_tools/render_all_partnet.py b/dataset/render_tools/render_all_partnet.py
new file mode 100644
index 0000000..fc868e5
--- /dev/null
+++ b/dataset/render_tools/render_all_partnet.py
@@ -0,0 +1,52 @@
+import os
+import sys
+from argparse import ArgumentParser
+
+sys.path.append('./utils')
+from utils.config_utils import ID_PATH, CAMERA_POSITION_RANGE, HEIGHT, WIDTH
+
+if __name__ == "__main__":
+ parser = ArgumentParser()
+ parser.add_argument('--ray_tracing', type=bool, default=False, help='Specify whether to use ray tracing in rendering')
+ parser.add_argument('--replace_texture', type=bool, default=False, help='Specify whether to replace the texture of the rendered image using the original model')
+ parser.add_argument('--start_idx', type=int, default=0, help='Specify the start index of the model id to render')
+ parser.add_argument('--num_render', type=int, default=32, help='Specify the number of renderings for each model id each camera range')
+ parser.add_argument('--log_dir', type=str, default='./log_render.txt', help='Specify the log file')
+
+ args = parser.parse_args()
+
+ ray_tracing = args.ray_tracing
+ replace_texture = args.replace_texture
+ start_idx = args.start_idx
+ num_render = args.num_render
+ log_dir = args.log_dir
+
+ model_id_list = []
+ with open(ID_PATH, 'r') as fd:
+ for line in fd:
+ ls = line.strip().split(' ')
+ model_id_list.append((ls[0], int(ls[1])))
+
+ total_to_render = len(model_id_list)
+ cnt = 0
+
+ for category, model_id in model_id_list:
+ print(f'Still to render: {total_to_render-cnt}\n')
+
+ for pos_idx in range(len(CAMERA_POSITION_RANGE[category])):
+ for render_idx in range(num_render):
+ print(f'Rendering: {category} : {model_id} : {pos_idx} : {start_idx + render_idx}\n')
+
+ render_string = f'python -u render.py --model_id {model_id} --camera_idx {pos_idx} --render_idx {start_idx + render_idx} --height {HEIGHT} --width {WIDTH}'
+ if ray_tracing:
+ render_string += ' --ray_tracing True'
+ if replace_texture:
+ render_string += ' --replace_texture True'
+ render_string += f' 2>&1 | tee -a {log_dir}'
+
+ os.system(render_string)
+
+ print(f'Render Over: {category} : {model_id}\n')
+ cnt += 1
+
+ print("Over!!!")
diff --git a/dataset/render_tools/requirements.txt b/dataset/render_tools/requirements.txt
new file mode 100644
index 0000000..9066538
--- /dev/null
+++ b/dataset/render_tools/requirements.txt
@@ -0,0 +1,5 @@
+sapien == 2.1.0 # for Linux
+# sapien == 1.0.0rc2 # for macOS
+opencv-contrib-python
+open3d
+transforms3d
diff --git a/dataset/render_tools/utils/config_utils.py b/dataset/render_tools/utils/config_utils.py
new file mode 100644
index 0000000..b56ab30
--- /dev/null
+++ b/dataset/render_tools/utils/config_utils.py
@@ -0,0 +1,222 @@
+import os
+import sys
+from os.path import join as pjoin
+import json
+import numpy as np
+
+
+# TODO: Set the path to the dataset
+DATASET_PATH = './dataset/partnet_all_annotated_new'
+
+SAVE_PATH = './example_rendered'
+
+VISU_SAVE_PATH = './visu'
+
+ID_PATH = './meta/partnet_all_id_list.txt'
+
+TARGET_GAPARTS = [
+ 'line_fixed_handle', 'round_fixed_handle', 'slider_button', 'hinge_door', 'slider_drawer',
+ 'slider_lid', 'hinge_lid', 'hinge_knob', 'hinge_handle'
+]
+
+OBJECT_CATEGORIES = [
+ 'Box', 'Camera', 'CoffeeMachine', 'Dishwasher', 'KitchenPot', 'Microwave', 'Oven', 'Phone', 'Refrigerator',
+ 'Remote', 'Safe', 'StorageFurniture', 'Table', 'Toaster', 'TrashCan', 'WashingMachine', 'Keyboard', 'Laptop', 'Door', 'Printer',
+ 'Suitcase', 'Bucket', 'Toilet'
+]
+
+CAMERA_POSITION_RANGE = {
+ 'Box': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.5
+ }],
+ 'Camera': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.5
+ }, {
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': -60.0,
+ 'phi_max': 60.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.5
+ }],
+ 'CoffeeMachine': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.1
+ }],
+ 'Dishwasher': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.1
+ }],
+ 'KitchenPot': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.1
+ }],
+ 'Microwave': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.1
+ }],
+ 'Oven': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.1
+ }],
+ 'Phone': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.1
+ }],
+ 'Refrigerator': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.1
+ }],
+ 'Remote': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.1
+ }],
+ 'Safe': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.1
+ }],
+ 'StorageFurniture': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 4.1,
+ 'distance_max': 5.2
+ }],
+ 'Table': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.8,
+ 'distance_max': 4.5
+ }],
+ 'Toaster': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.1
+ }],
+ 'TrashCan': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 4,
+ 'distance_max': 5.5
+ }],
+ 'WashingMachine': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.1
+ }],
+ 'Keyboard': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3,
+ 'distance_max': 3.5
+ }],
+ 'Laptop': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.1
+ }],
+ 'Door': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.1
+ }],
+ 'Printer': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.1
+ }],
+ 'Suitcase': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.1
+ }],
+ 'Bucket': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.1
+ }],
+ 'Toilet': [{
+ 'theta_min': 30.0,
+ 'theta_max': 80.0,
+ 'phi_min': 120.0,
+ 'phi_max': 240.0,
+ 'distance_min': 3.5,
+ 'distance_max': 4.1
+ }]
+}
+
+BACKGROUND_RGB = np.array([216, 206, 189], dtype=np.uint8)
diff --git a/dataset/render_tools/utils/pose_utils.py b/dataset/render_tools/utils/pose_utils.py
new file mode 100644
index 0000000..d20c593
--- /dev/null
+++ b/dataset/render_tools/utils/pose_utils.py
@@ -0,0 +1,153 @@
+from os.path import join as pjoin
+import json
+import math
+import numpy as np
+import transforms3d.euler as t
+import transforms3d.axangles as tax
+import sapien.core as sapien
+
+
+def query_part_pose_from_joint_qpos(data_path, anno_file, joint_qpos, joints_dict, target_parts, robot: sapien.KinematicArticulation):
+ anno_path = pjoin(data_path, anno_file)
+ anno_list = json.load(open(anno_path, 'r'))
+
+ target_links = {}
+ for link_dict in anno_list:
+ link_name = link_dict['link_name']
+ is_gapart = link_dict['is_gapart']
+ part_class = link_dict['category']
+ bbox = link_dict['bbox']
+ if is_gapart and part_class in target_parts:
+ target_links[link_name] = {
+ 'category_id': target_parts.index(part_class),
+ 'bbox': np.array(bbox, dtype=np.float32).reshape(-1, 3)
+ }
+
+ joint_states = {}
+ for joint in robot.get_joints():
+ joint_name = joint.get_name()
+ if joint_name in joints_dict:
+ joint_pose = joint.get_parent_link().pose * joint.get_pose_in_parent()
+ joint_states[joint_name] = {
+ 'origin': joint_pose.p,
+ 'axis': joint_pose.to_transformation_matrix()[:3,:3] @ [1,0,0]
+ }
+
+ child_link_to_joint_name = {}
+ for joint_name, joint_dict in joints_dict.items():
+ child_link_to_joint_name[joint_dict['child']] = joint_name
+
+ result_dict = {}
+
+ for link_name, link_dict in target_links.items():
+ joint_names_to_base = []
+ cur_name = link_name
+ while cur_name in child_link_to_joint_name:
+ joint_name = child_link_to_joint_name[cur_name]
+ joint_names_to_base.append(joint_name)
+ cur_name = joints_dict[joint_name]['parent']
+ assert cur_name == 'base'
+ joint_names_to_base = joint_names_to_base[:-1] # remove the last joint to 'base'
+
+ bbox = link_dict['bbox']
+ part_class = link_dict['category_id']
+ for joint_name in joint_names_to_base[::-1]:
+ joint_type = joints_dict[joint_name]['type']
+ origin = joint_states[joint_name]['origin']
+ axis = joint_states[joint_name]['axis']
+ axis = axis / np.linalg.norm(axis)
+ if joint_type == "fixed":
+ continue
+ elif joint_type == "prismatic":
+ bbox = bbox + axis * joint_qpos[joint_name]
+ elif joint_type == "revolute" or joint_type == "continuous":
+ rotation_mat = t.axangle2mat(axis.reshape(-1).tolist(), joint_qpos[joint_name]).T
+ bbox = np.dot(bbox - origin, rotation_mat) + origin
+
+ result_dict[link_name] = {
+ 'category_id': part_class,
+ 'bbox': bbox
+ }
+
+ return result_dict
+
+
+def backproject_depth_into_pointcloud(depth_map, ins_seg_map, valid_linkName_to_instId, camera_intrinsic, eps=1e-6):
+ part_pcs_dict = {}
+
+ for link_name, inst_id in valid_linkName_to_instId.items():
+ mask = (ins_seg_map == inst_id).astype(np.int32)
+ area = int(sum(sum(mask > 0)))
+ assert area > 0, 'link {} has no area'.format(link_name)
+ ys, xs = (mask > 0).nonzero()
+ part_pcs = []
+ for y, x in zip(ys, xs):
+ if abs(depth_map[y][x]) < eps:
+ continue
+ z_proj = float(depth_map[y][x])
+ x_proj = (float(x) - camera_intrinsic[0, 2]) * z_proj / camera_intrinsic[0, 0]
+ y_proj = (float(y) - camera_intrinsic[1, 2]) * z_proj / camera_intrinsic[1, 1]
+ part_pcs.append([x_proj, y_proj, z_proj])
+ assert len(part_pcs) > 0, 'link {} has no valid point'.format(link_name)
+ part_pcs_dict[link_name] = np.array(part_pcs).reshape(-1, 3)
+
+ return part_pcs_dict
+
+
+def compute_rotation_matrix(b1, b2):
+ c1 = np.mean(b1, axis=0)
+ c2 = np.mean(b2, axis=0)
+ H = np.dot((b1 - c1).T, (b2 - c2))
+ U, s, Vt = np.linalg.svd(H)
+ R = np.dot(Vt.T, U.T)
+
+ if np.linalg.det(R) < 0:
+ R[0, :] *= -1
+
+ return R.T
+
+
+def get_NPCS_map_from_oriented_bbox(depth_map, inst_seg_map, linkName_to_instId, link_pose_dict, camera_intrinsic, world2camera_rotation, camera2world_translation):
+ NPCS_RTS_dict = {}
+ for link_name in linkName_to_instId.keys():
+ bbox = link_pose_dict[link_name]['bbox']
+ T = bbox.mean(axis=0)
+ s_x = np.linalg.norm(bbox[1] - bbox[0])
+ s_y = np.linalg.norm(bbox[1] - bbox[2])
+ s_z = np.linalg.norm(bbox[0] - bbox[4])
+ S = np.array([s_x, s_y, s_z])
+ scaler = np.linalg.norm(S)
+ bbox_scaled = (bbox - T) / scaler
+ bbox_canon = np.array([
+ [-s_x / 2, s_y / 2, s_z / 2],
+ [s_x / 2, s_y / 2, s_z / 2],
+ [s_x / 2, -s_y / 2, s_z / 2],
+ [-s_x / 2, -s_y / 2, s_z / 2],
+ [-s_x / 2, s_y / 2, -s_z / 2],
+ [s_x / 2, s_y / 2, -s_z / 2],
+ [s_x / 2, -s_y / 2, -s_z / 2],
+ [-s_x / 2, -s_y / 2, -s_z / 2]
+ ]) / scaler
+ R = compute_rotation_matrix(bbox_canon, bbox_scaled)
+ NPCS_RTS_dict[link_name] = {'R': R, 'T': T, 'S': S, 'scaler': scaler}
+
+ height, width = depth_map.shape
+ canon_position_map = np.zeros((height, width, 3), dtype=np.float32)
+
+ instId_to_linkName = {v: k for k, v in linkName_to_instId.items()}
+ assert len(instId_to_linkName) == len(linkName_to_instId)
+ for y in range(height):
+ for x in range(width):
+ if inst_seg_map[y][x] < 0:
+ continue
+ z_proj = float(depth_map[y][x])
+ x_proj = (float(x) - camera_intrinsic[0, 2]) * z_proj / camera_intrinsic[0, 0]
+ y_proj = (float(y) - camera_intrinsic[1, 2]) * z_proj / camera_intrinsic[1, 1]
+ pixel_camera_position = np.array([x_proj, y_proj, z_proj])
+ pixel_world_position = pixel_camera_position @ world2camera_rotation.T + camera2world_translation
+ RTS_param = NPCS_RTS_dict[instId_to_linkName[inst_seg_map[y][x]]]
+ pixel_npcs_position = ((pixel_world_position - RTS_param['T']) / RTS_param['scaler']) @ RTS_param['R'].T
+ canon_position_map[y][x] = pixel_npcs_position
+
+ return NPCS_RTS_dict, canon_position_map
+
diff --git a/dataset/render_tools/utils/read_utils.py b/dataset/render_tools/utils/read_utils.py
new file mode 100644
index 0000000..d0fdf83
--- /dev/null
+++ b/dataset/render_tools/utils/read_utils.py
@@ -0,0 +1,147 @@
+import os
+from os.path import join as pjoin
+import xml.etree.ElementTree as ET
+import json
+import numpy as np
+from PIL import Image
+import pickle
+
+
+def get_id_category(target_id, id_path):
+ category = None
+ with open(id_path, 'r') as fd:
+ for line in fd:
+ cat = line.rstrip('\n').split(' ')[0]
+ id = int(line.rstrip('\n').split(' ')[1])
+ if id == target_id:
+ category = cat
+ break
+ return category
+
+
+def read_joints_from_urdf_file(data_path, urdf_name):
+ urdf_file = pjoin(data_path, urdf_name)
+ tree_urdf = ET.parse(urdf_file)
+ root_urdf = tree_urdf.getroot()
+
+ joint_dict = {}
+ for joint in root_urdf.iter('joint'):
+ joint_name = joint.attrib['name']
+ joint_type = joint.attrib['type']
+ for child in joint.iter('child'):
+ joint_child = child.attrib['link']
+ for parent in joint.iter('parent'):
+ joint_parent = parent.attrib['link']
+ for origin in joint.iter('origin'):
+ if 'xyz' in origin.attrib:
+ joint_xyz = [float(x) for x in origin.attrib['xyz'].split()]
+ else:
+ joint_xyz = [0, 0, 0]
+ if 'rpy' in origin.attrib:
+ joint_rpy = [float(x) for x in origin.attrib['rpy'].split()]
+ else:
+ joint_rpy = [0, 0, 0]
+ if joint_type == 'prismatic' or joint_type == 'revolute' or joint_type == 'continuous':
+ for axis in joint.iter('axis'):
+ joint_axis = [float(x) for x in axis.attrib['xyz'].split()]
+ else:
+ joint_axis = None
+ if joint_type == 'prismatic' or joint_type == 'revolute':
+ for limit in joint.iter('limit'):
+ joint_limit = [float(limit.attrib['lower']), float(limit.attrib['upper'])]
+ else:
+ joint_limit = None
+
+ joint_dict[joint_name] = {
+ 'type': joint_type,
+ 'parent': joint_parent,
+ 'child': joint_child,
+ 'xyz': joint_xyz,
+ 'rpy': joint_rpy,
+ 'axis': joint_axis,
+ 'limit': joint_limit
+ }
+
+ return joint_dict
+
+
+def save_rgb_image(rgb_img, save_path, filename):
+ rgb_path = pjoin(save_path, 'rgb')
+ if not os.path.exists(rgb_path): os.mkdir(rgb_path)
+
+ new_image = Image.fromarray(rgb_img)
+ new_image.save(pjoin(rgb_path, f'{filename}.png'))
+
+
+def save_depth_map(depth_map, save_path, filename):
+ depth_path = pjoin(save_path, 'depth')
+ if not os.path.exists(depth_path): os.mkdir(depth_path)
+
+ np.savez_compressed(pjoin(depth_path, f'{filename}.npz'), depth_map=depth_map)
+
+
+def save_anno_dict(anno_dict, save_path, filename):
+ seg_path = pjoin(save_path, 'segmentation')
+ bbox_path = pjoin(save_path, 'bbox')
+ npcs_path = pjoin(save_path, 'npcs')
+
+ if not os.path.exists(seg_path): os.mkdir(seg_path)
+ if not os.path.exists(bbox_path): os.mkdir(bbox_path)
+ if not os.path.exists(npcs_path): os.mkdir(npcs_path)
+
+ np.savez_compressed(pjoin(seg_path, f'{filename}.npz'),
+ semantic_segmentation=anno_dict['semantic_segmentation'],
+ instance_segmentation=anno_dict['instance_segmentation'])
+
+ np.savez_compressed(pjoin(npcs_path, f'{filename}.npz'), npcs_map=anno_dict['npcs_map'])
+
+ with open(pjoin(bbox_path, f'{filename}.pkl'), 'wb') as fd:
+ bbox_dict = {'bbox_pose_dict': anno_dict['bbox_pose_dict']}
+ pickle.dump(bbox_dict, fd)
+
+
+def save_meta(meta, save_path, filename):
+ meta_path = pjoin(save_path, 'metafile')
+ if not os.path.exists(meta_path): os.mkdir(meta_path)
+
+ with open(pjoin(meta_path, f'{filename}.json'), 'w') as fd:
+ json.dump(meta, fd)
+
+
+def load_rgb_image(save_path, filename):
+ img = Image.open(pjoin(save_path, 'rgb', f'{filename}.png'))
+ return np.array(img)
+
+
+def load_depth_map(save_path, filename):
+ depth_dict = np.load(pjoin(save_path, 'depth', f'{filename}.npz'))
+ depth_map = depth_dict['depth_map']
+ return depth_map
+
+
+def load_anno_dict(save_path, filename):
+ anno_dict = {}
+
+ seg_path = pjoin(save_path, 'segmentation')
+ bbox_path = pjoin(save_path, 'bbox')
+ npcs_path = pjoin(save_path, 'npcs')
+
+ seg_dict = np.load(pjoin(seg_path, f'{filename}.npz'))
+ anno_dict['semantic_segmentation'] = seg_dict['semantic_segmentation']
+ anno_dict['instance_segmentation'] = seg_dict['instance_segmentation']
+
+ npcs_dict = np.load(pjoin(npcs_path, f'{filename}.npz'))
+ anno_dict['npcs_map'] = npcs_dict['npcs_map']
+
+ with open(pjoin(bbox_path, f'{filename}.pkl'), 'rb') as fd:
+ bbox_dict = pickle.load(fd)
+ anno_dict['bbox_pose_dict'] = bbox_dict['bbox_pose_dict']
+
+ return anno_dict
+
+
+def load_meta(save_path, filename):
+ with open(pjoin(save_path, 'metafile', f'{filename}.json'), 'r') as fd:
+ meta = json.load(fd)
+ return meta
+
diff --git a/dataset/render_tools/utils/render_utils.py b/dataset/render_tools/utils/render_utils.py
new file mode 100644
index 0000000..7a88403
--- /dev/null
+++ b/dataset/render_tools/utils/render_utils.py
@@ -0,0 +1,245 @@
+import os
+from os.path import join as pjoin
+import math
+import numpy as np
+import sapien.core as sapien
+import transforms3d.euler as t
+import transforms3d.axangles as tax
+
+
+def get_cam_pos(theta_min, theta_max, phi_min, phi_max, dis_min, dis_max):
+ theta = np.random.uniform(low=theta_min, high=theta_max)
+ phi = np.random.uniform(low=phi_min, high=phi_max)
+ distance = np.random.uniform(low=dis_min, high=dis_max)
+ x = math.sin(math.pi / 180 * theta) * math.cos(math.pi / 180 * phi) * distance
+ y = math.sin(math.pi / 180 * theta) * math.sin(math.pi / 180 * phi) * distance
+ z = math.cos(math.pi / 180 * theta) * distance
+ return np.array([x, y, z])
+
+
+def set_all_scene(data_path,
+ urdf_file,
+ cam_pos,
+ width,
+ height,
+ joint_qpos_dict,
+ engine=None,
+ use_raytracing=False):
+
+ # set the sapien environment
+ if engine is None:
+ engine = sapien.Engine()
+ if use_raytracing:
+ config = sapien.KuafuConfig()
+ config.spp = 256
+ config.use_denoiser = True
+ renderer = sapien.KuafuRenderer(config)
+ else:
+ renderer = sapien.VulkanRenderer(offscreen_only=True)
+ engine.set_renderer(renderer)
+
+ scene = engine.create_scene()
+ scene.set_timestep(1 / 100.0)
+
+ # load model
+ loader = scene.create_urdf_loader()
+ loader.fix_root_link = True
+ urdf_path = os.path.join(data_path, urdf_file)
+ robot = loader.load_kinematic(urdf_path)
+ assert robot, 'URDF not loaded.'
+
+ joints = robot.get_joints()
+ qpos = []
+ for joint in joints:
+ if joint.get_parent_link() is None:
+ continue
+ joint_name = joint.get_name()
+ joint_type = joint.type
+ if joint_type == 'revolute' or joint_type == 'prismatic' or joint_type == 'continuous':
+ qpos.append(joint_qpos_dict[joint_name])
+ qpos = np.array(qpos)
+ assert qpos.shape[0] == robot.get_qpos().shape[0], 'qpos shape not match.'
+ robot.set_qpos(qpos=qpos)
+
+ # * different in server and local (sapien version issue)
+ scene.set_ambient_light([0.5, 0.5, 0.5])
+ scene.add_directional_light([0, 1, -1], [0.5, 0.5, 0.5], shadow=True)
+ scene.add_point_light([1, 2, 2], [1, 1, 1], shadow=True)
+ scene.add_point_light([1, -2, 2], [1, 1, 1], shadow=True)
+ scene.add_point_light([-1, 0, 1], [1, 1, 1], shadow=True)
+
+ # rscene = scene.get_renderer_scene()
+ # rscene.set_ambient_light([0.5, 0.5, 0.5])
+ # rscene.add_directional_light([0, 1, -1], [0.5, 0.5, 0.5], shadow=True)
+ # rscene.add_point_light([1, 2, 2], [1, 1, 1], shadow=True)
+ # rscene.add_point_light([1, -2, 2], [1, 1, 1], shadow=True)
+ # rscene.add_point_light([-1, 0, 1], [1, 1, 1], shadow=True)
+
+ camera_mount_actor = scene.create_actor_builder().build_kinematic()
+ camera = scene.add_mounted_camera(
+ name="camera",
+ actor=camera_mount_actor,
+ pose=sapien.Pose(), # relative to the mounted actor
+ width=width,
+ height=height,
+ fovx=np.deg2rad(35.0),
+ fovy=np.deg2rad(35.0),
+ near=0.1,
+ far=100.0,
+ )
+
+ forward = -cam_pos / np.linalg.norm(cam_pos)
+ left = np.cross([0, 0, 1], forward)
+ left = left / np.linalg.norm(left)
+ up = np.cross(forward, left)
+
+ mat44 = np.eye(4)
+ mat44[:3, :3] = np.stack([forward, left, up], axis=1)
+ mat44[:3, 3] = cam_pos
+ camera_mount_actor.set_pose(sapien.Pose.from_transformation_matrix(mat44))
+
+ scene.step()
+ scene.update_render()
+ camera.take_picture()
+
+ return scene, camera, engine, robot
+
+
+def render_rgb_image(camera):
+ rgba = camera.get_float_texture('Color')
+ rgb = rgba[:, :, :3]
+ rgb_img = (rgb * 255).clip(0, 255).astype("uint8")
+ return rgb_img
+
+
+def render_depth_map(camera):
+ position = camera.get_float_texture('Position')
+ depth_map = -position[..., 2]
+ return depth_map
+
+
+def get_visid2gapart_mapping_dict(scene: sapien.Scene, linkId2catName, target_parts_list):
+ # map visual id to instance name
+ visId2instName = {}
+ # map instance name to category id(index+1, 0 for others)
+ instName2catId = {}
+ for articulation in scene.get_all_articulations():
+ for link in articulation.get_links():
+ link_name = link.get_name()
+ if link_name == 'base':
+ continue
+ link_id = int(link_name.split('_')[-1]) + 1
+ for visual in link.get_visual_bodies():
+ visual_name = visual.get_name()
+ if visual_name.find('handle') != -1 and linkId2catName[link_id].find('handle') == -1:
+ # visial name handle; link name not handle: fixed handle!
+ inst_name = link_name + ':' + linkId2catName[link_id] + '/' + visual_name.split(
+ '-')[0] + ':' + 'fixed_handle'
+ visual_id = visual.get_visual_id()
+ visId2instName[visual_id] = inst_name
+ if inst_name not in instName2catId.keys():
+ instName2catId[inst_name] = target_parts_list.index('fixed_handle') + 1
+ elif linkId2catName[link_id] in target_parts_list:
+ inst_name = link_name + ':' + linkId2catName[link_id]
+ visual_id = visual.get_visual_id()
+ visId2instName[visual_id] = inst_name
+ if inst_name not in instName2catId.keys():
+ instName2catId[inst_name] = target_parts_list.index(linkId2catName[link_id]) + 1
+ else:
+ inst_name = 'others'
+ visual_id = visual.get_visual_id()
+ visId2instName[visual_id] = inst_name
+ if inst_name not in instName2catId.keys():
+ instName2catId[inst_name] = 0
+ return visId2instName, instName2catId
+
+
+def render_sem_ins_seg_map(scene: sapien.Scene, camera, link_pose_dict, depth_map, eps=1e-6):
+ vis_id_to_link_name = {}
+ for articulation in scene.get_all_articulations():
+ for link in articulation.get_links():
+ link_name = link.get_name()
+ if link_name not in link_pose_dict:
+ continue
+ for visual in link.get_visual_bodies():
+ visual_id = visual.get_visual_id()
+ vis_id_to_link_name[visual_id] = link_name
+
+ seg_labels = camera.get_uint32_texture("Segmentation")
+ seg_labels_by_visual_id = seg_labels[..., 0].astype(np.uint16) # H x W, save each pixel's visual id
+ height, width = seg_labels_by_visual_id.shape
+
+ sem_seg_map = np.ones((height, width), dtype=np.int32) * (-1) # -2 for background, -1 for others, 0~N-1 for N categories
+ ins_seg_map = np.ones((height, width), dtype=np.int32) * (-1) # -2 for background, -1 for others, 0~M-1 for M instances
+
+ valid_linkName_to_instId_mapping = {}
+ part_ins_cnt = 0
+ for link_name in link_pose_dict.keys():
+ mask = np.zeros((height, width), dtype=np.int32)
+ for vis_id in vis_id_to_link_name.keys():
+ if vis_id_to_link_name[vis_id] == link_name:
+ mask += (seg_labels_by_visual_id == vis_id).astype(np.int32)
+ area = int(sum(sum(mask > 0)))
+ if area == 0:
+ continue
+ sem_seg_map[mask > 0] = link_pose_dict[link_name]['category_id']
+ ins_seg_map[mask > 0] = part_ins_cnt
+ valid_linkName_to_instId_mapping[link_name] = part_ins_cnt
+ part_ins_cnt += 1
+
+ empty_mask = abs(depth_map) < eps
+ sem_seg_map[empty_mask] = -2
+ ins_seg_map[empty_mask] = -2
+
+ return sem_seg_map, ins_seg_map, valid_linkName_to_instId_mapping
+
+
+def add_background_color_for_image(rgb_image, depth_map, background_rgb, eps=1e-6):
+ background_mask = abs(depth_map) < eps
+ rgb_image[background_mask] = background_rgb
+
+ return rgb_image
+
+
+def get_camera_pos_mat(camera):
+ K = camera.get_camera_matrix()[:3, :3]
+ Rtilt = camera.get_model_matrix()
+ Rtilt_rot = Rtilt[:3, :3] @ np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
+ Rtilt_trl = Rtilt[:3, 3]
+
+ return K, Rtilt_rot, Rtilt_trl
+
+
+def merge_joint_qpos(joint_qpos_dict, new_joint_dict, old_joint_dict):
+ old_joint_qpos_dict = {}
+ for joint_name in new_joint_dict:
+ if joint_name not in old_joint_dict:
+ assert new_joint_dict[joint_name]['type'] == 'fixed'
+ continue
+ old_joint_qpos_dict[joint_name] = joint_qpos_dict[joint_name]
+ for joint_name in old_joint_dict:
+ assert joint_name in old_joint_qpos_dict
+ return old_joint_qpos_dict
+
+
+def backproject_depth_into_pointcloud(depth_map, ins_seg_map, valid_linkName_to_instId, camera_intrinsic, eps=1e-6):
+ part_pcs_dict = {}
+
+ for link_name, inst_id in valid_linkName_to_instId.items():
+ mask = (ins_seg_map == inst_id).astype(np.int32)
+ area = int(sum(sum(mask > 0)))
+ assert area > 0, 'link {} has no area'.format(link_name)
+ ys, xs = (mask > 0).nonzero()
+ part_pcs = []
+ for y, x in zip(ys, xs):
+ if abs(depth_map[y][x]) < eps:
+ continue
+ z_proj = float(depth_map[y][x])
+ x_proj = (float(x) - camera_intrinsic[0, 2]) * z_proj / camera_intrinsic[0, 0]
+ y_proj = (float(y) - camera_intrinsic[1, 2]) * z_proj / camera_intrinsic[1, 1]
+ part_pcs.append([x_proj, y_proj, z_proj])
+ assert len(part_pcs) > 0, 'link {} has no valid point'.format(link_name)
+ part_pcs_dict[link_name] = np.array(part_pcs).reshape(-1, 3)
+
+ return part_pcs_dict
+
diff --git a/dataset/render_tools/utils/visu_utils.py b/dataset/render_tools/utils/visu_utils.py
new file mode 100644
index 0000000..91a7690
--- /dev/null
+++ b/dataset/render_tools/utils/visu_utils.py
@@ -0,0 +1,299 @@
+import os
+from os.path import join as pjoin
+import math
+import numpy as np
+import open3d as o3d
+import transforms3d.euler as t
+import matplotlib.pyplot as plt
+import cv2
+from PIL import Image
+
+
+def save_image(img_array, save_path, filename):
+ img = Image.fromarray(img_array)
+ img.save(pjoin(save_path, '{}.png'.format(filename)))
+ print('{} saved!'.format(filename))
+
+
+def visu_depth_map(depth_map, eps=1e-6):
+ object_mask = (abs(depth_map) >= eps)
+ empty_mask = (abs(depth_map) < eps)
+ new_map = depth_map - depth_map[object_mask].min()
+ new_map = new_map / new_map.max()
+ new_map = np.clip(new_map * 255, 0, 255).astype('uint8')
+ colored_depth_map = cv2.applyColorMap(new_map, cv2.COLORMAP_JET)
+ colored_depth_map[empty_mask] = np.array([0, 0, 0])
+ return colored_depth_map
+
+
+def visu_2D_seg_map(seg_map):
+ H, W = seg_map.shape
+ seg_image = np.zeros((H, W, 3)).astype("uint8")
+
+ cmap = plt.cm.get_cmap('tab20', 20)
+ cmap = cmap.colors[:, 0:3]
+ cmap = (cmap * 255).clip(0, 255).astype("uint8")
+
+ for y in range(0, H):
+ for x in range(0, W):
+ if seg_map[y, x] == -2:
+ continue
+ if seg_map[y, x] == -1:
+ seg_image[y, x] = cmap[14]
+ else:
+ seg_image[y, x] = cmap[int(seg_map[y, x]) % 20]
+
+ return seg_image
+
+
+def visu_3D_bbox_semantic(rgb_image, bboxes_pose_dict, meta):
+ image = np.copy(rgb_image)
+
+ Rtilt_rot = np.array(meta['world2camera_rotation']).reshape(3, 3)
+ Rtilt_trl = np.array(meta['camera2world_translation']).reshape(1, 3)
+ K = np.array(meta['camera_intrinsic']).reshape(3, 3)
+
+ cmap = plt.cm.get_cmap('tab20', 20)
+ cmap = cmap.colors[:, 0:3]
+ cmap = (cmap * 255).clip(0, 255).astype("uint8")
+ lines = [[0, 1], [1, 2], [2, 3], [0, 3], [4, 5], [5, 6], [6, 7], [4, 7], [0, 4], [1, 5], [2, 6], [3, 7]]
+
+ for link_name, part_dict in bboxes_pose_dict.items():
+ category_id = part_dict['category_id']
+ bbox = part_dict['bbox']
+ bbox_camera = (bbox - Rtilt_trl) @ Rtilt_rot
+ color = tuple(int(x) for x in cmap[category_id])
+ for line in lines:
+ x_start = int(bbox_camera[line[0], 0] * K[0][0] / bbox_camera[line[0], 2] + K[0][2])
+ y_start = int(bbox_camera[line[0], 1] * K[1][1] / bbox_camera[line[0], 2] + K[1][2])
+ x_end = int(bbox_camera[line[1], 0] * K[0][0] / bbox_camera[line[1], 2] + K[0][2])
+ y_end = int(bbox_camera[line[1], 1] * K[1][1] / bbox_camera[line[1], 2] + K[1][2])
+ start = (x_start, y_start)
+ end = (x_end, y_end)
+ thickness = 2
+ linetype = 4
+ cv2.line(image, start, end, color, thickness, linetype)
+ return image
+
+
+def visu_3D_bbox_pose_in_color(rgb_image, bboxes_pose_dict, meta):
+ image = np.copy(rgb_image)
+
+ Rtilt_rot = np.array(meta['world2camera_rotation']).reshape(3, 3)
+ Rtilt_trl = np.array(meta['camera2world_translation']).reshape(1, 3)
+ K = np.array(meta['camera_intrinsic']).reshape(3, 3)
+
+ cmap = plt.cm.get_cmap('tab20', 20)
+ cmap = cmap.colors[:, 0:3]
+ cmap = (cmap * 255).clip(0, 255).astype("uint8")
+ lines = [[0, 1], [1, 2], [2, 3], [0, 3], [4, 5], [5, 6], [6, 7], [4, 7], [0, 4], [1, 5], [2, 6], [3, 7]]
+ colors = [
+ cmap[0], cmap[2], cmap[4], cmap[6], cmap[8], cmap[10], cmap[12], cmap[16], cmap[14], cmap[14], cmap[14],
+ cmap[14]
+ ]
+
+ for link_name, part_dict in bboxes_pose_dict.items():
+ category_id = part_dict['category_id']
+ bbox = part_dict['bbox']
+ bbox_camera = (bbox - Rtilt_trl) @ Rtilt_rot
+ for i, line in enumerate(lines):
+ x_start = int(bbox_camera[line[0], 0] * K[0][0] / bbox_camera[line[0], 2] + K[0][2])
+ y_start = int(bbox_camera[line[0], 1] * K[1][1] / bbox_camera[line[0], 2] + K[1][2])
+ x_end = int(bbox_camera[line[1], 0] * K[0][0] / bbox_camera[line[1], 2] + K[0][2])
+ y_end = int(bbox_camera[line[1], 1] * K[1][1] / bbox_camera[line[1], 2] + K[1][2])
+ start = (x_start, y_start)
+ end = (x_end, y_end)
+ thickness = 2
+ linetype = 4
+ color = tuple(int(x) for x in colors[i])
+ cv2.line(image, start, end, color, thickness, linetype)
+ return image
+
+
+def visu_NPCS_map(npcs_map, ins_seg_map):
+ npcs_image = npcs_map + np.array([0.5, 0.5, 0.5])
+ assert (npcs_image > 0).all(), 'NPCS map error!'
+ assert (npcs_image < 1).all(), 'NPCS map error!'
+ empty_mask = (ins_seg_map == -2)
+ npcs_image[empty_mask] = np.array([0, 0, 0])
+ npcs_image = (np.clip(npcs_image, 0, 1) * 255).astype('uint8')
+
+ return npcs_image
+
+
+def get_recovery_whole_point_cloud_camera(rgb_image, depth_map, meta, eps=1e-6):
+ height, width = depth_map.shape
+ K = meta['camera_intrinsic']
+ K = np.array(K).reshape(3, 3)
+
+ point_cloud = []
+ per_point_rgb = []
+
+ for y_ in range(height):
+ for x_ in range(width):
+ if abs(depth_map[y_][x_]) < eps:
+ continue
+ z_new = float(depth_map[y_][x_])
+ x_new = (x_ - K[0][2]) * z_new / K[0][0]
+ y_new = (y_ - K[1][2]) * z_new / K[1][1]
+ point_cloud.append([x_new, y_new, z_new])
+ per_point_rgb.append([
+ float(rgb_image[y_][x_][0]) / 255,
+ float(rgb_image[y_][x_][1]) / 255,
+ float(rgb_image[y_][x_][2]) / 255
+ ])
+
+ point_cloud = np.array(point_cloud)
+ per_point_rgb = np.array(per_point_rgb)
+
+ return point_cloud, per_point_rgb
+
+
+def get_recovery_part_point_cloud_camera(rgb_image, depth_map, mask, meta, eps=1e-6):
+ height, width = depth_map.shape
+ K = meta['camera_intrinsic']
+ K = np.array(K).reshape(3, 3)
+
+ point_cloud = []
+ per_point_rgb = []
+
+ for y_ in range(height):
+ for x_ in range(width):
+ if abs(depth_map[y_][x_]) < eps:
+ continue
+ if not mask[y_][x_]:
+ continue
+ z_new = float(depth_map[y_][x_])
+ x_new = (x_ - K[0][2]) * z_new / K[0][0]
+ y_new = (y_ - K[1][2]) * z_new / K[1][1]
+ point_cloud.append([x_new, y_new, z_new])
+ per_point_rgb.append([
+ float(rgb_image[y_][x_][0]) / 255,
+ float(rgb_image[y_][x_][1]) / 255,
+ float(rgb_image[y_][x_][2]) / 255
+ ])
+
+ point_cloud = np.array(point_cloud)
+ per_point_rgb = np.array(per_point_rgb)
+
+ return point_cloud, per_point_rgb
+
+
+def draw_bbox_in_3D_semantic(bbox, category_id):
+ cmap = plt.cm.get_cmap('tab20', 20)
+ cmap = cmap.colors[:, 0:3]
+
+ points = []
+ for i in range(bbox.shape[0]):
+ points.append(bbox[i].reshape(-1).tolist())
+ lines = [[0, 1], [1, 2], [2, 3], [0, 3], [4, 5], [5, 6], [6, 7], [4, 7], [0, 4], [1, 5], [2, 6], [3, 7]]
+ # Use the same color for all lines
+ colors = [cmap[category_id] for _ in range(len(lines))]
+ line_set = o3d.geometry.LineSet()
+ line_set.points = o3d.utility.Vector3dVector(points)
+ line_set.lines = o3d.utility.Vector2iVector(lines)
+ line_set.colors = o3d.utility.Vector3dVector(colors)
+ return line_set
+
+
+def draw_bbox_in_3D_pose_color(bbox):
+ cmap = plt.cm.get_cmap('tab20', 20)
+ cmap = cmap.colors[:, 0:3]
+
+ points = []
+ for i in range(bbox.shape[0]):
+ points.append(bbox[i].reshape(-1).tolist())
+ lines = [[0, 1], [1, 2], [2, 3], [0, 3], [4, 5], [5, 6], [6, 7], [4, 7], [0, 4], [1, 5], [2, 6], [3, 7]]
+ # Use the same color for all lines
+ colors = [
+ cmap[0], cmap[2], cmap[4], cmap[6], cmap[8], cmap[10], cmap[12], cmap[16], cmap[14], cmap[14], cmap[14],
+ cmap[14]
+ ]
+ line_set = o3d.geometry.LineSet()
+ line_set.points = o3d.utility.Vector3dVector(points)
+ line_set.lines = o3d.utility.Vector2iVector(lines)
+ line_set.colors = o3d.utility.Vector3dVector(colors)
+ return line_set
+
+
+def visu_point_cloud_with_bbox_semantic(rgb_image, depth_map, bbox_pose_dict, meta):
+
+ point_cloud, per_point_rgb = get_recovery_whole_point_cloud_camera(rgb_image, depth_map, meta)
+ Rtilt_rot = np.array(meta['world2camera_rotation']).reshape(3, 3)
+ Rtilt_trl = np.array(meta['camera2world_translation']).reshape(1, 3)
+ point_cloud_world = point_cloud @ Rtilt_rot.T + Rtilt_trl
+
+ vis_list = []
+ for link_name, part_dict in bbox_pose_dict.items():
+ category_id = part_dict['category_id']
+ bbox = part_dict['bbox']
+ bbox_t = draw_bbox_in_3D_semantic(bbox, category_id)
+ vis_list.append(bbox_t)
+
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(point_cloud_world)
+ pcd.colors = o3d.utility.Vector3dVector(per_point_rgb)
+ vis_list.append(pcd)
+ coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame()
+ vis_list.append(coord_frame)
+ o3d.visualization.draw_geometries(vis_list)
+
+
+def visu_point_cloud_with_bbox_pose_color(rgb_image, depth_map, bbox_pose_dict, meta):
+
+ point_cloud, per_point_rgb = get_recovery_whole_point_cloud_camera(rgb_image, depth_map, meta)
+ Rtilt_rot = np.array(meta['world2camera_rotation']).reshape(3, 3)
+ Rtilt_trl = np.array(meta['camera2world_translation']).reshape(1, 3)
+ point_cloud_world = point_cloud @ Rtilt_rot.T + Rtilt_trl
+
+ vis_list = []
+ for link_name, part_dict in bbox_pose_dict.items():
+ category_id = part_dict['category_id']
+ bbox = part_dict['bbox']
+ bbox_t = draw_bbox_in_3D_pose_color(bbox)
+ vis_list.append(bbox_t)
+
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(point_cloud_world)
+ pcd.colors = o3d.utility.Vector3dVector(per_point_rgb)
+ vis_list.append(pcd)
+ coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame()
+ vis_list.append(coord_frame)
+ o3d.visualization.draw_geometries(vis_list)
+
+
+def visu_NPCS_in_3D_with_bbox_pose_color(rgb_image, depth_map, anno_dict, meta):
+ ins_seg_map = anno_dict['instance_segmentation']
+ bbox_pose_dict = anno_dict['bbox_pose_dict']
+ npcs_map = anno_dict['npcs_map']
+ Rtilt_rot = np.array(meta['world2camera_rotation']).reshape(3, 3)
+ Rtilt_trl = np.array(meta['camera2world_translation']).reshape(1, 3)
+
+ for link_name, part_dict in bbox_pose_dict.items():
+ category_id = part_dict['category_id']
+ instance_id = part_dict['instance_id']
+ bbox_world = part_dict['bbox']
+ mask = (ins_seg_map == instance_id)
+ point_cloud, per_point_rgb = get_recovery_part_point_cloud_camera(rgb_image, depth_map, mask, meta)
+ point_cloud_world = point_cloud @ Rtilt_rot.T + Rtilt_trl
+ RTS_param = part_dict['pose_RTS_param']
+ R, T, S, scaler = RTS_param['R'], RTS_param['T'], RTS_param['S'], RTS_param['scaler']
+ point_cloud_canon = npcs_map[mask]
+ bbox_canon = ((bbox_world - T) / scaler) @ R.T
+
+ vis_list = []
+ vis_list.append(draw_bbox_in_3D_pose_color(bbox_world))
+ vis_list.append(draw_bbox_in_3D_pose_color(bbox_canon))
+
+ pcd_1 = o3d.geometry.PointCloud()
+ pcd_1.points = o3d.utility.Vector3dVector(point_cloud_world)
+ pcd_1.colors = o3d.utility.Vector3dVector(per_point_rgb)
+ vis_list.append(pcd_1)
+ pcd_2 = o3d.geometry.PointCloud()
+ pcd_2.points = o3d.utility.Vector3dVector(point_cloud_canon)
+ pcd_2.colors = o3d.utility.Vector3dVector(per_point_rgb)
+ vis_list.append(pcd_2)
+ coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame()
+ vis_list.append(coord_frame)
+ o3d.visualization.draw_geometries(vis_list)
+
diff --git a/dataset/render_tools/visualize.py b/dataset/render_tools/visualize.py
new file mode 100644
index 0000000..e73fa1e
--- /dev/null
+++ b/dataset/render_tools/visualize.py
@@ -0,0 +1,71 @@
+from json import load
+import os
+from os.path import join as pjoin
+from argparse import ArgumentParser
+
+from utils.config_utils import SAVE_PATH, VISU_SAVE_PATH
+from utils.read_utils import load_rgb_image, load_depth_map, load_anno_dict, load_meta
+from utils.visu_utils import save_image, visu_depth_map, visu_2D_seg_map, visu_3D_bbox_semantic, visu_3D_bbox_pose_in_color, \
+ visu_NPCS_map, visu_point_cloud_with_bbox_semantic, visu_point_cloud_with_bbox_pose_color, visu_NPCS_in_3D_with_bbox_pose_color
+
+
+if __name__ == "__main__":
+
+ parser = ArgumentParser()
+ parser.add_argument('--model_id', type=int)
+ parser.add_argument('--category', type=str)
+ parser.add_argument('--render_index', type=int, default=0)
+ parser.add_argument('--camera_position_index', type=int, default=0)
+
+ CONFS = parser.parse_args()
+
+ MODEL_ID = CONFS.model_id
+ CATEGORY = CONFS.category
+ RENDER_INDEX = CONFS.render_index
+ CAMERA_POSITION_INDEX = CONFS.camera_position_index
+
+ filename = '{}_{}_{}_{}'.format(CATEGORY, MODEL_ID, CAMERA_POSITION_INDEX, RENDER_INDEX)
+ save_path = pjoin(VISU_SAVE_PATH, filename)
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+
+ rgb_image = load_rgb_image(SAVE_PATH, filename)
+ depth_map = load_depth_map(SAVE_PATH, filename)
+ anno_dict = load_anno_dict(SAVE_PATH, filename)
+ metafile = load_meta(SAVE_PATH, filename)
+
+ # depth map
+ colored_depth_image = visu_depth_map(depth_map)
+ save_image(colored_depth_image, save_path, '{}_depth'.format(filename))
+
+ # semantic segmentation
+ sem_seg_image = visu_2D_seg_map(anno_dict['semantic_segmentation'])
+ save_image(sem_seg_image, save_path, '{}_semseg'.format(filename))
+
+ # instance segmentation
+ ins_seg_image = visu_2D_seg_map(anno_dict['instance_segmentation'])
+ save_image(ins_seg_image, save_path, '{}_insseg'.format(filename))
+
+ # 3D bbox with category
+ bbox_3D_image_category = visu_3D_bbox_semantic(rgb_image, anno_dict['bbox_pose_dict'], metafile)
+ save_image(bbox_3D_image_category, save_path, '{}_bbox3Dcat'.format(filename))
+
+ # 3D bbox with pose color
+ bbox_3D_image_pose_color = visu_3D_bbox_pose_in_color(rgb_image, anno_dict['bbox_pose_dict'], metafile)
+ save_image(bbox_3D_image_pose_color, save_path, '{}_bbox3Dposecolor'.format(filename))
+
+ # NPCS image
+ npcs_image = visu_NPCS_map(anno_dict['npcs_map'], anno_dict['instance_segmentation'])
+ save_image(npcs_image, save_path, '{}_NPCS'.format(filename))
+
+ # point cloud with 3D semantic bbox
+ visu_point_cloud_with_bbox_semantic(rgb_image, depth_map, anno_dict['bbox_pose_dict'], metafile)
+
+ # point cloud with 3D pose color bbox
+ visu_point_cloud_with_bbox_pose_color(rgb_image, depth_map, anno_dict['bbox_pose_dict'], metafile)
+
+ # point cloud of NPCS and 3D pose color bbox
+ visu_NPCS_in_3D_with_bbox_pose_color(rgb_image, depth_map, anno_dict, metafile)
+
+ print('Done!')
+
\ No newline at end of file
diff --git a/gapartnet/.gitignore b/gapartnet/.gitignore
new file mode 100644
index 0000000..14500f5
--- /dev/null
+++ b/gapartnet/.gitignore
@@ -0,0 +1,12 @@
+wandb/
+raw*
+_raw*
+__pycache__*
+*__pycache__/
+__pycache__/
+*/__pycache__*
+GAPartNet_data/
+gallery/
+GAPartNet_result/
+.ipynb_checkpoints/
+GAPartNet_data
diff --git a/gapartnet/README.md b/gapartnet/README.md
new file mode 100644
index 0000000..ce090d9
--- /dev/null
+++ b/gapartnet/README.md
@@ -0,0 +1,3 @@
+Run the code following the instruction in the README.md in the upper folder.
+
+We publish our checkpoint in our dataset link, follow the dataset download instructions in the upper folder.
diff --git a/gapartnet/ckpt/.gitignore b/gapartnet/ckpt/.gitignore
new file mode 100644
index 0000000..53979fe
--- /dev/null
+++ b/gapartnet/ckpt/.gitignore
@@ -0,0 +1,2 @@
+*.ckpt
+ckpt*
\ No newline at end of file
diff --git a/gapartnet/data/.gitignore b/gapartnet/data/.gitignore
new file mode 100644
index 0000000..e5853e4
--- /dev/null
+++ b/gapartnet/data/.gitignore
@@ -0,0 +1,2 @@
+GAPartNet_All
+image*
\ No newline at end of file
diff --git a/gapartnet/data/README.md b/gapartnet/data/README.md
new file mode 100644
index 0000000..6415502
--- /dev/null
+++ b/gapartnet/data/README.md
@@ -0,0 +1,23 @@
+place or link data folder here~
+An example of the data folder is as follows:
+```
+data
+ - GAPartNet_All
+ - train
+ - meta
+ ...(xxx.txt)
+ - pth
+ ...(xxx.pth)
+ - train_gt
+ ...(xxx.txt)
+ - val
+ - val_gt
+ - test_intra
+ - test_intra_gt
+ - test_inter
+ - test_inter_gt
+ - image_kuafu (if you need, just for visualization, not necessary)
+ ...(xxx.png)
+ - .gitignore
+ - README.md
+```
\ No newline at end of file
diff --git a/gapartnet/data/nopart.txt b/gapartnet/data/nopart.txt
new file mode 100644
index 0000000..e167454
--- /dev/null
+++ b/gapartnet/data/nopart.txt
@@ -0,0 +1 @@
+data/GAPartNet_All/test_inter/pth/Phone_103828_00_004.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_031.pth data/GAPartNet_All/test_inter/pth/Phone_103941_00_029.pth data/GAPartNet_All/test_inter/pth/Table_21467_00_006.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_017.pth data/GAPartNet_All/test_inter/pth/Phone_103941_00_020.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_001.pth data/GAPartNet_All/test_inter/pth/Table_21467_00_028.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_009.pth data/GAPartNet_All/test_inter/pth/Phone_103941_00_021.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_006.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_000.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_010.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_012.pth data/GAPartNet_All/test_inter/pth/Table_26899_00_031.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_022.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_006.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_025.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_030.pth data/GAPartNet_All/test_inter/pth/Phone_103941_00_030.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_003.pth data/GAPartNet_All/test_inter/pth/Table_30739_00_016.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_014.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_021.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_011.pth data/GAPartNet_All/test_inter/pth/TrashCan_102192_00_030.pth data/GAPartNet_All/test_inter/pth/Table_21467_00_025.pth data/GAPartNet_All/test_inter/pth/Phone_103941_00_011.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_016.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_005.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_001.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_006.pth data/GAPartNet_All/test_inter/pth/Phone_103941_00_019.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_000.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_023.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_016.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_015.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_011.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_008.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_003.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_031.pth data/GAPartNet_All/test_inter/pth/TrashCan_102256_00_007.pth data/GAPartNet_All/test_inter/pth/Phone_103941_00_000.pth data/GAPartNet_All/test_inter/pth/Table_26899_00_006.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_028.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_018.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_024.pth data/GAPartNet_All/test_inter/pth/Table_26899_00_015.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_017.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_025.pth data/GAPartNet_All/test_inter/pth/TrashCan_102192_00_024.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_001.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_007.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_019.pth data/GAPartNet_All/test_inter/pth/Table_21467_00_012.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_027.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_019.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_004.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_013.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_027.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_023.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_024.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_019.pth data/GAPartNet_All/test_inter/pth/Table_21467_00_019.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_016.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_021.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_008.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_014.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_018.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_026.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_027.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_010.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_008.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_007.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_005.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_021.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_004.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_029.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_002.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_010.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_025.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_031.pth data/GAPartNet_All/test_inter/pth/Table_21467_00_018.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_017.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_002.pth data/GAPartNet_All/test_inter/pth/Phone_103941_00_010.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_029.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_028.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_020.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_013.pth data/GAPartNet_All/test_inter/pth/Table_26899_00_010.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_026.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_005.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_014.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_024.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_020.pth data/GAPartNet_All/test_inter/pth/TrashCan_102192_00_018.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_030.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_030.pth data/GAPartNet_All/test_inter/pth/Table_21467_00_027.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_022.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_026.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_020.pth data/GAPartNet_All/test_inter/pth/TrashCan_102200_00_020.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_007.pth data/GAPartNet_All/test_inter/pth/Phone_103941_00_028.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_018.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_013.pth data/GAPartNet_All/test_inter/pth/Table_21467_00_016.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_009.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_023.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_000.pth data/GAPartNet_All/test_inter/pth/Table_30739_00_003.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_028.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_009.pth data/GAPartNet_All/test_inter/pth/Phone_103593_00_012.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_002.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_011.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_012.pth data/GAPartNet_All/test_inter/pth/Phone_103828_00_015.pth data/GAPartNet_All/test_inter/pth/Phone_103886_00_022.pth data/GAPartNet_All/train/pth/Camera_102417_00_020.pth data/GAPartNet_All/train/pth/Camera_102417_00_021.pth data/GAPartNet_All/train/pth/Toilet_102682_00_004.pth data/GAPartNet_All/train/pth/Camera_102417_00_003.pth data/GAPartNet_All/train/pth/Camera_102417_00_022.pth data/GAPartNet_All/train/pth/Toilet_102682_00_017.pth data/GAPartNet_All/train/pth/CoffeeMachine_103127_00_021.pth data/GAPartNet_All/train/pth/Camera_102417_00_016.pth data/GAPartNet_All/train/pth/Camera_102532_00_026.pth data/GAPartNet_All/train/pth/Toilet_102682_00_009.pth data/GAPartNet_All/train/pth/CoffeeMachine_103127_00_011.pth data/GAPartNet_All/train/pth/Toilet_102682_00_023.pth data/GAPartNet_All/train/pth/Camera_102417_00_012.pth data/GAPartNet_All/train/pth/Camera_102417_00_017.pth data/GAPartNet_All/train/pth/Camera_102417_00_004.pth data/GAPartNet_All/train/pth/Camera_102417_00_019.pth data/GAPartNet_All/train/pth/Camera_102417_00_006.pth data/GAPartNet_All/train/pth/Camera_102417_00_023.pth data/GAPartNet_All/train/pth/Camera_102417_00_005.pth data/GAPartNet_All/train/pth/Camera_102417_00_014.pth data/GAPartNet_All/train/pth/CoffeeMachine_103127_00_013.pth data/GAPartNet_All/train/pth/Bucket_100464_00_026.pth data/GAPartNet_All/train/pth/Camera_102417_00_000.pth data/GAPartNet_All/train/pth/Camera_102417_00_009.pth data/GAPartNet_All/train/pth/Camera_102417_00_027.pth data/GAPartNet_All/train/pth/Camera_102417_00_002.pth data/GAPartNet_All/train/pth/Camera_102417_00_010.pth data/GAPartNet_All/train/pth/CoffeeMachine_103127_00_014.pth data/GAPartNet_All/train/pth/Camera_102417_00_024.pth data/GAPartNet_All/train/pth/Toilet_102682_00_015.pth data/GAPartNet_All/train/pth/Camera_102417_00_013.pth data/GAPartNet_All/train/pth/Camera_102417_00_008.pth data/GAPartNet_All/train/pth/CoffeeMachine_103127_00_023.pth data/GAPartNet_All/train/pth/Camera_102417_00_011.pth data/GAPartNet_All/train/pth/CoffeeMachine_103127_00_015.pth data/GAPartNet_All/train/pth/Camera_102417_00_026.pth data/GAPartNet_All/train/pth/Camera_102417_00_018.pth data/GAPartNet_All/train/pth/Camera_102417_00_025.pth data/GAPartNet_All/val/pth/Camera_102417_00_029.pth data/GAPartNet_All/val/pth/Camera_102417_00_031.pth data/GAPartNet_All/val/pth/Camera_102417_00_030.pth data/GAPartNet_All/val/pth/Camera_102417_00_028.pth data/GAPartNet_All/val/pth/CoffeeMachine_103127_00_028.pth data/GAPartNet_All/val/pth/Printer_103859_00_028.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_020.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_004.pth data/GAPartNet_All/test_intra/pth/CoffeeMachine_103082_00_029.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_009.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_025.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_027.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_000.pth data/GAPartNet_All/test_intra/pth/CoffeeMachine_103082_00_020.pth data/GAPartNet_All/test_intra/pth/CoffeeMachine_103082_00_015.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_008.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_026.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_002.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_016.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_006.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_007.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_031.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_005.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_017.pth data/GAPartNet_All/test_intra/pth/CoffeeMachine_103082_00_030.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_030.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_021.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_003.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_015.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_029.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_013.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_018.pth data/GAPartNet_All/test_intra/pth/CoffeeMachine_103082_00_021.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_001.pth data/GAPartNet_All/test_intra/pth/CoffeeMachine_103082_00_025.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_019.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_010.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_023.pth data/GAPartNet_All/test_intra/pth/CoffeeMachine_103082_00_001.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_014.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_011.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_012.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_022.pth data/GAPartNet_All/test_intra/pth/CoffeeMachine_103082_00_027.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_024.pth data/GAPartNet_All/test_intra/pth/CoffeeMachine_103082_00_009.pth data/GAPartNet_All/test_intra/pth/Camera_102472_00_028.pth
\ No newline at end of file
diff --git a/gapartnet/dataset/data_utils.py b/gapartnet/dataset/data_utils.py
new file mode 100644
index 0000000..d1109c4
--- /dev/null
+++ b/gapartnet/dataset/data_utils.py
@@ -0,0 +1,36 @@
+from typing import Any, Iterator
+
+import torch
+import torch.distributed as dist
+import torchdata.datapipes as dp
+
+
+def trivial_batch_collator(batch):
+ """
+ A batch collator that does nothing.
+ """
+ return batch
+
+
+@dp.functional_datapipe("distributed_sharding_filter")
+class DistributedShardingFilter(dp.iter.ShardingFilter):
+ def __init__(self, source_datapipe: dp.iter.IterDataPipe) -> None:
+ super().__init__(source_datapipe)
+
+ self.rank = 0
+ self.world_size = 1
+ if dist.is_available() and dist.is_initialized():
+ self.rank = dist.get_rank()
+ self.world_size = dist.get_world_size()
+ self.apply_sharding(self.world_size, self.rank)
+
+ def __iter__(self) -> Iterator[Any]:
+ num_workers = self.world_size
+ worker_id = self.rank
+ worker_info = torch.utils.data.get_worker_info()
+ if worker_info is not None:
+ worker_id = worker_id + worker_info.id * num_workers
+ num_workers *= worker_info.num_workers
+ self.apply_sharding(num_workers, worker_id)
+
+ yield from super().__iter__()
diff --git a/gapartnet/dataset/gapartnet.py b/gapartnet/dataset/gapartnet.py
new file mode 100644
index 0000000..03170d1
--- /dev/null
+++ b/gapartnet/dataset/gapartnet.py
@@ -0,0 +1,503 @@
+import copy
+import json
+from functools import partial
+from pathlib import Path
+from typing import Optional, Tuple, Union
+
+import numpy as np
+from lightning.pytorch import LightningDataModule
+import torch
+import torchdata.datapipes as dp
+from epic_ops.voxelize import voxelize
+from torch.utils.data import DataLoader
+from torch.utils.data import Dataset
+import random
+from glob import glob
+
+from structure.point_cloud import PointCloud
+from dataset import data_utils
+from misc.info import OBJECT_NAME2ID
+
+
+class GAPartNetDataset(Dataset):
+ def __init__(
+ self,
+ root_dir: Union[str, Path] = "",
+ shuffle: bool = False,
+ max_points: int = 20000,
+ augmentation: bool = False,
+ voxel_size: Tuple[float, float, float] = (1 / 100, 1 / 100, 1 / 100),
+ few_shot = False,
+ few_shot_num = 512,
+ pos_jitter: float = 0.,
+ color_jitter: float = 0.,
+ flip_prob: float = 0.,
+ rotate_prob: float = 0.,
+ nopart_path: str = "data/nopart.txt",
+ no_label = False,
+ ):
+ file_paths=glob(str(root_dir) + "/*.pth")
+ self.nopart_files = open(nopart_path, "r").readlines()[0].split(" ")
+ self.nopart_names = [p.split("/")[-1].split(".")[0] for p in self.nopart_files]
+ file_paths = [path for path in file_paths
+ if path.split("/")[-1].split(".")[0] not in self.nopart_names]
+ if shuffle:
+ random.shuffle(file_paths)
+ if few_shot:
+ file_paths = file_paths[:few_shot_num]
+ self.pc_paths = file_paths
+ self.no_label = no_label
+ self.augmentation = augmentation
+ self.pos_jitter = pos_jitter
+ self.color_jitter = color_jitter
+ self.flip_prob = flip_prob
+ self.rotate_prob = rotate_prob
+ self.voxel_size = voxel_size
+ self.max_points = max_points
+
+ def __len__(self):
+ return len(self.pc_paths)
+
+ def __getitem__(self, idx):
+ path = self.pc_paths[idx]
+ file = load_data(path, no_label = self.no_label)
+ if not bool((file.instance_labels != -100).any()):
+ import ipdb; ipdb.set_trace()
+ file = downsample(file, max_points=self.max_points)
+ file = compact_instance_labels(file)
+ if self.augmentation:
+ file = apply_augmentations(file,
+ pos_jitter=self.pos_jitter,
+ color_jitter=self.color_jitter,
+ flip_prob=self.flip_prob,
+ rotate_prob=self.rotate_prob,)
+ file = generate_inst_info(file)
+ file = file.to_tensor()
+ file = apply_voxelization(file, voxel_size=self.voxel_size)
+ return file
+
+
+def apply_augmentations(
+ pc: PointCloud,
+ *,
+ pos_jitter: float = 0.,
+ color_jitter: float = 0.,
+ flip_prob: float = 0.,
+ rotate_prob: float = 0.,
+) -> PointCloud:
+ pc = copy.copy(pc)
+
+ m = np.eye(3)
+ if pos_jitter > 0:
+ m += np.random.randn(3, 3) * pos_jitter
+
+ if flip_prob > 0:
+ if np.random.rand() < flip_prob:
+ m[0, 0] = -m[0, 0]
+
+ if rotate_prob > 0:
+ if np.random.rand() < flip_prob:
+ theta = np.random.rand() * np.pi * 2
+ m = m @ np.asarray([
+ [np.cos(theta), np.sin(theta), 0],
+ [-np.sin(theta), np.cos(theta), 0],
+ [0, 0, 1],
+ ])
+
+ pc.points = pc.points.copy()
+ pc.points[:, :3] = pc.points[:, :3] @ m
+
+ if color_jitter > 0:
+ pc.points[:, 3:] += np.random.randn(
+ 1, pc.points.shape[1] - 3
+ ) * color_jitter
+
+ return pc
+
+
+def downsample(pc: PointCloud, *, max_points: int = 20000) -> PointCloud:
+ pc = copy.copy(pc)
+
+ num_points = pc.points.shape[0]
+
+ if num_points > max_points:
+ assert False, (num_points, max_points)
+
+ return pc
+
+
+def compact_instance_labels(pc: PointCloud) -> PointCloud:
+ pc = copy.copy(pc)
+
+ valid_mask = pc.instance_labels >= 0
+ instance_labels = pc.instance_labels[valid_mask]
+ _, instance_labels = np.unique(instance_labels, return_inverse=True)
+ pc.instance_labels[valid_mask] = instance_labels
+
+ return pc
+
+
+def generate_inst_info(pc: PointCloud) -> PointCloud:
+ pc = copy.copy(pc)
+
+ num_points = pc.points.shape[0]
+
+ num_instances = int(pc.instance_labels.max()) + 1
+ instance_regions = np.zeros((num_points, 9), dtype=np.float32)
+ num_points_per_instance = []
+ instance_sem_labels = []
+
+ assert num_instances > 0
+
+ for i in range(num_instances):
+ indices = np.where(pc.instance_labels == i)[0]
+
+ xyz_i = pc.points[indices, :3]
+ min_i = xyz_i.min(0)
+ max_i = xyz_i.max(0)
+ mean_i = xyz_i.mean(0)
+ instance_regions[indices, 0:3] = mean_i
+ instance_regions[indices, 3:6] = min_i
+ instance_regions[indices, 6:9] = max_i
+
+ num_points_per_instance.append(indices.shape[0])
+ instance_sem_labels.append(int(pc.sem_labels[indices[0]]))
+
+ pc.num_instances = num_instances
+ pc.instance_regions = instance_regions
+ pc.num_points_per_instance = np.asarray(num_points_per_instance, dtype=np.int32)
+ pc.instance_sem_labels = np.asarray(instance_sem_labels, dtype=np.int32)
+
+ return pc
+
+
+def apply_voxelization(
+ pc: PointCloud, *, voxel_size: Tuple[float, float, float]
+) -> PointCloud:
+ pc = copy.copy(pc)
+
+ num_points = pc.points.shape[0]
+ pt_xyz = pc.points[:, :3]
+ points_range_min = pt_xyz.min(0)[0] - 1e-4
+ points_range_max = pt_xyz.max(0)[0] + 1e-4
+ voxel_features, voxel_coords, _, pc_voxel_id = voxelize(
+ pt_xyz, pc.points,
+ batch_offsets=torch.as_tensor([0, num_points], dtype=torch.int64, device = pt_xyz.device),
+ voxel_size=torch.as_tensor(voxel_size, device = pt_xyz.device),
+ points_range_min=torch.as_tensor(points_range_min, device = pt_xyz.device),
+ points_range_max=torch.as_tensor(points_range_max, device = pt_xyz.device),
+ reduction="mean",
+ )
+ assert (pc_voxel_id >= 0).all()
+
+ voxel_coords_range = (voxel_coords.max(0)[0] + 1).clamp(min=128, max=None)
+
+ pc.voxel_features = voxel_features
+ pc.voxel_coords = voxel_coords
+ pc.voxel_coords_range = voxel_coords_range.tolist()
+ pc.pc_voxel_id = pc_voxel_id
+
+ return pc
+
+
+def load_data(file_path: str, no_label: bool = False):
+ if not no_label:
+ pc_data = torch.load(file_path)
+ else:
+ # testing data type, e.g. real world point cloud without GT semantic label.
+ raise NotImplementedError
+
+ pc_id = file_path.split("/")[-1].split(".")[0]
+ object_cat = OBJECT_NAME2ID[pc_id.split("_")[0]]
+
+
+ return PointCloud(
+ pc_id=pc_id,
+ obj_cat=object_cat,
+ points=np.concatenate(
+ [pc_data[0], pc_data[1]],
+ axis=-1, dtype=np.float32,
+ ),
+ sem_labels=pc_data[2].astype(np.int64),
+ instance_labels=pc_data[3].astype(np.int32),
+ gt_npcs=pc_data[4].astype(np.float32),
+ )
+
+def from_folder(
+ root_dir: Union[str, Path] = "",
+ split: str = "train_new",
+ shuffle: bool = False,
+ max_points: int = 20000,
+ augmentation: bool = False,
+ voxel_size: Tuple[float, float, float] = (1 / 100, 1 / 100, 1 / 100),
+ pos_jitter: float = 0.,
+ color_jitter: float = 0.1,
+ flip_prob: float = 0.,
+ rotate_prob: float = 0.,
+):
+ root_dir = Path(root_dir)
+
+ with open(root_dir / f"{split}.json") as f:
+ file_names = json.load(f)
+
+ pipe = dp.iter.IterableWrapper(file_names)
+
+ # pipe = pipe.filter(filter_fn=lambda x: x == "pth_new/StorageFurniture_41004_00_013.pth")
+
+ pipe = pipe.distributed_sharding_filter()
+ if shuffle:
+ pipe = pipe.shuffle()
+
+ # Load data
+ pipe = pipe.map(partial(load_data, root_dir=root_dir))
+ # Remove empty samples
+ pipe = pipe.filter(filter_fn=lambda x: bool((x.instance_labels != -100).any()))
+
+ # Downsample
+ # TODO: Crop
+ pipe = pipe.map(partial(downsample, max_points=max_points))
+ pipe = pipe.map(compact_instance_labels)
+
+ # Augmentations
+ if augmentation:
+ pipe = pipe.map(partial(
+ apply_augmentations,
+ pos_jitter=pos_jitter,
+ color_jitter=color_jitter,
+ flip_prob=flip_prob,
+ rotate_prob=rotate_prob,
+ ))
+
+ # Generate instance info
+ pipe = pipe.map(generate_inst_info)
+
+ # To tensor
+ pipe = pipe.map(lambda pc: pc.to_tensor())
+
+ # Voxelization
+ pipe = pipe.map(partial(apply_voxelization, voxel_size=voxel_size))
+
+ return pipe
+
+
+class GAPartNetInst(LightningDataModule):
+ def __init__(
+ self,
+ root_dir: str,
+ max_points: int = 20000,
+ voxel_size: Tuple[float, float, float] = (1 / 100, 1 / 100, 1 / 100),
+ train_batch_size: int = 32,
+ val_batch_size: int = 32,
+ test_batch_size: int = 32,
+ num_workers: int = 16,
+ pos_jitter: float = 0.,
+ color_jitter: float = 0.1,
+ flip_prob: float = 0.,
+ rotate_prob: float = 0.,
+ train_few_shot: bool = False,
+ val_few_shot: bool = False,
+ intra_few_shot: bool = False,
+ inter_few_shot: bool = False,
+ few_shot_num: int = 256,
+ ):
+ super().__init__()
+ self.save_hyperparameters()
+
+ self.root_dir = root_dir
+ self.max_points = max_points
+ self.voxel_size = voxel_size
+
+ self.train_batch_size = train_batch_size
+ self.val_batch_size = val_batch_size
+ self.test_batch_size = test_batch_size
+ self.num_workers = num_workers
+
+ self.pos_jitter = pos_jitter
+ self.color_jitter = color_jitter
+ self.flip_prob = flip_prob
+ self.rotate_prob = rotate_prob
+
+ # debug
+ self.train_few_shot = train_few_shot
+ self.val_few_shot = val_few_shot
+ self.intra_few_shot = intra_few_shot
+ self.inter_few_shot = inter_few_shot
+ self.few_shot_num = few_shot_num
+
+ def setup(self, stage: Optional[str] = None):
+ if stage in (None, "fit", "validate"):
+ self.train_data_files = GAPartNetDataset(
+ Path(self.root_dir) / "train" / "pth",
+ shuffle=True,
+ max_points=self.max_points,
+ augmentation=True,
+ voxel_size=self.voxel_size,
+ few_shot = self.train_few_shot,
+ few_shot_num=self.few_shot_num,
+ pos_jitter = self.pos_jitter,
+ color_jitter = self.color_jitter,
+ flip_prob = self.flip_prob,
+ rotate_prob = self.rotate_prob,
+ )
+
+ self.val_data_files = GAPartNetDataset(
+ Path(self.root_dir) / "val" / "pth",
+ shuffle=True,
+ max_points=self.max_points,
+ augmentation=False,
+ voxel_size=self.voxel_size,
+ few_shot = self.val_few_shot,
+ few_shot_num=self.few_shot_num,
+ pos_jitter = self.pos_jitter,
+ color_jitter = self.color_jitter,
+ flip_prob = self.flip_prob,
+ rotate_prob = self.rotate_prob,
+ )
+
+ self.intra_data_files = GAPartNetDataset(
+ Path(self.root_dir) / "test_intra" / "pth",
+ shuffle=True,
+ max_points=self.max_points,
+ augmentation=False,
+ voxel_size=self.voxel_size,
+ few_shot = self.intra_few_shot,
+ few_shot_num=self.few_shot_num,
+ pos_jitter = self.pos_jitter,
+ color_jitter = self.color_jitter,
+ flip_prob = self.flip_prob,
+ rotate_prob = self.rotate_prob,
+ )
+
+ self.inter_data_files = GAPartNetDataset(
+ Path(self.root_dir) / "test_inter" / "pth",
+ shuffle=True,
+ max_points=self.max_points,
+ augmentation=False,
+ voxel_size=self.voxel_size,
+ few_shot = self.inter_few_shot,
+ few_shot_num=self.few_shot_num,
+ pos_jitter = self.pos_jitter,
+ color_jitter = self.color_jitter,
+ flip_prob = self.flip_prob,
+ rotate_prob = self.rotate_prob,
+ )
+
+ if stage in (None, "test"):
+ self.val_data_files = GAPartNetDataset(
+ Path(self.root_dir) / "val" / "pth",
+ shuffle=True,
+ max_points=self.max_points,
+ augmentation=False,
+ voxel_size=self.voxel_size,
+ few_shot = self.val_few_shot,
+ few_shot_num=self.few_shot_num,
+ pos_jitter = self.pos_jitter,
+ color_jitter = self.color_jitter,
+ flip_prob = self.flip_prob,
+ rotate_prob = self.rotate_prob,
+ )
+
+ self.intra_data_files = GAPartNetDataset(
+ Path(self.root_dir) / "test_intra" / "pth",
+ shuffle=True,
+ max_points=self.max_points,
+ augmentation=False,
+ voxel_size=self.voxel_size,
+ few_shot = self.intra_few_shot,
+ few_shot_num=self.few_shot_num,
+ pos_jitter = self.pos_jitter,
+ color_jitter = self.color_jitter,
+ flip_prob = self.flip_prob,
+ rotate_prob = self.rotate_prob,
+ )
+
+ self.inter_data_files = GAPartNetDataset(
+ Path(self.root_dir) / "test_inter" / "pth",
+ shuffle=True,
+ max_points=self.max_points,
+ augmentation=False,
+ voxel_size=self.voxel_size,
+ few_shot = self.inter_few_shot,
+ few_shot_num=self.few_shot_num,
+ pos_jitter = self.pos_jitter,
+ color_jitter = self.color_jitter,
+ flip_prob = self.flip_prob,
+ rotate_prob = self.rotate_prob,
+ )
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_data_files,
+ batch_size=self.train_batch_size,
+ shuffle=True,
+ num_workers=self.num_workers,
+ collate_fn=data_utils.trivial_batch_collator,
+ pin_memory=True,
+ drop_last=True,
+ )
+
+ def val_dataloader(self):
+ return [
+ DataLoader(
+ self.val_data_files,
+ batch_size=self.val_batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ collate_fn=data_utils.trivial_batch_collator,
+ pin_memory=True,
+ drop_last=False,
+ ),
+
+ DataLoader(
+ self.intra_data_files,
+ batch_size=self.val_batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ collate_fn=data_utils.trivial_batch_collator,
+ pin_memory=True,
+ drop_last=False,
+ ),
+
+ DataLoader(
+ self.inter_data_files,
+ batch_size=self.val_batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ collate_fn=data_utils.trivial_batch_collator,
+ pin_memory=True,
+ drop_last=False,
+ ),
+ ]
+
+ def test_dataloader(self):
+ return [
+ DataLoader(
+ self.val_data_files,
+ batch_size=self.val_batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ collate_fn=data_utils.trivial_batch_collator,
+ pin_memory=True,
+ drop_last=False,
+ ),
+
+ DataLoader(
+ self.intra_data_files,
+ batch_size=self.val_batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ collate_fn=data_utils.trivial_batch_collator,
+ pin_memory=True,
+ drop_last=False,
+ ),
+
+ DataLoader(
+ self.inter_data_files,
+ batch_size=self.val_batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ collate_fn=data_utils.trivial_batch_collator,
+ pin_memory=True,
+ drop_last=False,
+ ),
+ ]
\ No newline at end of file
diff --git a/gapartnet/gapartnet.yaml b/gapartnet/gapartnet.yaml
new file mode 100644
index 0000000..2175ea9
--- /dev/null
+++ b/gapartnet/gapartnet.yaml
@@ -0,0 +1,91 @@
+model:
+ class_path: network.model.GAPartNet
+ init_args:
+ debug: False
+ in_channels: 6
+ num_part_classes: 10
+ backbone_type: SparseUNet
+ backbone_cfg:
+ channels: [16,32,48,64,80,96,112]
+ block_repeat: 2
+ instance_seg_cfg:
+ ball_query_radius: 0.04
+ max_num_points_per_query: 50
+ min_num_points_per_proposal: 5 # 50 for scannet?
+ max_num_points_per_query_shift: 300
+ score_fullscale: 28
+ score_scale: 50
+ learning_rate: 0.001
+ ignore_sem_label: -100
+ use_sem_focal_loss: true
+ use_sem_dice_loss: true
+ training_schedule: [0,0]
+ val_nms_iou_threshold: 0.3
+ val_ap_iou_threshold: 0.5
+ symmetry_indices: [0, 1, 3, 3, 2, 0, 3, 2, 4, 1]
+ visualize_cfg:
+ visualize: True
+ visualize_dir: visu
+ sample_num: 10
+ RAW_IMG_ROOT: "data/image_kuafu"
+ GAPARTNET_DATA_ROOT: "data/GAPartNet_All"
+ SAVE_ROOT: "output/GAPartNet_result"
+ save_option: ["raw", "pc", "sem_pred", "sem_gt", "ins_pred", "ins_gt", "npcs_pred", "npcs_gt", "bbox_gt", "bbox_gt_pure", "bbox_pred", "bbox_pred_pure"]
+
+
+data:
+ class_path: dataset.gapartnet.GAPartNetInst
+ init_args:
+ root_dir: data/GAPartNet_All
+ max_points: 20000
+ voxel_size: [0.01,0.01,0.01]
+ train_batch_size: 64
+ val_batch_size: 32
+ test_batch_size: 32
+ num_workers: 16
+ pos_jitter: 0.1
+ color_jitter: 0.3
+ flip_prob: 0.3
+ rotate_prob: 0.3
+ train_few_shot: true
+ val_few_shot: true
+ intra_few_shot: true
+ inter_few_shot: true
+ few_shot_num: 640
+
+trainer:
+ max_epochs: 700
+ accelerator: gpu
+ strategy: auto
+ devices: auto
+ num_nodes: 1
+ # resume_from_checkpoint: /data2/haoran/PLAffordance/wandb/perception/1osg9z17/checkpoints/epoch_118_miou_0.00.ckpt
+
+ # logger:
+ # class_path: WandbLogger
+ # init_args:
+ # save_dir: wandb
+ # project: perception
+ # entity: haoran-geng
+ # group: 1024_new
+ # name: test
+ # notes: "GAPartNet"
+ # tags: ["GAPartNet", "score", "npcs"]
+ # save_code: True
+ # mode: dryrun
+ callbacks:
+ - class_path: RichProgressBar
+ init_args:
+ leave: True
+ - class_path: ModelCheckpoint
+ init_args:
+ filename: "epoch_{epoch:03d}_mAP_{monitor_metrics/mean_mAP:.2f}"
+ auto_insert_metric_name: False
+ save_top_k: 5
+ mode: max
+ monitor: "monitor_metrics/mean_mAP"
+ every_n_epochs: 1
+
+ default_root_dir: wandb
+
+seed_everything: 23333
diff --git a/gapartnet/gapartnet/__init__.py b/gapartnet/gapartnet/__init__.py
new file mode 100644
index 0000000..5c0cf28
--- /dev/null
+++ b/gapartnet/gapartnet/__init__.py
@@ -0,0 +1 @@
+from .version import __version__ # noqa: F401
diff --git a/gapartnet/gapartnet/datasets/__init__.py b/gapartnet/gapartnet/datasets/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/gapartnet/gapartnet/datasets/gapartnet_new.py b/gapartnet/gapartnet/datasets/gapartnet_new.py
new file mode 100644
index 0000000..5fa2738
--- /dev/null
+++ b/gapartnet/gapartnet/datasets/gapartnet_new.py
@@ -0,0 +1,320 @@
+import copy
+import json
+from functools import partial
+from pathlib import Path
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import pytorch_lightning as pl
+import torch
+import torchdata.datapipes as dp
+from epic_ops.voxelize import voxelize
+from torch.utils.data import DataLoader
+
+from gapartnet.structures.point_cloud import PointCloud
+from gapartnet.utils import data as data_utils
+
+
+def apply_augmentations(
+ pc: PointCloud,
+ *,
+ pos_jitter: float = 0.,
+ color_jitter: float = 0.,
+ flip_prob: float = 0.,
+ rotate_prob: float = 0.,
+) -> PointCloud:
+ pc = copy.copy(pc)
+
+ m = np.eye(3)
+ if pos_jitter > 0:
+ m += np.random.randn(3, 3) * pos_jitter
+
+ if flip_prob > 0:
+ if np.random.rand() < flip_prob:
+ m[0, 0] = -m[0, 0]
+
+ if rotate_prob > 0:
+ if np.random.rand() < flip_prob:
+ theta = np.random.rand() * np.pi * 2
+ m = m @ np.asarray([
+ [np.cos(theta), np.sin(theta), 0],
+ [-np.sin(theta), np.cos(theta), 0],
+ [0, 0, 1],
+ ])
+
+ pc.points = pc.points.copy()
+ pc.points[:, :3] = pc.points[:, :3] @ m
+
+ if color_jitter > 0:
+ pc.points[:, 3:] += np.random.randn(
+ 1, pc.points.shape[1] - 3
+ ) * color_jitter
+
+ return pc
+
+
+def downsample(pc: PointCloud, *, max_points: int = 20000) -> PointCloud:
+ pc = copy.copy(pc)
+
+ num_points = pc.points.shape[0]
+
+ if num_points > max_points:
+ assert False, (num_points, max_points)
+
+ return pc
+
+
+def compact_instance_labels(pc: PointCloud) -> PointCloud:
+ pc = copy.copy(pc)
+
+ valid_mask = pc.instance_labels >= 0
+ instance_labels = pc.instance_labels[valid_mask]
+ _, instance_labels = np.unique(instance_labels, return_inverse=True)
+ pc.instance_labels[valid_mask] = instance_labels
+
+ return pc
+
+
+def generate_inst_info(pc: PointCloud) -> PointCloud:
+ pc = copy.copy(pc)
+
+ num_points = pc.points.shape[0]
+
+ num_instances = int(pc.instance_labels.max()) + 1
+ instance_regions = np.zeros((num_points, 9), dtype=np.float32)
+ num_points_per_instance = []
+ instance_sem_labels = []
+
+ for i in range(num_instances):
+ indices = np.where(pc.instance_labels == i)[0]
+
+ xyz_i = pc.points[indices, :3]
+ min_i = xyz_i.min(0)
+ max_i = xyz_i.max(0)
+ mean_i = xyz_i.mean(0)
+ instance_regions[indices, 0:3] = mean_i
+ instance_regions[indices, 3:6] = min_i
+ instance_regions[indices, 6:9] = max_i
+
+ num_points_per_instance.append(indices.shape[0])
+ instance_sem_labels.append(int(pc.sem_labels[indices[0]]))
+
+ pc.num_instances = num_instances
+ pc.instance_regions = instance_regions
+ pc.num_points_per_instance = np.asarray(num_points_per_instance, dtype=np.int32)
+ pc.instance_sem_labels = np.asarray(instance_sem_labels, dtype=np.int32)
+
+ return pc
+
+
+def apply_voxelization(
+ pc: PointCloud, *, voxel_size: Tuple[float, float, float]
+) -> PointCloud:
+ pc = copy.copy(pc)
+
+ num_points = pc.points.shape[0]
+ pt_xyz = pc.points[:, :3]
+ points_range_min = pt_xyz.min(0)[0] - 1e-4
+ points_range_max = pt_xyz.max(0)[0] + 1e-4
+ voxel_features, voxel_coords, _, pc_voxel_id = voxelize(
+ pt_xyz, pc.points,
+ batch_offsets=torch.as_tensor([0, num_points], dtype=torch.int64, device = pt_xyz.device),
+ voxel_size=torch.as_tensor(voxel_size, device = pt_xyz.device),
+ points_range_min=torch.as_tensor(points_range_min, device = pt_xyz.device),
+ points_range_max=torch.as_tensor(points_range_max, device = pt_xyz.device),
+ reduction="mean",
+ )
+ assert (pc_voxel_id >= 0).all()
+
+ voxel_coords_range = (voxel_coords.max(0)[0] + 1).clamp(min=128, max=None)
+
+ pc.voxel_features = voxel_features
+ pc.voxel_coords = voxel_coords
+ pc.voxel_coords_range = voxel_coords_range.tolist()
+ pc.pc_voxel_id = pc_voxel_id
+
+ return pc
+
+
+def load_data(file_name: str, *, root_dir: Path):
+ pc_data = torch.load(root_dir / file_name)
+
+ scene_id = file_name.split("/")[-1].split(".")[0]
+
+ assert pc_data["xyz"].dtype == torch.float32
+
+ return PointCloud(
+ scene_id=scene_id,
+ points=np.concatenate(
+ [pc_data["xyz"].numpy(), pc_data["rgb"].numpy()],
+ axis=-1, dtype=np.float32,
+ ),
+ sem_labels=pc_data["sem_labels"].numpy().astype(np.int64),
+ instance_labels=pc_data["instance_labels"].numpy().astype(np.int32),
+ gt_npcs=pc_data["gt_npcs"].numpy().astype(np.float32),
+ )
+
+
+def from_folder(
+ root_dir: Union[str, Path] = "",
+ split: str = "train_new",
+ shuffle: bool = False,
+ max_points: int = 20000,
+ augmentation: bool = False,
+ voxel_size: Tuple[float, float, float] = (1 / 100, 1 / 100, 1 / 100),
+ pos_jitter: float = 0.,
+ color_jitter: float = 0.1,
+ flip_prob: float = 0.,
+ rotate_prob: float = 0.,
+):
+ root_dir = Path(root_dir)
+
+ with open(root_dir / f"{split}.json") as f:
+ file_names = json.load(f)
+
+ pipe = dp.iter.IterableWrapper(file_names)
+
+ # pipe = pipe.filter(filter_fn=lambda x: x == "pth_new/StorageFurniture_41004_00_013.pth")
+
+ pipe = pipe.distributed_sharding_filter()
+ if shuffle:
+ pipe = pipe.shuffle()
+
+ # Load data
+ pipe = pipe.map(partial(load_data, root_dir=root_dir))
+ # Remove empty samples
+ pipe = pipe.filter(filter_fn=lambda x: bool((x.instance_labels != -100).any()))
+
+ # Downsample
+ # TODO: Crop
+ pipe = pipe.map(partial(downsample, max_points=max_points))
+ pipe = pipe.map(compact_instance_labels)
+
+ # Augmentations
+ if augmentation:
+ pipe = pipe.map(partial(
+ apply_augmentations,
+ pos_jitter=pos_jitter,
+ color_jitter=color_jitter,
+ flip_prob=flip_prob,
+ rotate_prob=rotate_prob,
+ ))
+
+ # Generate instance info
+ pipe = pipe.map(generate_inst_info)
+
+ # To tensor
+ pipe = pipe.map(lambda pc: pc.to_tensor())
+
+ # Voxelization
+ pipe = pipe.map(partial(apply_voxelization, voxel_size=voxel_size))
+
+ return pipe
+
+
+class GAPartNetInst(pl.LightningDataModule):
+ def __init__(
+ self,
+ root_dir: str,
+ max_points: int = 20000,
+ voxel_size: Tuple[float, float, float] = (1 / 100, 1 / 100, 1 / 100),
+ train_batch_size: int = 32,
+ val_batch_size: int = 32,
+ test_batch_size: int = 32,
+ num_workers: int = 16,
+ pos_jitter: float = 0.,
+ color_jitter: float = 0.1,
+ flip_prob: float = 0.,
+ rotate_prob: float = 0.,
+ ):
+ super().__init__()
+ self.save_hyperparameters()
+
+ self.root_dir = root_dir
+ self.max_points = max_points
+ self.voxel_size = voxel_size
+
+ self.train_batch_size = train_batch_size
+ self.val_batch_size = val_batch_size
+ self.test_batch_size = test_batch_size
+ self.num_workers = num_workers
+
+ self.pos_jitter = pos_jitter
+ self.color_jitter = color_jitter
+ self.flip_prob = flip_prob
+ self.rotate_prob = rotate_prob
+
+ def setup(self, stage: Optional[str] = None):
+ if stage in (None, "fit", "validate"):
+ self.train_data_pipe = from_folder(
+ Path(self.root_dir),
+ split="train_new",
+ shuffle=True,
+ max_points=self.max_points,
+ augmentation=True,
+ voxel_size=self.voxel_size,
+ pos_jitter=self.pos_jitter,
+ color_jitter=self.color_jitter,
+ flip_prob=self.flip_prob,
+ rotate_prob=self.rotate_prob,
+ )
+
+ self.val_data_pipe = from_folder(
+ Path(self.root_dir),
+ split="val_new_intra",
+ shuffle=False,
+ max_points=self.max_points,
+ augmentation=False,
+ voxel_size=self.voxel_size,
+ pos_jitter=self.pos_jitter,
+ color_jitter=self.color_jitter,
+ flip_prob=self.flip_prob,
+ rotate_prob=self.rotate_prob,
+ )
+
+ if stage in (None, "test"):
+ self.test_data_pipe = from_folder(
+ Path(self.root_dir),
+ split="test_new",
+ shuffle=False,
+ max_points=self.max_points,
+ augmentation=False,
+ voxel_size=self.voxel_size,
+ pos_jitter=self.pos_jitter,
+ color_jitter=self.color_jitter,
+ flip_prob=self.flip_prob,
+ rotate_prob=self.rotate_prob,
+ )
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_data_pipe,
+ batch_size=self.train_batch_size,
+ shuffle=True,
+ num_workers=self.num_workers,
+ collate_fn=data_utils.trivial_batch_collator,
+ pin_memory=True,
+ drop_last=True,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.val_data_pipe,
+ batch_size=self.val_batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ collate_fn=data_utils.trivial_batch_collator,
+ pin_memory=True,
+ drop_last=False,
+ )
+
+ def test_dataloader(self):
+ return DataLoader(
+ self.test_data_pipe,
+ batch_size=self.test_batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ collate_fn=data_utils.trivial_batch_collator,
+ pin_memory=True,
+ drop_last=False,
+ )
diff --git a/gapartnet/gapartnet/losses/__init__.py b/gapartnet/gapartnet/losses/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/gapartnet/gapartnet/losses/dice_loss.py b/gapartnet/gapartnet/losses/dice_loss.py
new file mode 100644
index 0000000..8a2dfd7
--- /dev/null
+++ b/gapartnet/gapartnet/losses/dice_loss.py
@@ -0,0 +1,55 @@
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+
+
+def one_hot(
+ labels: torch.Tensor,
+ num_classes: int,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ eps: float = 1e-6,
+) -> torch.Tensor:
+ if not isinstance(labels, torch.Tensor):
+ raise TypeError(f"Input labels type is not a torch.Tensor. Got {type(labels)}")
+
+ if not labels.dtype == torch.int64:
+ raise ValueError(f"labels must be of the same dtype torch.int64. Got: {labels.dtype}")
+
+ if num_classes < 1:
+ raise ValueError("The number of classes must be bigger than one." " Got: {}".format(num_classes))
+
+ shape = labels.shape
+ one_hot = torch.zeros((shape[0], num_classes) + shape[1:], device=device, dtype=dtype)
+
+ return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + eps
+
+
+def dice_loss(input: torch.Tensor, target: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
+ if not isinstance(input, torch.Tensor):
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
+
+ if not len(input.shape) == 4:
+ raise ValueError(f"Invalid input shape, we expect BxNxHxW. Got: {input.shape}")
+
+ if not input.shape[-2:] == target.shape[-2:]:
+ raise ValueError(f"input and target shapes must be the same. Got: {input.shape} and {target.shape}")
+
+ if not input.device == target.device:
+ raise ValueError(f"input and target must be in the same device. Got: {input.device} and {target.device}")
+
+ # compute softmax over the classes axis
+ input_soft: torch.Tensor = F.softmax(input, dim=1)
+
+ # create the labels one hot tensor
+ target_one_hot: torch.Tensor = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype)
+
+ # compute the actual dice score
+ dims = (1, 2, 3)
+ intersection = torch.sum(input_soft * target_one_hot, dims)
+ cardinality = torch.sum(input_soft + target_one_hot, dims)
+
+ dice_score = 2.0 * intersection / (cardinality + eps)
+
+ return torch.mean(-dice_score + 1.0)
diff --git a/gapartnet/gapartnet/losses/focal_loss.py b/gapartnet/gapartnet/losses/focal_loss.py
new file mode 100644
index 0000000..5ce104f
--- /dev/null
+++ b/gapartnet/gapartnet/losses/focal_loss.py
@@ -0,0 +1,79 @@
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+
+
+def focal_loss(
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ alpha: Optional[torch.Tensor] = None,
+ gamma: float = 2.0,
+ reduction: str = "mean",
+ ignore_index: int = -100,
+) -> torch.Tensor:
+ if ignore_index is not None:
+ valid_mask = targets != ignore_index
+ targets = targets[valid_mask]
+
+ if targets.shape[0] == 0:
+ return torch.tensor(0.0).to(dtype=inputs.dtype, device=inputs.device)
+
+ inputs = inputs[valid_mask]
+
+ log_p = F.log_softmax(inputs, dim=-1)
+ ce_loss = F.nll_loss(
+ log_p, targets, weight=alpha, ignore_index=ignore_index, reduction="none"
+ )
+ log_p_t = log_p.gather(1, targets[:, None]).squeeze(-1)
+ loss = ce_loss * ((1 - log_p_t.exp()) ** gamma)
+
+ if reduction == "mean":
+ loss = loss.mean()
+ elif reduction == "sum":
+ loss = loss.sum()
+
+ return loss
+
+
+def sigmoid_focal_loss(
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ alpha: float = 0.25,
+ gamma: float = 2,
+ reduction: str = "none",
+) -> torch.Tensor:
+ """
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+ Args:
+ inputs (Tensor): A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ alpha (float): Weighting factor in range (0,1) to balance
+ positive vs negative examples or -1 for ignore. Default: ``0.25``.
+ gamma (float): Exponent of the modulating factor (1 - p_t) to
+ balance easy vs hard examples. Default: ``2``.
+ reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
+ ``'none'``: No reduction will be applied to the output.
+ ``'mean'``: The output will be averaged.
+ ``'sum'``: The output will be summed. Default: ``'none'``.
+ Returns:
+ Loss tensor with the reduction option applied.
+ """
+ p = torch.sigmoid(inputs)
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+ p_t = p * targets + (1 - p) * (1 - targets)
+ loss = ce_loss * ((1 - p_t) ** gamma)
+
+ if alpha >= 0:
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+ loss = alpha_t * loss
+
+ if reduction == "mean":
+ loss = loss.mean()
+ elif reduction == "sum":
+ loss = loss.sum()
+
+ return loss
diff --git a/gapartnet/gapartnet/metrics/__init__.py b/gapartnet/gapartnet/metrics/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/gapartnet/gapartnet/metrics/eval.py b/gapartnet/gapartnet/metrics/eval.py
new file mode 100644
index 0000000..44607f6
--- /dev/null
+++ b/gapartnet/gapartnet/metrics/eval.py
@@ -0,0 +1,375 @@
+# Modified from ScanNet evaluation script: https://github.com/ScanNet/ScanNet/blob/master/BenchmarkScripts/3d_evaluation/evaluate_semantic_instance.py
+
+import os, sys, numpy as np
+import utils.utils_3d as util_3d
+import utils.utils as util
+
+# ---------- Label info ---------- #
+CLASS_LABELS = [
+ 'others', 'line_fixed_handle', 'round_fixed_handle', 'slider_button', 'hinge_door', 'slider_drawer', 'slider_lid',
+ 'hinge_lid', 'hinge_knob'
+]
+VALID_CLASS_IDS = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])
+ID_TO_LABEL = {}
+LABEL_TO_ID = {}
+for i in range(len(VALID_CLASS_IDS)):
+ LABEL_TO_ID[CLASS_LABELS[i]] = VALID_CLASS_IDS[i]
+ ID_TO_LABEL[VALID_CLASS_IDS[i]] = CLASS_LABELS[i]
+# ---------- Evaluation params ---------- #
+# overlaps for evaluation
+OVERLAPS = np.append(np.arange(0.5, 0.95, 0.05), 0.25)
+# minimum region size for evaluation [verts]
+MIN_REGION_SIZES = np.array([100]) # ! important hyperparamter
+# distance thresholds [m]
+DISTANCE_THRESHES = np.array([float('inf')])
+# distance confidences
+DISTANCE_CONFS = np.array([-float('inf')])
+
+
+def evaluate_matches(matches):
+ overlaps = OVERLAPS
+ min_region_sizes = [MIN_REGION_SIZES[0]]
+ dist_threshes = [DISTANCE_THRESHES[0]]
+ dist_confs = [DISTANCE_CONFS[0]]
+
+ # results: class x overlap
+ ap = np.zeros((len(dist_threshes), len(CLASS_LABELS), len(overlaps)), np.float)
+ for di, (min_region_size, distance_thresh,
+ distance_conf) in enumerate(zip(min_region_sizes, dist_threshes, dist_confs)):
+ for oi, overlap_th in enumerate(overlaps):
+ pred_visited = {}
+ for m in matches:
+ for p in matches[m]['pred']:
+ for label_name in CLASS_LABELS:
+ for p in matches[m]['pred'][label_name]:
+ if 'filename' in p:
+ pred_visited[p['filename']] = False
+ for li, label_name in enumerate(CLASS_LABELS):
+ y_true = np.empty(0)
+ y_score = np.empty(0)
+ hard_false_negatives = 0
+ has_gt = False
+ has_pred = False
+ for m in matches:
+ pred_instances = matches[m]['pred'][label_name]
+ gt_instances = matches[m]['gt'][label_name]
+ # filter groups in ground truth
+ gt_instances = [
+ gt for gt in gt_instances if gt['instance_id'] >= 1000 and gt['vert_count'] >= min_region_size
+ and gt['med_dist'] <= distance_thresh and gt['dist_conf'] >= distance_conf
+ ]
+ if gt_instances:
+ has_gt = True
+ if pred_instances:
+ has_pred = True
+
+ cur_true = np.ones(len(gt_instances))
+ cur_score = np.ones(len(gt_instances)) * (-float("inf"))
+ cur_match = np.zeros(len(gt_instances), dtype=np.bool)
+ # collect matches
+ for (gti, gt) in enumerate(gt_instances):
+ found_match = False
+ num_pred = len(gt['matched_pred'])
+ for pred in gt['matched_pred']:
+ # greedy assignments
+ if pred_visited[pred['filename']]:
+ continue
+ overlap = float(
+ pred['intersection']) / (gt['vert_count'] + pred['vert_count'] - pred['intersection'])
+ if overlap > overlap_th:
+ confidence = pred['confidence']
+ # if already have a prediction for this gt,
+ # the prediction with the lower score is automatically a false positive
+ if cur_match[gti]:
+ max_score = max(cur_score[gti], confidence)
+ min_score = min(cur_score[gti], confidence)
+ cur_score[gti] = max_score
+ # append false positive
+ cur_true = np.append(cur_true, 0)
+ cur_score = np.append(cur_score, min_score)
+ cur_match = np.append(cur_match, True)
+ # otherwise set score
+ else:
+ found_match = True
+ cur_match[gti] = True
+ cur_score[gti] = confidence
+ pred_visited[pred['filename']] = True
+ if not found_match:
+ hard_false_negatives += 1
+ # remove non-matched ground truth instances
+ cur_true = cur_true[cur_match == True]
+ cur_score = cur_score[cur_match == True]
+
+ # collect non-matched predictions as false positive
+ for pred in pred_instances:
+ found_gt = False
+ for gt in pred['matched_gt']:
+ overlap = float(
+ gt['intersection']) / (gt['vert_count'] + pred['vert_count'] - gt['intersection'])
+ if overlap > overlap_th:
+ found_gt = True
+ break
+ if not found_gt:
+ num_ignore = pred['void_intersection']
+ for gt in pred['matched_gt']:
+ # group?
+ if gt['instance_id'] < 1000:
+ num_ignore += gt['intersection']
+ # small ground truth instances
+ if gt['vert_count'] < min_region_size or gt['med_dist'] > distance_thresh or gt[
+ 'dist_conf'] < distance_conf:
+ num_ignore += gt['intersection']
+ proportion_ignore = float(num_ignore) / pred['vert_count']
+ # if not ignored append false positive
+ if proportion_ignore <= overlap_th:
+ cur_true = np.append(cur_true, 0)
+ confidence = pred["confidence"]
+ cur_score = np.append(cur_score, confidence)
+
+ # append to overall results
+ y_true = np.append(y_true, cur_true)
+ y_score = np.append(y_score, cur_score)
+
+ # compute average precision
+ if has_gt and has_pred:
+ # compute precision recall curve first
+
+ # sorting and cumsum
+ score_arg_sort = np.argsort(y_score)
+ y_score_sorted = y_score[score_arg_sort]
+ y_true_sorted = y_true[score_arg_sort]
+ y_true_sorted_cumsum = np.cumsum(y_true_sorted)
+
+ # unique thresholds
+ (thresholds, unique_indices) = np.unique(y_score_sorted, return_index=True)
+ num_prec_recall = len(unique_indices) + 1
+
+ # prepare precision recall
+ num_examples = len(y_score_sorted)
+ if(len(y_true_sorted_cumsum) == 0):
+ num_true_examples = 0
+ else:
+ num_true_examples = y_true_sorted_cumsum[-1]
+ precision = np.zeros(num_prec_recall)
+ recall = np.zeros(num_prec_recall)
+
+ # deal with the first point
+ y_true_sorted_cumsum = np.append(y_true_sorted_cumsum, 0)
+ # deal with remaining
+ for idx_res, idx_scores in enumerate(unique_indices):
+ cumsum = y_true_sorted_cumsum[idx_scores - 1]
+ tp = num_true_examples - cumsum
+ fp = num_examples - idx_scores - tp
+ fn = cumsum + hard_false_negatives
+ p = float(tp) / (tp + fp)
+ r = float(tp) / (tp + fn)
+ precision[idx_res] = p
+ recall[idx_res] = r
+
+ # first point in curve is artificial
+ precision[-1] = 1.
+ recall[-1] = 0.
+
+ # compute average of precision-recall curve
+ recall_for_conv = np.copy(recall)
+ recall_for_conv = np.append(recall_for_conv[0], recall_for_conv)
+ recall_for_conv = np.append(recall_for_conv, 0.)
+
+ stepWidths = np.convolve(recall_for_conv, [-0.5, 0, 0.5], 'valid')
+ # integrate is now simply a dot product
+ ap_current = np.dot(precision, stepWidths)
+
+ elif has_gt:
+ ap_current = 0.0
+ else:
+ ap_current = float('nan')
+ ap[di, li, oi] = ap_current
+ return ap
+
+
+def compute_averages(aps):
+ d_inf = 0
+ o50 = np.where(np.isclose(OVERLAPS, 0.5))
+ o25 = np.where(np.isclose(OVERLAPS, 0.25))
+ oAllBut25 = np.where(np.logical_not(np.isclose(OVERLAPS, 0.25)))
+ avg_dict = {}
+ #avg_dict['all_ap'] = np.nanmean(aps[ d_inf,:,: ])
+ avg_dict['all_ap'] = np.nanmean(aps[d_inf, :, oAllBut25])
+ avg_dict['all_ap_50%'] = np.nanmean(aps[d_inf, :, o50])
+ avg_dict['all_ap_25%'] = np.nanmean(aps[d_inf, :, o25])
+ avg_dict["classes"] = {}
+ for (li, label_name) in enumerate(CLASS_LABELS):
+ avg_dict["classes"][label_name] = {}
+ #avg_dict["classes"][label_name]["ap"] = np.average(aps[ d_inf,li, :])
+ avg_dict["classes"][label_name]["ap"] = np.average(aps[d_inf,li,oAllBut25])
+ avg_dict["classes"][label_name]["ap50%"] = np.average(aps[d_inf,li,o50])
+ avg_dict["classes"][label_name]["ap25%"] = np.average(aps[d_inf,li,o25])
+ return avg_dict
+
+
+def assign_instances_for_scan(scene_name, pred_info, gt_file):
+ try:
+ gt_ids,npcs = util_3d.load_ids_npcs(gt_file)
+ gt_ids = gt_ids.astype(np.int32)
+ except Exception as e:
+ util.print_error('unable to load ' + gt_file + ': ' + str(e))
+
+ # get gt instances
+ gt_instances = util_3d.get_instances(gt_ids, VALID_CLASS_IDS, CLASS_LABELS, ID_TO_LABEL)
+
+ # associate
+ gt2pred = gt_instances.copy()
+ for label in gt2pred:
+ for gt in gt2pred[label]:
+ gt['matched_pred'] = []
+ pred2gt = {}
+ for label in CLASS_LABELS:
+ pred2gt[label] = []
+ num_pred_instances = 0
+ # mask of void labels in the groundtruth
+ bool_void = np.logical_not(np.in1d(gt_ids // 1000, VALID_CLASS_IDS)) # ! should be all False here
+ # go thru all prediction masks
+ nMask = pred_info['label_id'].shape[0]
+ for i in range(nMask):
+ label_id = int(pred_info['label_id'][i])
+ conf = pred_info['conf'][i]
+ if not label_id in ID_TO_LABEL:
+ continue
+ label_name = ID_TO_LABEL[label_id]
+ # read the mask
+ pred_mask = pred_info['mask'][i] # (N), long
+ if len(pred_mask) != len(gt_ids):
+ util.print_error('wrong number of lines in mask#%d: ' % (i) + '(%d) vs #mesh vertices (%d)' %
+ (len(pred_mask), len(gt_ids)))
+ # convert to binary
+ pred_mask = np.not_equal(pred_mask, 0)
+ num = np.count_nonzero(pred_mask)
+ if num < MIN_REGION_SIZES[0]:
+ continue # skip if empty
+
+ pred_instance = {}
+ pred_instance['filename'] = '{}_{:03d}'.format(scene_name, num_pred_instances)
+ pred_instance['pred_id'] = num_pred_instances
+ pred_instance['label_id'] = label_id
+ pred_instance['vert_count'] = num
+ pred_instance['confidence'] = conf
+ pred_instance['void_intersection'] = np.count_nonzero(np.logical_and(bool_void,
+ pred_mask)) # ! should be 0 here
+
+ # matched gt instances
+ matched_gt = []
+ # go thru all gt instances with matching label
+ for (gt_num, gt_inst) in enumerate(gt2pred[label_name]):
+ intersection = np.count_nonzero(np.logical_and(gt_ids == gt_inst['instance_id'], pred_mask))
+ if intersection > 0:
+ gt_copy = gt_inst.copy()
+ pred_copy = pred_instance.copy()
+ gt_copy['intersection'] = intersection
+ pred_copy['intersection'] = intersection
+ matched_gt.append(gt_copy)
+ gt2pred[label_name][gt_num]['matched_pred'].append(pred_copy)
+ pred_instance['matched_gt'] = matched_gt
+ num_pred_instances += 1
+ pred2gt[label_name].append(pred_instance)
+
+ return gt2pred, pred2gt, npcs
+
+
+def print_results(avgs):
+ from util.log import logger
+ sep = ""
+ col1 = ":"
+ lineLen = 64
+
+ logger.info("")
+ logger.info("#" * lineLen)
+ line = ""
+ line += "{:<15}".format("what") + sep + col1
+ line += "{:>15}".format("AP") + sep
+ line += "{:>15}".format("AP_50%") + sep
+ line += "{:>15}".format("AP_25%") + sep
+ logger.info(line)
+ logger.info("#" * lineLen)
+
+ for (li, label_name) in enumerate(CLASS_LABELS):
+ ap_avg = avgs["classes"][label_name]["ap"]
+ ap_50o = avgs["classes"][label_name]["ap50%"]
+ ap_25o = avgs["classes"][label_name]["ap25%"]
+ line = "{:<15}".format(label_name) + sep + col1
+ line += sep + "{:>15.3f}".format(ap_avg) + sep
+ line += sep + "{:>15.3f}".format(ap_50o) + sep
+ line += sep + "{:>15.3f}".format(ap_25o) + sep
+ logger.info(line)
+
+ all_ap_avg = avgs["all_ap"]
+ all_ap_50o = avgs["all_ap_50%"]
+ all_ap_25o = avgs["all_ap_25%"]
+
+ logger.info("-" * lineLen)
+ line = "{:<15}".format("average") + sep + col1
+ line += "{:>15.3f}".format(all_ap_avg) + sep
+ line += "{:>15.3f}".format(all_ap_50o) + sep
+ line += "{:>15.3f}".format(all_ap_25o) + sep
+ logger.info(line)
+ logger.info("")
+
+
+class Prec_Rec_Acc_Calculator(object):
+ """Compute Precision, Recall, Accuracy for each semantic label"""
+ def __init__(self):
+ self.id_to_label = ID_TO_LABEL
+ self.classes = 9
+ ls = ['tp', 'fp', 'tn', 'fn']
+ self.stat = {x: {y: 0 for y in ls} for x in range(self.classes)}
+ self.acc = 0
+ self.n = 0
+
+ def update(self, semantic_label, semantic_pred):
+ '''
+ semantic_label: tensor, (N)
+ semantic_pred: tensor, (N)
+ '''
+ assert semantic_label.size(0) == semantic_pred.size(
+ 0), "Prec_Rec_Acc_Calculator: semantic_label.size(0) not equal to semantic_pred.size(0)!"
+
+ for cl in self.stat.keys():
+ gt_mask = (semantic_label == cl).float()
+ pred_mask = (semantic_pred == cl).float()
+
+ tp_mask = pred_mask * gt_mask
+ fp_mask = pred_mask * (1 - gt_mask)
+ tn_mask = (1 - pred_mask) * (1 - gt_mask)
+ fn_mask = (1 - pred_mask) * gt_mask
+
+ self.stat[cl]['tp'] += tp_mask.sum()
+ self.stat[cl]['fp'] += fp_mask.sum()
+ self.stat[cl]['tn'] += tn_mask.sum()
+ self.stat[cl]['fn'] += fn_mask.sum()
+
+ self.n += semantic_label.size(0)
+
+ def get_stat(self):
+ ret = {}
+ '''precision'''
+ for cl in self.stat.keys():
+ class_name = self.id_to_label[cl]
+ precision = self.stat[cl]['tp'] / (self.stat[cl]['tp'] + self.stat[cl]['fp'] + 1e-6)
+ ret[class_name + '_' + 'semantic_precision'] = precision
+ '''recall'''
+ for cl in self.stat.keys():
+ class_name = self.id_to_label[cl]
+ recall = self.stat[cl]['tp'] / (self.stat[cl]['tp'] + self.stat[cl]['fn'] + 1e-6)
+ ret[class_name + '_' + 'semantic_recall'] = recall
+ '''accuracy'''
+ sum_tp = 0
+ for cl in self.stat.keys():
+ sum_tp += self.stat[cl]['tp']
+ accuracy = sum_tp / (self.n + 1e-6)
+ ret['semantic_accuracy'] = accuracy
+
+ return ret
+
+
+if __name__ == "__main__":
+ test = Prec_Rec_Acc_Calculator()
+ ret = test.get_stat()
diff --git a/gapartnet/gapartnet/metrics/pose.py b/gapartnet/gapartnet/metrics/pose.py
new file mode 100644
index 0000000..8442753
--- /dev/null
+++ b/gapartnet/gapartnet/metrics/pose.py
@@ -0,0 +1,89 @@
+import torch
+import numpy as np
+import itertools
+
+def rot_diff_rad(rot1, rot2):
+ mat_diff = np.dot(rot1, rot2.transpose(-1, -2))
+ diff = mat_diff[..., 0, 0] + mat_diff[..., 1, 1] + mat_diff[..., 2, 2]
+ diff = (diff - 1) / 2.0
+ diff = np.clip(diff, -1.0, 1.0)
+ return np.arccos(diff)
+
+def z_rot_diff_rad(r1,r2):
+ return np.arccos(np.dot(r1,r2.T)/(np.linalg.norm(r1)*np.linalg.norm(r2)))
+
+def z_rot_diff_degree(r1,r2):
+ return z_rot_diff_rad(r1,r2) / np.pi * 180.0
+
+
+def z_rot_norm_diff_rad(r1,r2):
+ return np.arccos(np.linalg.norm(r1*r2)/(np.linalg.norm(r1)*np.linalg.norm(r2)))
+
+def z_rot_norm_diff_degree(r1,r2):
+ return z_rot_norm_diff_rad(r1,r2) / np.pi * 180.0
+
+def rot_diff_degree(rot1, rot2):
+ return rot_diff_rad(rot1, rot2) / np.pi * 180.0
+
+
+def trans_diff(trans1, trans2):
+ return np.linalg.norm((trans1 - trans2)) # [..., 3, 1] -> [..., 3] -> [...]
+
+
+def scale_diff(scale1, scale2):
+ return np.absolute(scale1 - scale2)
+
+
+def theta_diff(theta1, theta2):
+ return np.absolute(theta1 - theta2)
+
+def pts_inside_box(pts, bbox):
+ # pts: N x 3
+ u1 = bbox[1, :] - bbox[0, :]
+ u2 = bbox[2, :] - bbox[0, :]
+ u3 = bbox[3, :] - bbox[0, :]
+
+ up = pts - np.reshape(bbox[0, :], (1, 3))
+ p1 = np.matmul(up, u1.reshape((3, 1)))
+ p2 = np.matmul(up, u2.reshape((3, 1)))
+ p3 = np.matmul(up, u3.reshape((3, 1)))
+ p1 = np.logical_and(p1>0, p10, p20, p3RatioTS) else RatioTS
+ # print(TargetNorm,SourceNorm,RatioTS,RatioST,PassT)
+ StopT = 0.5 #PassT / 100
+ nIter = 100
+ if verbose:
+ print('Pass threshold: ', PassT)
+ print('Stop threshold: ', StopT)
+ print('Number of iterations: ', nIter)
+
+ SourceInliersHom, TargetInliersHom, BestInlierRatio, BestInlierIdx = \
+ getRANSACInliers(SourceHom, TargetHom, MaxIterations=nIter, PassThreshold=PassT, StopThreshold=StopT)
+ # print("###################")
+ # print(len(BestInlierIdx))
+
+ # print("###################")
+ # print(SourceInliersHom)
+ if(BestInlierRatio < 0.01): # haoran: 0.1->0.01
+ print('[ WARN ] - Something is wrong. Small BestInlierRatio: ', BestInlierRatio)
+ return None, np.array([None,None,None,]), None, None, None
+
+ Scales, Rotation, Translation, OutTransform = estimateSimilarityUmeyama(SourceInliersHom, TargetInliersHom)
+
+ if verbose:
+ print('BestInlierRatio:', BestInlierRatio)
+ print('Rotation:\n', Rotation)
+ print('Translation:\n', Translation)
+ print('Scales:', Scales)
+
+ return Scales, Rotation, Translation, OutTransform, BestInlierIdx
+
+def estimateRestrictedAffineTransform(source: np.array, target: np.array, verbose=False):
+ SourceHom = np.transpose(np.hstack([source, np.ones([source.shape[0], 1])]))
+ TargetHom = np.transpose(np.hstack([target, np.ones([source.shape[0], 1])]))
+
+ RetVal, AffineTrans, Inliers = cv2.estimateAffine3D(source, target)
+ # We assume no shear in the affine matrix and decompose into rotation, non-uniform scales, and translation
+ Translation = AffineTrans[:3, 3]
+ NUScaleRotMat = AffineTrans[:3, :3]
+ # NUScaleRotMat should be the matrix SR, where S is a diagonal scale matrix and R is the rotation matrix (equivalently RS)
+ # Let us do the SVD of NUScaleRotMat to obtain R1*S*R2 and then R = R1 * R2
+ R1, ScalesSorted, R2 = np.linalg.svd(NUScaleRotMat, full_matrices=True)
+
+ if verbose:
+ print('-----------------------------------------------------------------------')
+ # Now, the scales are sort in ascending order which is painful because we don't know the x, y, z scales
+ # Let's figure that out by evaluating all 6 possible permutations of the scales
+ ScalePermutations = list(itertools.permutations(ScalesSorted))
+ MinResidual = 1e8
+ Scales = ScalePermutations[0]
+ OutTransform = np.identity(4)
+ Rotation = np.identity(3)
+ for ScaleCand in ScalePermutations:
+ CurrScale = np.asarray(ScaleCand)
+ CurrTransform = np.identity(4)
+ CurrRotation = (np.diag(1 / CurrScale) @ NUScaleRotMat).transpose()
+ CurrTransform[:3, :3] = np.diag(CurrScale) @ CurrRotation
+ CurrTransform[:3, 3] = Translation
+ # Residual = evaluateModel(CurrTransform, SourceHom, TargetHom)
+ Residual = evaluateModelNonHom(source, target, CurrScale,CurrRotation, Translation)
+ if verbose:
+ # print('CurrTransform:\n', CurrTransform)
+ print('CurrScale:', CurrScale)
+ print('Residual:', Residual)
+ print('AltRes:', evaluateModelNoThresh(CurrTransform, SourceHom, TargetHom))
+ if Residual < MinResidual:
+ MinResidual = Residual
+ Scales = CurrScale
+ Rotation = CurrRotation
+ OutTransform = CurrTransform
+
+ if verbose:
+ print('Best Scale:', Scales)
+
+ if verbose:
+ print('Affine Scales:', Scales)
+ print('Affine Translation:', Translation)
+ print('Affine Rotation:\n', Rotation)
+ print('-----------------------------------------------------------------------')
+
+ return Scales, Rotation, Translation, OutTransform
+
+def getRANSACInliers(SourceHom, TargetHom, MaxIterations=100, PassThreshold=200, StopThreshold=0.5):
+ BestResidual = 1e10
+ BestInlierRatio = 0
+ BestInlierIdx = np.arange(SourceHom.shape[1])
+ for i in range(0, MaxIterations):
+ # Pick 5 random (but corresponding) points from source and target
+ RandIdx = np.random.randint(SourceHom.shape[1], size=5)
+ _, _, _, OutTransform = estimateSimilarityUmeyama(SourceHom[:, RandIdx], TargetHom[:, RandIdx])
+ Residual, InlierRatio, InlierIdx = evaluateModel(OutTransform, SourceHom, TargetHom, PassThreshold)
+ if Residual < BestResidual:
+ BestResidual = Residual
+ BestInlierRatio = InlierRatio
+ BestInlierIdx = InlierIdx
+ if BestResidual < StopThreshold:
+ break
+
+ # print('Iteration: ', i)
+ # print('Residual: ', Residual)
+ # print('Inlier ratio: ', InlierRatio)
+
+ return SourceHom[:, BestInlierIdx], TargetHom[:, BestInlierIdx], BestInlierRatio, BestInlierIdx
+
+def evaluateModel(OutTransform, SourceHom, TargetHom, PassThreshold):
+ Diff = TargetHom - np.matmul(OutTransform, SourceHom)
+ ResidualVec = np.linalg.norm(Diff[:3, :], axis=0)
+ Residual = np.linalg.norm(ResidualVec)
+ InlierIdx = np.where(ResidualVec < PassThreshold)
+ nInliers = np.count_nonzero(InlierIdx)
+ InlierRatio = nInliers / SourceHom.shape[1]
+ return Residual, InlierRatio, InlierIdx[0]
+
+def evaluateModelNoThresh(OutTransform, SourceHom, TargetHom):
+ Diff = TargetHom - np.matmul(OutTransform, SourceHom)
+ ResidualVec = np.linalg.norm(Diff[:3, :], axis=0)
+ Residual = np.linalg.norm(ResidualVec)
+ return Residual
+
+def evaluateModelNonHom(source, target, Scales, Rotation, Translation):
+ RepTrans = np.tile(Translation, (source.shape[0], 1))
+ TransSource = (np.diag(Scales) @ Rotation @ source.transpose() + RepTrans.transpose()).transpose()
+ Diff = target - TransSource
+ ResidualVec = np.linalg.norm(Diff, axis=0)
+ Residual = np.linalg.norm(ResidualVec)
+ return Residual
+
+def testNonUniformScale(SourceHom, TargetHom):
+ OutTransform = np.matmul(TargetHom, np.linalg.pinv(SourceHom))
+ ScaledRotation = OutTransform[:3, :3]
+ Translation = OutTransform[:3, 3]
+ Sx = np.linalg.norm(ScaledRotation[0, :])
+ Sy = np.linalg.norm(ScaledRotation[1, :])
+ Sz = np.linalg.norm(ScaledRotation[2, :])
+ Rotation = np.vstack([ScaledRotation[0, :] / Sx, ScaledRotation[1, :] / Sy, ScaledRotation[2, :] / Sz])
+ print('Rotation matrix norm:', np.linalg.norm(Rotation))
+ Scales = np.array([Sx, Sy, Sz])
+
+ # # Check
+ # Diff = TargetHom - np.matmul(OutTransform, SourceHom)
+ # Residual = np.linalg.norm(Diff[:3, :], axis=0)
+ return Scales, Rotation, Translation, OutTransform
+
+def estimateSimilarityUmeyama(SourceHom, TargetHom):
+ # Copy of original paper is at: http://web.stanford.edu/class/cs273/refs/umeyama.pdf
+ SourceCentroid = np.mean(SourceHom[:3, :], axis=1)
+ TargetCentroid = np.mean(TargetHom[:3, :], axis=1)
+ nPoints = SourceHom.shape[1]
+
+ CenteredSource = SourceHom[:3, :] - np.tile(SourceCentroid, (nPoints, 1)).transpose()
+ CenteredTarget = TargetHom[:3, :] - np.tile(TargetCentroid, (nPoints, 1)).transpose()
+
+ CovMatrix = np.matmul(CenteredTarget, np.transpose(CenteredSource)) / nPoints
+ # print(CenteredTarget, CenteredSource)
+ # print(CovMatrix)
+ if np.isnan(CovMatrix).any():
+ print('nPoints:', nPoints)
+ print(SourceHom.shape)
+ print(TargetHom.shape)
+ raise RuntimeError('There are NANs in the input.')
+
+ U, D, Vh = np.linalg.svd(CovMatrix, full_matrices=True)
+ d = (np.linalg.det(U) * np.linalg.det(Vh)) < 0.0
+ if d:
+ D[-1] = -D[-1]
+ U[:, -1] = -U[:, -1]
+
+ Rotation = np.matmul(U, Vh).T # Transpose is the one that works
+
+ varP = np.var(SourceHom[:3, :], axis=1).sum()
+ ScaleFact = 1/varP * np.sum(D) # scale factor
+ Scales = np.array([ScaleFact, ScaleFact, ScaleFact])
+ ScaleMatrix = np.diag(Scales)
+
+ Translation = TargetHom[:3, :].mean(axis=1) - SourceHom[:3, :].mean(axis=1).dot(ScaleFact*Rotation)
+
+ OutTransform = np.identity(4)
+ OutTransform[:3, :3] = ScaleMatrix @ Rotation
+ OutTransform[:3, 3] = Translation
+
+ # # Check
+ # Diff = TargetHom - np.matmul(OutTransform, SourceHom)
+ # Residual = np.linalg.norm(Diff[:3, :], axis=0)
+ return Scales, Rotation, Translation, OutTransform
+
+def estimate3dbbox(coord, npcs_pred, part_mask):
+ coord_part = coord[np.where(part_mask == True)]
+ npcs_part = npcs_pred[np.where(part_mask == True)]
+
+ coord_part=np.array(coord_part)
+ npcs_part=np.array(npcs_part)
+
+ Scales, Rotation, Translation, OutTransform , BestInlierIdx= estimateSimilarityTransform(npcs_part,coord_part)
+ Rotation_I = np.linalg.pinv(Rotation)
+ trans_seg = np.dot((coord_part-Translation),Rotation_I)/Scales[0]
+ npcs_max = abs(trans_seg[BestInlierIdx]).max(0)
+ bbox_raw = np.array([
+ [-npcs_max[0], -npcs_max[1], -npcs_max[2]],
+ [npcs_max[0], -npcs_max[1], -npcs_max[2]],
+ [-npcs_max[0], npcs_max[1], -npcs_max[2]],
+ [-npcs_max[0], -npcs_max[1], npcs_max[2]],
+ [npcs_max[0], npcs_max[1], -npcs_max[2]],
+
+ [npcs_max[0], -npcs_max[1], npcs_max[2]],
+ [-npcs_max[0], npcs_max[1], npcs_max[2]],
+ [npcs_max[0], npcs_max[1], npcs_max[2]]
+ ])
+ bbox_trans = np.dot((bbox_raw * Scales[0]) , Rotation) + Translation
+ return bbox_trans, Scales, Rotation, Translation
+
+
+def estimate3dbbox_part(coord_part, npcs_part):
+ Scales, Rotation, Translation, OutTransform, BestInlierIdx = estimateSimilarityTransform(npcs_part,coord_part)
+ if Rotation.all()==None:
+ return np.array([None,None,None,]),None,None,None
+ Rotation_I = np.linalg.pinv(Rotation)
+ trans_seg = np.dot((coord_part-Translation),Rotation_I)/Scales[0]
+ npcs_max = abs(trans_seg[BestInlierIdx]).max(0)
+ bbox_raw = np.array([
+ [-npcs_max[0], -npcs_max[1], -npcs_max[2]],
+ [npcs_max[0], -npcs_max[1], -npcs_max[2]],
+ [-npcs_max[0], npcs_max[1], -npcs_max[2]],
+ [-npcs_max[0], -npcs_max[1], npcs_max[2]],
+ [npcs_max[0], npcs_max[1], -npcs_max[2]],
+
+ [npcs_max[0], -npcs_max[1], npcs_max[2]],
+ [-npcs_max[0], npcs_max[1], npcs_max[2]],
+ [npcs_max[0], npcs_max[1], npcs_max[2]]
+ ])
+ bbox_trans = np.dot((bbox_raw * Scales[0]) , Rotation) + Translation
+ return bbox_trans, Scales, Rotation, Translation
\ No newline at end of file
diff --git a/gapartnet/gapartnet/metrics/pose_fitting/pose_fitting_nocs.py b/gapartnet/gapartnet/metrics/pose_fitting/pose_fitting_nocs.py
new file mode 100644
index 0000000..13c2deb
--- /dev/null
+++ b/gapartnet/gapartnet/metrics/pose_fitting/pose_fitting_nocs.py
@@ -0,0 +1,254 @@
+'''
+Normalized Object Coordinate Space for Category-Level 6D Object Pose and Size Estimation
+RANSAC for Similarity Transformation Estimation
+
+Written by Srinath Sridhar
+'''
+
+import numpy as np
+import cv2
+import itertools
+
+def estimateSimilarityTransform(source: np.array, target: np.array, verbose=False):
+ if source.shape[0]==1:
+ source = [source[0],source[0]]
+ target = [target[0],target[0]]
+
+ SourceHom = np.transpose(np.hstack([source, np.ones([source.shape[0], 1])]))
+ # print(SourceHom.shape)
+ TargetHom = np.transpose(np.hstack([target, np.ones([source.shape[0], 1])]))
+ # print(SourceHom)
+ # Auto-parameter selection based on source-target heuristics
+ TargetNorm = np.mean(np.linalg.norm(target, axis=1))
+ SourceNorm = np.mean(np.linalg.norm(source, axis=1))
+ RatioTS = (TargetNorm / SourceNorm)
+ RatioST = (SourceNorm / TargetNorm)
+ PassT = RatioST if(RatioST>RatioTS) else RatioTS
+ # print(TargetNorm,SourceNorm,RatioTS,RatioST,PassT)
+ StopT = 0.5 #PassT / 100
+ nIter = 100
+ if verbose:
+ print('Pass threshold: ', PassT)
+ print('Stop threshold: ', StopT)
+ print('Number of iterations: ', nIter)
+
+ SourceInliersHom, TargetInliersHom, BestInlierRatio, BestInlierIdx = \
+ getRANSACInliers(SourceHom, TargetHom, MaxIterations=nIter, PassThreshold=PassT, StopThreshold=StopT)
+ # print("###################")
+ # print(len(BestInlierIdx))
+
+ # print("###################")
+ # print(SourceInliersHom)
+ if(BestInlierRatio < 0.01): # haoran: 0.1->0.01
+ print('[ WARN ] - Something is wrong. Small BestInlierRatio: ', BestInlierRatio)
+ return None, np.array([None,None,None,]), None, None, None
+
+ Scales, Rotation, Translation, OutTransform = estimateSimilarityUmeyama(SourceInliersHom, TargetInliersHom)
+
+ if verbose:
+ print('BestInlierRatio:', BestInlierRatio)
+ print('Rotation:\n', Rotation)
+ print('Translation:\n', Translation)
+ print('Scales:', Scales)
+
+ return Scales, Rotation, Translation, OutTransform, BestInlierIdx
+
+def estimateRestrictedAffineTransform(source: np.array, target: np.array, verbose=False):
+ SourceHom = np.transpose(np.hstack([source, np.ones([source.shape[0], 1])]))
+ TargetHom = np.transpose(np.hstack([target, np.ones([source.shape[0], 1])]))
+
+ RetVal, AffineTrans, Inliers = cv2.estimateAffine3D(source, target)
+ # We assume no shear in the affine matrix and decompose into rotation, non-uniform scales, and translation
+ Translation = AffineTrans[:3, 3]
+ NUScaleRotMat = AffineTrans[:3, :3]
+ # NUScaleRotMat should be the matrix SR, where S is a diagonal scale matrix and R is the rotation matrix (equivalently RS)
+ # Let us do the SVD of NUScaleRotMat to obtain R1*S*R2 and then R = R1 * R2
+ R1, ScalesSorted, R2 = np.linalg.svd(NUScaleRotMat, full_matrices=True)
+
+ if verbose:
+ print('-----------------------------------------------------------------------')
+ # Now, the scales are sort in ascending order which is painful because we don't know the x, y, z scales
+ # Let's figure that out by evaluating all 6 possible permutations of the scales
+ ScalePermutations = list(itertools.permutations(ScalesSorted))
+ MinResidual = 1e8
+ Scales = ScalePermutations[0]
+ OutTransform = np.identity(4)
+ Rotation = np.identity(3)
+ for ScaleCand in ScalePermutations:
+ CurrScale = np.asarray(ScaleCand)
+ CurrTransform = np.identity(4)
+ CurrRotation = (np.diag(1 / CurrScale) @ NUScaleRotMat).transpose()
+ CurrTransform[:3, :3] = np.diag(CurrScale) @ CurrRotation
+ CurrTransform[:3, 3] = Translation
+ # Residual = evaluateModel(CurrTransform, SourceHom, TargetHom)
+ Residual = evaluateModelNonHom(source, target, CurrScale,CurrRotation, Translation)
+ if verbose:
+ # print('CurrTransform:\n', CurrTransform)
+ print('CurrScale:', CurrScale)
+ print('Residual:', Residual)
+ print('AltRes:', evaluateModelNoThresh(CurrTransform, SourceHom, TargetHom))
+ if Residual < MinResidual:
+ MinResidual = Residual
+ Scales = CurrScale
+ Rotation = CurrRotation
+ OutTransform = CurrTransform
+
+ if verbose:
+ print('Best Scale:', Scales)
+
+ if verbose:
+ print('Affine Scales:', Scales)
+ print('Affine Translation:', Translation)
+ print('Affine Rotation:\n', Rotation)
+ print('-----------------------------------------------------------------------')
+
+ return Scales, Rotation, Translation, OutTransform
+
+def getRANSACInliers(SourceHom, TargetHom, MaxIterations=100, PassThreshold=200, StopThreshold=0.5):
+ BestResidual = 1e10
+ BestInlierRatio = 0
+ BestInlierIdx = np.arange(SourceHom.shape[1])
+ for i in range(0, MaxIterations):
+ # Pick 5 random (but corresponding) points from source and target
+ RandIdx = np.random.randint(SourceHom.shape[1], size=5)
+ _, _, _, OutTransform = estimateSimilarityUmeyama(SourceHom[:, RandIdx], TargetHom[:, RandIdx])
+ Residual, InlierRatio, InlierIdx = evaluateModel(OutTransform, SourceHom, TargetHom, PassThreshold)
+ if Residual < BestResidual:
+ BestResidual = Residual
+ BestInlierRatio = InlierRatio
+ BestInlierIdx = InlierIdx
+ if BestResidual < StopThreshold:
+ break
+
+ # print('Iteration: ', i)
+ # print('Residual: ', Residual)
+ # print('InlierIdx: ', InlierIdx)
+
+ return SourceHom[:, BestInlierIdx], TargetHom[:, BestInlierIdx], BestInlierRatio, BestInlierIdx
+
+def evaluateModel(OutTransform, SourceHom, TargetHom, PassThreshold):
+ Diff = TargetHom - np.matmul(OutTransform, SourceHom)
+ ResidualVec = np.linalg.norm(Diff[:3, :], axis=0)
+ Residual = np.linalg.norm(ResidualVec)
+ InlierIdx = np.where(ResidualVec < PassThreshold)
+ nInliers = np.count_nonzero(InlierIdx)
+ InlierRatio = nInliers / SourceHom.shape[1]
+ return Residual, InlierRatio, InlierIdx[0]
+
+def evaluateModelNoThresh(OutTransform, SourceHom, TargetHom):
+ Diff = TargetHom - np.matmul(OutTransform, SourceHom)
+ ResidualVec = np.linalg.norm(Diff[:3, :], axis=0)
+ Residual = np.linalg.norm(ResidualVec)
+ return Residual
+
+def evaluateModelNonHom(source, target, Scales, Rotation, Translation):
+ RepTrans = np.tile(Translation, (source.shape[0], 1))
+ TransSource = (np.diag(Scales) @ Rotation @ source.transpose() + RepTrans.transpose()).transpose()
+ Diff = target - TransSource
+ ResidualVec = np.linalg.norm(Diff, axis=0)
+ Residual = np.linalg.norm(ResidualVec)
+ return Residual
+
+def testNonUniformScale(SourceHom, TargetHom):
+ OutTransform = np.matmul(TargetHom, np.linalg.pinv(SourceHom))
+ ScaledRotation = OutTransform[:3, :3]
+ Translation = OutTransform[:3, 3]
+ Sx = np.linalg.norm(ScaledRotation[0, :])
+ Sy = np.linalg.norm(ScaledRotation[1, :])
+ Sz = np.linalg.norm(ScaledRotation[2, :])
+ Rotation = np.vstack([ScaledRotation[0, :] / Sx, ScaledRotation[1, :] / Sy, ScaledRotation[2, :] / Sz])
+ print('Rotation matrix norm:', np.linalg.norm(Rotation))
+ Scales = np.array([Sx, Sy, Sz])
+
+ # # Check
+ # Diff = TargetHom - np.matmul(OutTransform, SourceHom)
+ # Residual = np.linalg.norm(Diff[:3, :], axis=0)
+ return Scales, Rotation, Translation, OutTransform
+
+def estimateSimilarityUmeyama(SourceHom, TargetHom):
+ # Copy of original paper is at: http://web.stanford.edu/class/cs273/refs/umeyama.pdf
+ SourceCentroid = np.mean(SourceHom[:3, :], axis=1)
+ TargetCentroid = np.mean(TargetHom[:3, :], axis=1)
+ nPoints = SourceHom.shape[1]
+
+ CenteredSource = SourceHom[:3, :] - np.tile(SourceCentroid, (nPoints, 1)).transpose()
+ CenteredTarget = TargetHom[:3, :] - np.tile(TargetCentroid, (nPoints, 1)).transpose()
+
+ CovMatrix = np.matmul(CenteredTarget, np.transpose(CenteredSource)) / nPoints
+ # print(CenteredTarget, CenteredSource)
+ # print(CovMatrix)
+ if np.isnan(CovMatrix).any():
+ print('nPoints:', nPoints)
+ print(SourceHom.shape)
+ print(TargetHom.shape)
+ raise RuntimeError('There are NANs in the input.')
+
+ U, D, Vh = np.linalg.svd(CovMatrix, full_matrices=True)
+ d = (np.linalg.det(U) * np.linalg.det(Vh)) < 0.0
+ if d:
+ D[-1] = -D[-1]
+ U[:, -1] = -U[:, -1]
+
+ Rotation = np.matmul(U, Vh).T # Transpose is the one that works
+
+ varP = np.var(SourceHom[:3, :], axis=1).sum()
+ ScaleFact = 1/varP * np.sum(D) # scale factor
+ Scales = np.array([ScaleFact, ScaleFact, ScaleFact])
+ ScaleMatrix = np.diag(Scales)
+
+ Translation = TargetHom[:3, :].mean(axis=1) - SourceHom[:3, :].mean(axis=1).dot(ScaleFact*Rotation)
+
+ OutTransform = np.identity(4)
+ OutTransform[:3, :3] = ScaleMatrix @ Rotation
+ OutTransform[:3, 3] = Translation
+
+ # # Check
+ # Diff = TargetHom - np.matmul(OutTransform, SourceHom)
+ # Residual = np.linalg.norm(Diff[:3, :], axis=0)
+ return Scales, Rotation, Translation, OutTransform
+
+def estimate3dbbox(coord, npcs_pred, part_mask):
+ coord_part = coord[np.where(part_mask == True)]
+ npcs_part = npcs_pred[np.where(part_mask == True)]
+
+ coord_part=np.array(coord_part)
+ npcs_part=np.array(npcs_part)
+
+ Scales, Rotation, Translation, OutTransform, BestInlierIdx = estimateSimilarityTransform(npcs_part,coord_part)
+ scale_pred = np.max(abs(npcs_part[BestInlierIdx]), axis=0)
+
+ bbox_raw = np.array([
+ [-scale_pred[0], -scale_pred[1], -scale_pred[2]],
+ [scale_pred[0], -scale_pred[1], -scale_pred[2]],
+ [-scale_pred[0], scale_pred[1], -scale_pred[2]],
+ [-scale_pred[0], -scale_pred[1], scale_pred[2]],
+ [scale_pred[0], scale_pred[1], -scale_pred[2]],
+
+ [scale_pred[0], -scale_pred[1], scale_pred[2]],
+ [-scale_pred[0], scale_pred[1], scale_pred[2]],
+ [scale_pred[0], scale_pred[1], scale_pred[2]]
+ ])
+ bbox_trans = np.dot((bbox_raw * Scales[0]) , Rotation) + Translation
+
+ return bbox_trans, Scales, Rotation, Translation
+
+def estimate3dbbox_part(coord_part, npcs_pred):
+ Scales, Rotation, Translation, OutTransform, BestInlierIdx = estimateSimilarityTransform(npcs_pred,coord_part)
+ if Rotation.all()==None:
+ return np.array([None,None,None,]),None,None,None
+ scale_pred = np.max(abs(npcs_pred[BestInlierIdx]), axis=0)
+
+ bbox_raw = np.array([
+ [-scale_pred[0], -scale_pred[1], -scale_pred[2]],
+ [scale_pred[0], -scale_pred[1], -scale_pred[2]],
+ [-scale_pred[0], scale_pred[1], -scale_pred[2]],
+ [-scale_pred[0], -scale_pred[1], scale_pred[2]],
+ [scale_pred[0], scale_pred[1], -scale_pred[2]],
+
+ [scale_pred[0], -scale_pred[1], scale_pred[2]],
+ [-scale_pred[0], scale_pred[1], scale_pred[2]],
+ [scale_pred[0], scale_pred[1], scale_pred[2]]
+ ])
+ bbox_trans = np.dot((bbox_raw * Scales[0]) , Rotation) + Translation
+
+ return bbox_trans, Scales, Rotation, Translation
\ No newline at end of file
diff --git a/gapartnet/gapartnet/metrics/segmentation.py b/gapartnet/gapartnet/metrics/segmentation.py
new file mode 100644
index 0000000..59db2df
--- /dev/null
+++ b/gapartnet/gapartnet/metrics/segmentation.py
@@ -0,0 +1,29 @@
+import torch
+from kornia.metrics import mean_iou as _mean_iou
+
+
+@torch.no_grad()
+def pixel_accuracy(pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> float:
+ """
+ Compute pixel accuracy.
+ """
+
+ if gt_mask.numel() > 0:
+ accuracy = (pred_mask == gt_mask).sum() / gt_mask.numel()
+ accuracy = accuracy.item()
+ else:
+ accuracy = 0.
+ return accuracy
+
+
+@torch.no_grad()
+def mean_iou(pred_mask: torch.Tensor, gt_mask: torch.Tensor, num_classes: int) -> float:
+ """
+ Compute mIoU.
+ """
+
+ valid_mask = gt_mask >= 0
+ miou = _mean_iou(
+ pred_mask[valid_mask][None], gt_mask[valid_mask][None], num_classes=num_classes
+ ).mean()
+ return miou
diff --git a/gapartnet/gapartnet/models/__init__.py b/gapartnet/gapartnet/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/gapartnet/gapartnet/models/gapartnet.py b/gapartnet/gapartnet/models/gapartnet.py
new file mode 100644
index 0000000..cde8baf
--- /dev/null
+++ b/gapartnet/gapartnet/models/gapartnet.py
@@ -0,0 +1,811 @@
+import functools
+from typing import List, Optional, Tuple
+from unittest import result
+
+import numpy as np
+import pytorch_lightning as pl
+import spconv.pytorch as spconv
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from epic_ops.iou import batch_instance_seg_iou
+from epic_ops.reduce import segmented_maxpool
+from gapartnet.metrics.segmentation import mean_iou, pixel_accuracy
+from gapartnet.structures.instances import Instances, Result
+from gapartnet.structures.point_cloud import PointCloud
+from gapartnet.structures.segmentation import Segmentation
+from gapartnet.utils.symmetry_matrix import get_symmetry_matrix
+from gapartnet.losses.focal_loss import focal_loss
+from gapartnet.losses.dice_loss import dice_loss
+
+from .gapartnet_utils import (apply_nms, cluster_proposals, compute_ap,
+ compute_npcs_loss, filter_invalid_proposals,
+ get_gt_scores, segmented_voxelize)
+from .sparse_unet import SparseUNet
+from torch.autograd import Function
+from .util_net import ReverseLayerF, Discriminator
+from gapartnet.utils.info import OBJECT_NAME2ID, PART_ID2NAME, PART_NAME2ID
+
+class InsSeg(pl.LightningModule):
+ def __init__(
+ self,
+ in_channels: int,
+ num_classes: int,
+ num_obj_cats: int,
+ channels: List[int],
+ block_repeat: int = 2,
+ learning_rate: float = 1e-3,
+ ignore_sem_label: int = -100,
+ ignore_instance_label: int = -100,
+ ball_query_radius: float = 0.03,
+ max_num_points_per_query: int = 50,
+ max_num_points_per_query_shift: int = 300,
+ min_num_points_per_proposal: int = 50,
+ score_net_start_at: int = 100,
+ score_fullscale: float = 14,
+ score_scale: float = 50,
+ npcs_net_start_at: int = 100,
+ symmetry_indices: Optional[List[int]] = None,
+ pretrained_model_path: Optional[str] = None,
+ loss_sem_seg_weight: Optional[List[float]] = None,
+ loss_prop_cls_weight: Optional[List[float]] = None,
+ use_focal_loss: bool = False,
+ use_dice_loss: bool = False,
+ val_score_threshold: float = 0.09,
+ val_min_num_points_per_proposal: int = 3,
+ val_nms_iou_threshold: float = 0.3,
+ val_ap_iou_threshold: float = 0.5,
+ # cls_global_weight: float = 0.1,
+ # cls_local_weight: float = 0.1,
+ # cls_start_at: int = 50,
+ reverse: bool = True,
+ alpha: float = 0.5,
+ discrimination_score_thresh: float = 0.3,
+ discrimination_use_score: bool = True,
+ ckpt: str = None,
+ ):
+ super().__init__()
+ self.save_hyperparameters()
+
+ self.in_channels = in_channels
+ self.num_classes = num_classes
+ self.num_obj_cats = num_obj_cats
+ self.channels = channels
+ self.reverse = reverse
+ self.alpha = alpha
+ self.discrimination_score_thresh = discrimination_score_thresh
+ self.discrimination_use_score = discrimination_use_score
+
+ norm_fn = functools.partial(nn.BatchNorm1d, eps=1e-4, momentum=0.1)
+
+ self.unet = SparseUNet.build(in_channels, channels, block_repeat, norm_fn)
+ self.sem_seg_head = nn.Linear(channels[0], num_classes)
+ self.offset_head = nn.Sequential(
+ nn.Linear(channels[0], channels[0]),
+ norm_fn(channels[0]),
+ nn.ReLU(inplace=True),
+ nn.Linear(channels[0], 3),
+ )
+
+ self.score_unet = SparseUNet.build(
+ channels[0], channels[:2], block_repeat, norm_fn, without_stem=True
+ )
+ self.score_head = nn.Linear(channels[0], num_classes - 1)
+
+ self.npcs_unet = SparseUNet.build(
+ channels[0], channels[:2], block_repeat, norm_fn, without_stem=True
+ )
+ self.npcs_head = nn.Linear(channels[0], 3 * (num_classes - 1))
+
+ self.learning_rate = learning_rate
+ self.ignore_sem_label = ignore_sem_label
+ self.ignore_instance_label = ignore_instance_label
+ self.ball_query_radius = ball_query_radius
+ self.max_num_points_per_query = max_num_points_per_query
+ self.max_num_points_per_query_shift = max_num_points_per_query_shift
+ self.min_num_points_per_proposal = min_num_points_per_proposal
+
+ self.score_net_start_at = score_net_start_at
+ self.score_fullscale = score_fullscale
+ self.score_scale = score_scale
+
+ self.npcs_net_start_at = npcs_net_start_at
+ self.register_buffer(
+ "symmetry_indices", torch.as_tensor(symmetry_indices, dtype=torch.int64)
+ )
+ if symmetry_indices is not None:
+ assert len(symmetry_indices) == num_classes, (symmetry_indices, num_classes)
+
+ (
+ symmetry_matrix_1, symmetry_matrix_2, symmetry_matrix_3
+ ) = get_symmetry_matrix()
+ self.register_buffer("symmetry_matrix_1", symmetry_matrix_1)
+ self.register_buffer("symmetry_matrix_2", symmetry_matrix_2)
+ self.register_buffer("symmetry_matrix_3", symmetry_matrix_3)
+
+ if pretrained_model_path is not None:
+ print("Loading pretrained model from:", pretrained_model_path)
+ state_dict = torch.load(
+ pretrained_model_path, map_location="cpu"
+ )["state_dict"]
+ missing_keys, unexpected_keys = self.load_state_dict(
+ state_dict, strict=False,
+ )
+ if len(missing_keys) > 0:
+ print("missing_keys:", missing_keys)
+ if len(unexpected_keys) > 0:
+ print("unexpected_keys:", unexpected_keys)
+
+ if loss_sem_seg_weight is None:
+ self.loss_sem_seg_weight = loss_sem_seg_weight
+ else:
+ assert len(loss_sem_seg_weight) == num_classes
+ self.register_buffer(
+ "loss_sem_seg_weight",
+ torch.as_tensor(loss_sem_seg_weight, dtype=torch.float32),
+ persistent=False,
+ )
+ if loss_prop_cls_weight is None:
+ self.loss_prop_cls_weight = loss_prop_cls_weight
+ else:
+ assert len(loss_prop_cls_weight) == num_obj_cats
+ self.register_buffer(
+ "loss_prop_cls_weight",
+ torch.as_tensor(loss_prop_cls_weight, dtype=torch.float32),
+ persistent=False,
+ )
+ self.use_focal_loss = use_focal_loss
+ self.use_dice_loss = use_dice_loss
+
+ self.cluster_proposals_start_at = min(
+ self.score_net_start_at, self.npcs_net_start_at
+ )
+
+ self.val_score_threshold = val_score_threshold
+ self.val_min_num_points_per_proposal = val_min_num_points_per_proposal
+ self.val_nms_iou_threshold = val_nms_iou_threshold
+ self.val_ap_iou_threshold = val_ap_iou_threshold
+ self.ckpt = ckpt
+
+
+ def forward_sem_seg(
+ self,
+ voxel_tensor: spconv.SparseConvTensor,
+ pc_voxel_id: torch.Tensor,
+ ) -> Tuple[spconv.SparseConvTensor, torch.Tensor, torch.Tensor]:
+ voxel_features = self.unet(voxel_tensor)
+ sem_logits = self.sem_seg_head(voxel_features.features)
+
+ pt_features = voxel_features.features[pc_voxel_id]
+ sem_logits = sem_logits[pc_voxel_id]
+
+ return voxel_features, pt_features, sem_logits
+
+ def loss_sem_seg(
+ self,
+ sem_logits: torch.Tensor,
+ sem_labels: torch.Tensor,
+ ) -> torch.Tensor:
+ if self.use_focal_loss:
+ loss = focal_loss(
+ sem_logits, sem_labels,
+ alpha=self.loss_sem_seg_weight,
+ gamma=2.0,
+ ignore_index=self.ignore_sem_label,
+ reduction="mean",
+ )
+ else:
+ loss = F.cross_entropy(
+ sem_logits, sem_labels,
+ weight=self.loss_sem_seg_weight,
+ ignore_index=self.ignore_sem_label,
+ reduction="mean",
+ )
+
+ if self.use_dice_loss:
+ loss += dice_loss(
+ sem_logits[:, :, None, None], sem_labels[:, None, None],
+ )
+
+ return loss
+
+ def forward_pt_offset(
+ self,
+ voxel_features: spconv.SparseConvTensor,
+ pc_voxel_id: torch.Tensor,
+ ) -> torch.Tensor:
+ pt_offsets = self.offset_head(voxel_features.features)
+ return pt_offsets[pc_voxel_id]
+
+ def loss_pt_offset(
+ self,
+ pt_offsets: torch.Tensor,
+ gt_pt_offsets: torch.Tensor,
+ sem_labels: torch.Tensor,
+ instance_labels: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ valid_instance_mask = (sem_labels > 0) & (instance_labels >= 0)
+
+ pt_diff = pt_offsets - gt_pt_offsets
+ pt_dist = torch.sum(pt_diff.abs(), dim=-1)
+ loss_pt_offset_dist = pt_dist[valid_instance_mask].mean()
+
+ gt_pt_offsets_norm = torch.norm(gt_pt_offsets, p=2, dim=-1)
+ gt_pt_offsets = gt_pt_offsets / (gt_pt_offsets_norm[:, None] + 1e-8)
+
+ pt_offsets_norm = torch.norm(pt_offsets, p=2, dim=-1)
+ pt_offsets = pt_offsets / (pt_offsets_norm[:, None] + 1e-8)
+
+ dir_diff = -(gt_pt_offsets * pt_offsets).sum(-1)
+ loss_pt_offset_dir = dir_diff[valid_instance_mask].mean()
+
+ return loss_pt_offset_dist, loss_pt_offset_dir
+
+ def cluster_proposals_and_revoxelize(
+ self,
+ pt_xyz: torch.Tensor,
+ batch_indices: torch.Tensor,
+ pt_features: torch.Tensor,
+ sem_preds: torch.Tensor,
+ pt_offsets: torch.Tensor,
+ instance_labels: Optional[torch.Tensor],
+ cls_labels: Optional[torch.Tensor],
+ ):
+ device = pt_xyz.device
+
+ # get rid of stuff classes (e.g. wall)
+ if instance_labels is not None:
+ valid_mask = (sem_preds > 0) & (instance_labels >= 0)
+ else:
+ valid_mask = sem_preds > 0
+
+ pt_xyz = pt_xyz[valid_mask]
+ if pt_xyz.shape[0] == 0:
+ return None, None, None
+
+ batch_indices = batch_indices[valid_mask]
+ pt_features = pt_features[valid_mask]
+ sem_preds = sem_preds[valid_mask].int()
+ pt_offsets = pt_offsets[valid_mask]
+ if instance_labels is not None:
+ instance_labels = instance_labels[valid_mask]
+
+ # get batch offsets (csr) from batch indices
+ _, batch_indices_compact, num_points_per_batch = torch.unique_consecutive(
+ batch_indices, return_inverse=True, return_counts=True
+ )
+ batch_indices_compact = batch_indices_compact.int()
+ batch_offsets = torch.zeros(
+ (num_points_per_batch.shape[0] + 1,), dtype=torch.int32, device=device
+ )
+ batch_offsets[1:] = num_points_per_batch.cumsum(0)
+
+ # cluster proposals
+ sorted_cc_labels, sorted_indices = cluster_proposals(
+ pt_xyz, batch_indices_compact, batch_offsets, sem_preds,
+ self.ball_query_radius, self.max_num_points_per_query,
+ )
+
+ sorted_cc_labels_shift, sorted_indices_shift = cluster_proposals(
+ pt_xyz + pt_offsets, batch_indices_compact, batch_offsets, sem_preds,
+ self.ball_query_radius, self.max_num_points_per_query_shift,
+ )
+
+ # combine clusters
+ sorted_cc_labels = torch.cat([
+ sorted_cc_labels,
+ sorted_cc_labels_shift + sorted_cc_labels.shape[0],
+ ], dim=0)
+ sorted_indices = torch.cat([sorted_indices, sorted_indices_shift], dim=0)
+
+ # compact the proposal ids
+ _, proposal_indices, num_points_per_proposal = torch.unique_consecutive(
+ sorted_cc_labels, return_inverse=True, return_counts=True
+ )
+
+ # remove small proposals
+ valid_proposal_mask = (
+ num_points_per_proposal >= self.min_num_points_per_proposal
+ )
+ # proposal to point
+ valid_point_mask = valid_proposal_mask[proposal_indices]
+
+ sorted_indices = sorted_indices[valid_point_mask]
+ if sorted_indices.shape[0] == 0:
+ return None, None, None
+
+ batch_indices = batch_indices[sorted_indices]
+ pt_xyz = pt_xyz[sorted_indices]
+ pt_features = pt_features[sorted_indices]
+ sem_preds = sem_preds[sorted_indices]
+ if instance_labels is not None:
+ instance_labels = instance_labels[sorted_indices]
+
+ # re-compact the proposal ids
+ proposal_indices = proposal_indices[valid_point_mask]
+ _, proposal_indices, num_points_per_proposal = torch.unique_consecutive(
+ proposal_indices, return_inverse=True, return_counts=True
+ )
+ num_proposals = num_points_per_proposal.shape[0]
+
+ # get proposal batch offsets
+ proposal_offsets = torch.zeros(
+ num_proposals + 1, dtype=torch.int32, device=device
+ )
+ proposal_offsets[1:] = num_points_per_proposal.cumsum(0)
+
+ # voxelization
+ voxel_features, voxel_coords, pc_voxel_id = segmented_voxelize(
+ pt_xyz, pt_features,
+ proposal_offsets, proposal_indices,
+ num_points_per_proposal,
+ self.score_fullscale, self.score_scale,
+ )
+ voxel_tensor = spconv.SparseConvTensor(
+ voxel_features, voxel_coords.int(),
+ spatial_shape=[self.score_fullscale] * 3,
+ batch_size=num_proposals,
+ )
+ assert (pc_voxel_id >= 0).all()
+ cls_labels = cls_labels.to(batch_indices.device)
+ proposal_cls_labels = cls_labels[batch_indices.long()]
+ proposals = Instances(
+ valid_mask=valid_mask,
+ sorted_indices=sorted_indices,
+ pt_xyz=pt_xyz,
+ batch_indices=batch_indices,
+ proposal_offsets=proposal_offsets,
+ proposal_indices=proposal_indices,
+ num_points_per_proposal=num_points_per_proposal,
+ sem_preds=sem_preds,
+ instance_labels=instance_labels,
+ cls_labels=proposal_cls_labels,
+ )
+
+ return voxel_tensor, pc_voxel_id, proposals
+
+ def forward_proposal_score(
+ self,
+ voxel_tensor: spconv.SparseConvTensor,
+ pc_voxel_id: torch.Tensor,
+ proposals: Instances,
+ ):
+ proposal_offsets = proposals.proposal_offsets
+ proposal_offsets_begin = proposal_offsets[:-1]
+ proposal_offsets_end = proposal_offsets[1:]
+
+ score_features = self.score_unet(voxel_tensor)
+ score_features = score_features.features[pc_voxel_id]
+ pooled_score_features, _ = segmented_maxpool(
+ score_features, proposal_offsets_begin, proposal_offsets_end
+ )
+ score_logits = self.score_head(pooled_score_features)
+
+ return score_logits
+
+ def loss_proposal_score(
+ self,
+ score_logits: torch.Tensor,
+ proposals: Instances,
+ num_points_per_instance: torch.Tensor,
+ ) -> torch.Tensor:
+ ious = batch_instance_seg_iou(
+ proposals.proposal_offsets,
+ proposals.instance_labels,
+ proposals.batch_indices,
+ num_points_per_instance,
+ )
+ proposals.ious = ious
+ proposals.num_points_per_instance = num_points_per_instance
+
+ ious_max = ious.max(-1)[0]
+ gt_scores = get_gt_scores(ious_max, 0.75, 0.25)
+
+ return F.binary_cross_entropy_with_logits(score_logits, gt_scores)
+
+ def forward_proposal_npcs(
+ self,
+ voxel_tensor: spconv.SparseConvTensor,
+ pc_voxel_id: torch.Tensor,
+ ) -> torch.Tensor:
+ npcs_features = self.npcs_unet(voxel_tensor)
+ npcs_logits = self.npcs_head(npcs_features.features)
+ npcs_logits = npcs_logits[pc_voxel_id]
+
+ return npcs_logits
+
+ def loss_proposal_npcs(
+ self,
+ npcs_logits: torch.Tensor,
+ gt_npcs: torch.Tensor,
+ proposals: Instances,
+ ) -> torch.Tensor:
+ sem_preds, sem_labels = proposals.sem_preds, proposals.sem_labels
+ proposal_indices = proposals.proposal_indices
+ valid_mask = (sem_preds == sem_labels) & (gt_npcs != 0).any(dim=-1)
+
+ npcs_logits = npcs_logits[valid_mask]
+ gt_npcs = gt_npcs[valid_mask]
+ sem_preds = sem_preds[valid_mask].long()
+ sem_labels = sem_labels[valid_mask]
+ proposal_indices = proposal_indices[valid_mask]
+
+ npcs_logits = rearrange(npcs_logits, "n (k c) -> n k c", c=3)
+ npcs_logits = npcs_logits.gather(
+ 1, index=repeat(sem_preds - 1, "n -> n one c", one=1, c=3)
+ ).squeeze(1)
+
+ proposals.npcs_preds = npcs_logits.detach()
+ proposals.gt_npcs = gt_npcs
+ proposals.npcs_valid_mask = valid_mask
+
+ loss_npcs = 0
+
+ symmetry_indices = self.symmetry_indices[sem_preds]
+ # group #1
+ group_1_mask = symmetry_indices < 3
+ symmetry_indices_1 = symmetry_indices[group_1_mask]
+ if symmetry_indices_1.shape[0] > 0:
+ loss_npcs += compute_npcs_loss(
+ npcs_logits[group_1_mask], gt_npcs[group_1_mask],
+ proposal_indices[group_1_mask],
+ self.symmetry_matrix_1[symmetry_indices_1]
+ )
+
+ # group #2
+ group_2_mask = symmetry_indices == 3
+ symmetry_indices_2 = symmetry_indices[group_2_mask]
+ if symmetry_indices_2.shape[0] > 0:
+ loss_npcs += compute_npcs_loss(
+ npcs_logits[group_2_mask], gt_npcs[group_2_mask],
+ proposal_indices[group_2_mask],
+ self.symmetry_matrix_2[symmetry_indices_2 - 3]
+ )
+
+ # group #3
+ group_3_mask = symmetry_indices == 4
+ symmetry_indices_3 = symmetry_indices[group_3_mask]
+ if symmetry_indices_3.shape[0] > 0:
+ loss_npcs += compute_npcs_loss(
+ npcs_logits[group_3_mask], gt_npcs[group_3_mask],
+ proposal_indices[group_3_mask],
+ self.symmetry_matrix_3[symmetry_indices_3 - 4]
+ )
+
+ return loss_npcs
+
+ def _training_or_validation_step(
+ self,
+ point_clouds: List[PointCloud],
+ batch_idx: int,
+ running_mode: str,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ batch_size = len(point_clouds)
+ (
+ scene_ids, cls_labels, num_points, points, batch_indices, sem_labels, instance_labels, gt_npcs,
+ num_instances, instance_regions, num_points_per_instance, instance_sem_labels,
+ voxel_tensor, pc_voxel_id,
+ ) = PointCloud.collate(point_clouds)
+
+ pt_xyz = points[:, :3]
+ cls_labels.to(pt_xyz.device)
+ assert (pc_voxel_id >= 0).all()
+
+
+ # semantic segmentation
+ voxel_features, pt_features, sem_logits = self.forward_sem_seg(voxel_tensor, pc_voxel_id)
+
+ sem_preds = torch.argmax(sem_logits.detach(), dim=-1)
+
+ if sem_labels is not None:
+ loss_sem_seg = self.loss_sem_seg(sem_logits, sem_labels)
+ else:
+ loss_sem_seg = 0.
+
+ sem_seg = Segmentation(batch_size=batch_size,num_points=num_points,sem_preds=sem_preds,sem_labels=sem_labels,)
+
+ # point offset
+ pt_offsets = self.forward_pt_offset(voxel_features, pc_voxel_id)
+ if instance_regions is not None:
+ gt_pt_offsets = instance_regions[:, :3] - pt_xyz
+ loss_pt_offset_dist, loss_pt_offset_dir = self.loss_pt_offset(
+ pt_offsets, gt_pt_offsets, sem_labels, instance_labels,
+ )
+ else:
+ loss_pt_offset_dist, loss_pt_offset_dir = 0., 0.
+
+ if self.current_epoch >= self.cluster_proposals_start_at:
+ (
+ voxel_tensor, pc_voxel_id, proposals
+ ) = self.cluster_proposals_and_revoxelize(
+ pt_xyz, batch_indices, pt_features,
+ sem_preds, pt_offsets, instance_labels, cls_labels
+ )
+ if sem_labels is not None and proposals is not None:
+ proposals.sem_labels = sem_labels[proposals.valid_mask][
+ proposals.sorted_indices
+ ]
+ if proposals is not None:
+ proposals.instance_sem_labels = instance_sem_labels
+ else:
+ voxel_tensor, pc_voxel_id, proposals = None, None, None
+
+ # clustering and scoring
+ if self.current_epoch >= self.score_net_start_at and voxel_tensor is not None and proposals is not None:
+ score_logits = self.forward_proposal_score(
+ voxel_tensor, pc_voxel_id, proposals
+ )
+ proposal_offsets_begin = proposals.proposal_offsets[:-1].long()
+
+ if proposals.sem_labels is not None:
+ proposal_sem_labels = proposals.sem_labels[proposal_offsets_begin].long()
+ else:
+ proposal_sem_labels = proposals.sem_preds[proposal_offsets_begin].long()
+ score_logits = score_logits.gather(
+ 1, proposal_sem_labels[:, None] - 1
+ ).squeeze(1)
+ proposals.score_preds = score_logits.detach().sigmoid()
+ if num_points_per_instance is not None:
+ loss_prop_score = self.loss_proposal_score(
+ score_logits, proposals, num_points_per_instance,
+ )
+ else:
+ loss_prop_score = 0.0
+
+ else:
+ loss_prop_score = 0.0
+
+ if self.current_epoch >= self.npcs_net_start_at and voxel_tensor is not None:
+ npcs_logits = self.forward_proposal_npcs(
+ voxel_tensor, pc_voxel_id
+ )
+ if gt_npcs is not None:
+ gt_npcs = gt_npcs[proposals.valid_mask][proposals.sorted_indices]
+ loss_prop_npcs = self.loss_proposal_npcs(npcs_logits, gt_npcs, proposals)
+ else:
+ npcs_preds = npcs_logits.detach()
+ npcs_preds = rearrange(npcs_preds, "n (k c) -> n k c", c=3)
+ npcs_preds = npcs_preds.gather(
+ 1, index=repeat(proposals.sem_preds.long() - 1, "n -> n one c", one=1, c=3)
+ ).squeeze(1)
+ proposals.npcs_preds = npcs_preds
+ loss_prop_npcs = 0.0
+ else:
+ loss_prop_npcs = 0.0
+
+
+ # total loss
+ loss = loss_sem_seg + loss_pt_offset_dist + loss_pt_offset_dir
+ loss += loss_prop_score + loss_prop_npcs #+ self.cls_local_weight * loss_local_cls + self.cls_global_weight * loss_global_cls
+
+ if sem_labels is not None:
+ instance_mask = sem_labels > 0
+ pixel_acc = pixel_accuracy(sem_preds[instance_mask], sem_labels[instance_mask])
+ else:
+ pixel_acc = 0.0
+
+ prefix = running_mode
+ self.log(f"{prefix}/total_loss", loss, batch_size=batch_size)
+ self.log(
+ f"{prefix}/loss_sem_seg",
+ loss_sem_seg,
+ batch_size=batch_size,
+ )
+ self.log(
+ f"{prefix}/loss_pt_offset_dist",
+ loss_pt_offset_dist,
+ batch_size=batch_size,
+ )
+ self.log(
+ f"{prefix}/loss_pt_offset_dir",
+ loss_pt_offset_dir,
+ batch_size=batch_size,
+ )
+ self.log(
+ f"{prefix}/loss_prop_score",
+ loss_prop_score,
+ batch_size=batch_size,
+ )
+ self.log(
+ f"{prefix}/loss_prop_npcs",
+ loss_prop_npcs,
+ batch_size=batch_size,
+ )
+ self.log(
+ f"{prefix}/pixel_acc",
+ pixel_acc * 100,
+ batch_size=batch_size,
+ )
+
+ return scene_ids, sem_seg, proposals, loss
+
+ def training_step(self, point_clouds: List[PointCloud], batch_idx: int):
+ _, _, _, loss = self._training_or_validation_step(
+ point_clouds, batch_idx, "train"
+ )
+
+ return loss
+
+ def validation_step(self, point_clouds: List[PointCloud], batch_idx: int, dataloader_idx: int = 0):
+ split = ["val", "intra", "inter"]
+ scene_ids, sem_seg, proposals, _ = self._training_or_validation_step(
+ point_clouds, batch_idx, split[dataloader_idx]
+ )
+
+ if proposals is not None:
+ proposals = filter_invalid_proposals(
+ proposals,
+ score_threshold=self.val_score_threshold,
+ min_num_points_per_proposal=self.val_min_num_points_per_proposal
+ )
+ proposals = apply_nms(proposals, self.val_nms_iou_threshold)
+
+ if proposals != None:
+ proposals.pt_sem_classes = proposals.sem_preds[proposals.proposal_offsets[:-1].long()]
+ proposals.valid_mask = None
+ proposals.pt_xyz = None
+ proposals.sem_preds = None
+ proposals.npcs_preds = None
+ proposals.sem_labels = None
+ proposals.npcs_valid_mask = None
+ proposals.gt_npcs = None
+ return scene_ids, sem_seg, proposals
+
+ def validation_epoch_end(self, validation_step_outputs_list):
+ splits = ["val", "intra", "inter"]
+ for i_, validation_step_outputs in enumerate(validation_step_outputs_list):
+ split = splits[i_]
+
+ batch_size = sum(x[1].batch_size for x in validation_step_outputs)
+
+ proposals = [x[2] for x in validation_step_outputs]
+ # torch.save(validation_step_outputs, "wandb/predictions_gap.pth")
+ del validation_step_outputs
+
+ # miou = mean_iou(sem_preds, sem_labels, num_classes=self.num_classes)
+ # self.log(f"{split}/mean_iou", miou * 100, batch_size=batch_size)
+
+
+ if proposals[0] is not None:
+
+ aps = compute_ap(proposals, self.num_classes, self.val_ap_iou_threshold)
+
+ for class_idx in range(1, self.num_classes):
+ partname = PART_ID2NAME[class_idx]
+ self.log(
+ f"{split}/AP@50_{partname}",
+ aps[class_idx - 1] * 100,
+ batch_size=batch_size,
+
+ )
+ self.log(f"{split}/AP@50", np.mean(aps) * 100, )
+ else:
+ self.log(f"{split}/AP@50", 0.0, )
+
+ def test_step(self, point_clouds: List[PointCloud], batch_idx: int, dataloader_idx: int = 0):
+ split = ["val", "intra", "inter"]
+ scene_ids, sem_seg, proposals, _ = self._training_or_validation_step(
+ point_clouds, batch_idx, split[dataloader_idx]
+ )
+
+ if proposals is not None:
+ proposals = filter_invalid_proposals(
+ proposals,
+ score_threshold=self.val_score_threshold,
+ min_num_points_per_proposal=self.val_min_num_points_per_proposal
+ )
+ proposals = apply_nms(proposals, self.val_nms_iou_threshold)
+
+ if proposals != None:
+ proposals.pt_sem_classes = proposals.sem_preds[proposals.proposal_offsets[:-1].long()]
+ proposals.valid_mask = None
+ proposals.pt_xyz = None
+ proposals.sem_preds = None
+ proposals.npcs_preds = None
+ proposals.sem_labels = None
+ proposals.npcs_valid_mask = None
+ proposals.gt_npcs = None
+
+ return scene_ids, sem_seg, proposals
+
+ def test_epoch_end(self, validation_step_outputs_list):
+ splits = ["val", "intra", "inter"]
+ for i_, validation_step_outputs in enumerate(validation_step_outputs_list):
+ split = splits[i_]
+
+ batch_size = sum(x[1].batch_size for x in validation_step_outputs)
+
+ proposals = [x[2] for x in validation_step_outputs] # if x[2] != None
+
+ del validation_step_outputs
+
+ if proposals[0] is not None:
+
+ aps = compute_ap(proposals, self.num_classes, self.val_ap_iou_threshold)
+
+ for class_idx in range(1, self.num_classes):
+ partname = PART_ID2NAME[class_idx]
+ self.log(
+ f"{split}/AP@50_{partname}",
+ aps[class_idx - 1] * 100,
+ batch_size=batch_size,
+
+ )
+ self.log(f"{split}/AP@50", np.mean(aps) * 100, )
+
+ def test_step_(self, point_clouds: List[PointCloud], batch_idx: int, dataloader_idx: int):
+ split = ["val", "intra", "inter"]
+ scene_ids, sem_seg, proposals, _ = self._training_or_validation_step(
+ point_clouds, batch_idx, split[dataloader_idx]
+ )
+
+ if proposals is not None:
+ proposals = filter_invalid_proposals(
+ proposals,
+ score_threshold=self.val_score_threshold,
+ min_num_points_per_proposal=self.val_min_num_points_per_proposal
+ )
+ proposals = apply_nms(proposals, self.val_nms_iou_threshold)
+
+ if proposals != None:
+ instance_label = torch.zeros_like(sem_seg.sem_preds, device = sem_seg.sem_preds.device)
+ instance_label[proposals.valid_mask][proposals.sorted_indices.long()] = proposals.proposal_indices + 1
+ npcs_map = torch.zeros_like(point_clouds[0].points[:,:3], device = sem_seg.sem_preds.device)
+ npcs_map[proposals.valid_mask][proposals.sorted_indices.long()][proposals.npcs_valid_mask] = proposals.npcs_preds
+ result = Result(xyz = point_clouds[0].points[:,:3], rgb=point_clouds[0].points[:,3:], sem_preds=sem_seg.sem_preds, ins_preds = instance_label, npcs_preds=npcs_map)
+ else:
+ instance_label = torch.zeros_like(sem_seg.sem_preds, device = sem_seg.sem_preds.device)
+ npcs_map = torch.zeros_like(point_clouds[0].points[:,:3], device = sem_seg.sem_preds.device)
+ result = Result(xyz = point_clouds[0].points[:,:3], rgb=point_clouds[0].points[:,3:], sem_preds=sem_seg.sem_preds, ins_preds = instance_label, npcs_preds=npcs_map)
+
+
+ return scene_ids, sem_seg, result
+
+ def test_epoch_end_(self, validation_step_outputs_list):
+ splits = ["val", "intra", "inter"]
+ for i_, validation_step_outputs in enumerate(validation_step_outputs_list):
+ split = splits[i_]
+ import pdb
+ pdb.set_trace()
+
+ batch_size = sum(x[1].batch_size for x in validation_step_outputs)
+
+ proposals = [x[2] for x in validation_step_outputs]
+ ids = [y for x in validation_step_outputs for y in x[0]]
+ results = [y for x in validation_step_outputs for y in x[2]]
+ del validation_step_outputs
+
+
+ if proposals[0] is not None:
+ aps = compute_ap(proposals, self.num_classes, self.val_ap_iou_threshold)
+
+ for class_idx in range(1, self.num_classes):
+ partname = PART_ID2NAME[class_idx]
+ self.log(
+ f"{split}/AP@50_{partname}",
+ aps[class_idx - 1] * 100,
+ batch_size=batch_size,
+
+ )
+ self.log(f"{split}/AP@50", np.mean(aps) * 100, )
+
+ def forward(self, point_clouds: List[PointCloud]):
+ scene_ids, sem_seg, proposals, _ = self._training_or_validation_step(
+ point_clouds, 0, "val"
+ )
+
+ if proposals is not None:
+ proposals = filter_invalid_proposals(
+ proposals,
+ score_threshold=self.val_score_threshold,
+ min_num_points_per_proposal=self.val_min_num_points_per_proposal
+ )
+ proposals = apply_nms(proposals, self.val_nms_iou_threshold)
+ return scene_ids, sem_seg, proposals
+
+ def configure_optimizers(self):
+ return torch.optim.Adam(
+ filter(lambda p: p.requires_grad, self.parameters()),
+ lr=self.learning_rate,
+ )
diff --git a/gapartnet/gapartnet/models/gapartnet_utils.py b/gapartnet/gapartnet/models/gapartnet_utils.py
new file mode 100644
index 0000000..8b00e79
--- /dev/null
+++ b/gapartnet/gapartnet/models/gapartnet_utils.py
@@ -0,0 +1,465 @@
+from typing import List, Tuple
+
+import torch
+from epic_ops.ball_query import ball_query
+from epic_ops.ccl import connected_components_labeling
+from epic_ops.nms import nms
+from epic_ops.reduce import segmented_reduce
+from epic_ops.voxelize import voxelize
+
+from gapartnet.structures.instances import Instances
+
+
+@torch.jit.script
+def compute_npcs_loss(
+ npcs_preds: torch.Tensor,
+ gt_npcs: torch.Tensor,
+ proposal_indices: torch.Tensor,
+ symmetry_matrix: torch.Tensor,
+) -> torch.Tensor:
+ _, num_points_per_proposal = torch.unique_consecutive(
+ proposal_indices, return_counts=True
+ )
+
+ # gt_npcs: n, 3 -> n, 1, 1, 3
+ # symmetry_matrix: n, m, 3, 3
+ gt_npcs = gt_npcs[:, None, None, :] @ symmetry_matrix
+ # n, m, 1, 3 -> n, m, 3
+ gt_npcs = gt_npcs.squeeze(2)
+
+ # npcs_preds: n, 3 -> n, 1, 3
+ dist2 = (npcs_preds[:, None, :] - gt_npcs - 0.5) ** 2
+ # n, m, 3 -> n, m
+ dist2 = dist2.sum(dim=-1)
+
+ loss = torch.where(
+ dist2 <= 0.01,
+ 5 * dist2, torch.sqrt(dist2) - 0.05,
+ )
+ loss = torch.segment_reduce(
+ loss, "mean", lengths=num_points_per_proposal
+ )
+ loss, _ = loss.min(dim=-1)
+ return loss.mean()
+
+
+@torch.jit.script
+def segmented_voxelize(
+ pt_xyz: torch.Tensor,
+ pt_features: torch.Tensor,
+ segment_offsets: torch.Tensor,
+ segment_indices: torch.Tensor,
+ num_points_per_segment: torch.Tensor,
+ score_fullscale: float,
+ score_scale: float,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ segment_offsets_begin = segment_offsets[:-1]
+ segment_offsets_end = segment_offsets[1:]
+
+ segment_coords_mean = segmented_reduce(
+ pt_xyz, segment_offsets_begin, segment_offsets_end, mode="sum"
+ ) / num_points_per_segment[:, None]
+
+ centered_points = pt_xyz - segment_coords_mean[segment_indices]
+
+ segment_coords_min = segmented_reduce(
+ centered_points, segment_offsets_begin, segment_offsets_end, mode="min"
+ )
+ segment_coords_max = segmented_reduce(
+ centered_points, segment_offsets_begin, segment_offsets_end, mode="max"
+ )
+
+ score_fullscale = 28.
+ score_scale = 50.
+ segment_scales = 1. / (
+ (segment_coords_max - segment_coords_min) / score_fullscale
+ ).max(-1)[0] - 0.01
+ segment_scales = torch.clamp(segment_scales, min=None, max=score_scale)
+
+ min_xyz = segment_coords_min * segment_scales[..., None]
+ max_xyz = segment_coords_max * segment_scales[..., None]
+
+ segment_scales = segment_scales[segment_indices]
+ scaled_points = centered_points * segment_scales[..., None]
+
+ range_xyz = max_xyz - min_xyz
+ offsets = -min_xyz + torch.clamp(
+ score_fullscale - range_xyz - 0.001, min=0
+ ) * torch.rand(3, dtype=min_xyz.dtype, device=min_xyz.device) + torch.clamp(
+ score_fullscale - range_xyz + 0.001, max=0
+ ) * torch.rand(3, dtype=min_xyz.dtype, device=min_xyz.device)
+ scaled_points += offsets[segment_indices]
+
+ voxel_features, voxel_coords, voxel_batch_indices, pc_voxel_id = voxelize(
+ scaled_points,
+ pt_features,
+ batch_offsets=segment_offsets.long(),
+ voxel_size=torch.as_tensor([1., 1., 1.]),
+ points_range_min=torch.as_tensor([0., 0., 0.]),
+ points_range_max=torch.as_tensor([score_fullscale, score_fullscale, score_fullscale]),
+ reduction="mean",
+ )
+ voxel_coords = torch.cat([voxel_batch_indices[:, None], voxel_coords], dim=1)
+
+ return voxel_features, voxel_coords, pc_voxel_id
+
+
+@torch.jit.script
+def cluster_proposals(
+ pt_xyz: torch.Tensor,
+ batch_indices: torch.Tensor,
+ batch_offsets: torch.Tensor,
+ sem_preds: torch.Tensor,
+ ball_query_radius: float,
+ max_num_points_per_query: int,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ device = pt_xyz.device
+ index_dtype = batch_indices.dtype
+
+ clustered_indices, num_points_per_query = ball_query(
+ pt_xyz,
+ pt_xyz,
+ batch_indices,
+ batch_offsets,
+ ball_query_radius,
+ max_num_points_per_query,
+ point_labels=sem_preds,
+ query_labels=sem_preds,
+ )
+
+ ccl_indices_begin = torch.arange(
+ pt_xyz.shape[0], dtype=index_dtype, device=device
+ ) * max_num_points_per_query
+ ccl_indices_end = ccl_indices_begin + num_points_per_query
+ ccl_indices = torch.stack([ccl_indices_begin, ccl_indices_end], dim=1)
+ cc_labels = connected_components_labeling(
+ ccl_indices.view(-1), clustered_indices.view(-1), compacted=False
+ )
+
+ sorted_cc_labels, sorted_indices = torch.sort(cc_labels)
+ return sorted_cc_labels, sorted_indices
+
+
+@torch.jit.script
+def get_gt_scores(
+ ious: torch.Tensor, fg_thresh: float = 0.75, bg_thresh: float = 0.25
+) -> torch.Tensor:
+ fg_mask = ious > fg_thresh
+ bg_mask = ious < bg_thresh
+ intermidiate_mask = ~(fg_mask | bg_mask)
+
+ gt_scores = fg_mask.float()
+ k = 1 / (fg_thresh - bg_thresh)
+ b = bg_thresh / (bg_thresh - fg_thresh)
+ gt_scores[intermidiate_mask] = ious[intermidiate_mask] * k + b
+
+ return gt_scores
+
+
+def filter_invalid_proposals(
+ proposals: Instances,
+ score_threshold: float,
+ min_num_points_per_proposal: int,
+) -> Instances:
+ score_preds = proposals.score_preds
+ proposal_indices = proposals.proposal_indices
+ num_points_per_proposal = proposals.num_points_per_proposal
+
+ valid_proposals_mask = (
+ score_preds > score_threshold
+ ) & (num_points_per_proposal > min_num_points_per_proposal)
+ valid_points_mask = valid_proposals_mask[proposal_indices]
+
+ proposal_indices = proposal_indices[valid_points_mask]
+ _, proposal_indices, num_points_per_proposal = torch.unique_consecutive(
+ proposal_indices, return_inverse=True, return_counts=True
+ )
+ num_proposals = num_points_per_proposal.shape[0]
+
+ proposal_offsets = torch.zeros(
+ num_proposals + 1, dtype=torch.int32, device=proposal_indices.device
+ )
+ proposal_offsets[1:] = num_points_per_proposal.cumsum(0)
+
+ if proposals.npcs_valid_mask is not None:
+ valid_npcs_mask = valid_points_mask[proposals.npcs_valid_mask]
+ else:
+ valid_npcs_mask = valid_points_mask
+
+ return Instances(
+ valid_mask=proposals.valid_mask,
+ sorted_indices=proposals.sorted_indices[valid_points_mask],
+ pt_xyz=proposals.pt_xyz[valid_points_mask],
+ batch_indices=proposals.batch_indices[valid_points_mask],
+ proposal_offsets=proposal_offsets,
+ proposal_indices=proposal_indices,
+ num_points_per_proposal=num_points_per_proposal,
+ sem_preds=proposals.sem_preds[valid_points_mask],
+ score_preds=proposals.score_preds[valid_proposals_mask],
+ npcs_preds=proposals.npcs_preds[
+ valid_npcs_mask
+ ] if proposals.npcs_preds is not None else None,
+ sem_labels=proposals.sem_labels[
+ valid_points_mask
+ ] if proposals.sem_labels is not None else None,
+ instance_labels=proposals.instance_labels[
+ valid_points_mask
+ ] if proposals.instance_labels is not None else None,
+ instance_sem_labels=proposals.instance_sem_labels,
+ num_points_per_instance=proposals.num_points_per_instance,
+ gt_npcs=proposals.gt_npcs[
+ valid_npcs_mask
+ ] if proposals.gt_npcs is not None else None,
+ npcs_valid_mask=proposals.npcs_valid_mask[valid_points_mask] \
+ if proposals.npcs_valid_mask is not None else None,
+ ious=proposals.ious[
+ valid_proposals_mask
+ ] if proposals.ious is not None else None,
+ )
+
+
+def apply_nms(
+ proposals: Instances,
+ iou_threshold: float = 0.3,
+):
+ score_preds = proposals.score_preds
+ sorted_indices = proposals.sorted_indices
+ proposal_offsets = proposals.proposal_offsets
+ proposal_indices = proposals.proposal_indices
+ num_points_per_proposal = proposals.num_points_per_proposal
+
+ values = torch.ones(
+ sorted_indices.shape[0], dtype=torch.float32, device=sorted_indices.device
+ )
+ csr = torch.sparse_csr_tensor(
+ proposal_offsets.int(), sorted_indices.int(), values,
+ dtype=torch.float32, device=sorted_indices.device,
+ )
+ intersection = csr @ csr.t()
+ intersection = intersection.to_dense()
+ union = num_points_per_proposal[:, None] + num_points_per_proposal[None, :]
+ union = union - intersection
+
+ ious = intersection / (union + 1e-8)
+ keep = nms(ious.cuda(), score_preds.cuda(), iou_threshold)
+ keep = keep.to(score_preds.device)
+
+ valid_proposals_mask = torch.zeros(
+ ious.shape[0], dtype=torch.bool, device=score_preds.device
+ )
+ valid_proposals_mask[keep] = True
+ valid_points_mask = valid_proposals_mask[proposal_indices]
+
+ proposal_indices = proposal_indices[valid_points_mask]
+ _, proposal_indices, num_points_per_proposal = torch.unique_consecutive(
+ proposal_indices, return_inverse=True, return_counts=True
+ )
+ num_proposals = num_points_per_proposal.shape[0]
+
+ proposal_offsets = torch.zeros(
+ num_proposals + 1, dtype=torch.int32, device=proposal_indices.device
+ )
+ proposal_offsets[1:] = num_points_per_proposal.cumsum(0)
+
+ if proposals.npcs_valid_mask is not None:
+ valid_npcs_mask = valid_points_mask[proposals.npcs_valid_mask]
+ else:
+ valid_npcs_mask = valid_points_mask
+
+ return Instances(
+ valid_mask=proposals.valid_mask,
+ sorted_indices=proposals.sorted_indices[valid_points_mask],
+ pt_xyz=proposals.pt_xyz[valid_points_mask],
+ batch_indices=proposals.batch_indices[valid_points_mask],
+ proposal_offsets=proposal_offsets,
+ proposal_indices=proposal_indices,
+ num_points_per_proposal=num_points_per_proposal,
+ sem_preds=proposals.sem_preds[valid_points_mask],
+ score_preds=proposals.score_preds[valid_proposals_mask],
+ npcs_preds=proposals.npcs_preds[
+ valid_npcs_mask
+ ] if proposals.npcs_preds is not None else None,
+ sem_labels=proposals.sem_labels[
+ valid_points_mask
+ ] if proposals.sem_labels is not None else None,
+ instance_labels=proposals.instance_labels[
+ valid_points_mask
+ ] if proposals.instance_labels is not None else None,
+ instance_sem_labels=proposals.instance_sem_labels,
+ num_points_per_instance=proposals.num_points_per_instance,
+ gt_npcs=proposals.gt_npcs[
+ valid_npcs_mask
+ ] if proposals.gt_npcs is not None else None,
+ npcs_valid_mask=proposals.npcs_valid_mask[valid_points_mask] \
+ if proposals.npcs_valid_mask is not None else None,
+ ious=proposals.ious[
+ valid_proposals_mask
+ ] if proposals.ious is not None else None,
+ )
+
+
+@torch.jit.script
+def voc_ap(
+ rec: torch.Tensor,
+ prec: torch.Tensor,
+ use_07_metric: bool = False,
+) -> float:
+ if use_07_metric:
+ # 11 point metric
+ ap = torch.as_tensor(0, dtype=prec.dtype, device=prec.device)
+ for t in range(0, 11, 1):
+ t /= 10.0
+ if torch.sum(rec >= t) == 0:
+ p = torch.as_tensor(0, dtype=prec.dtype, device=prec.device)
+ else:
+ p = torch.max(prec[rec >= t])
+ ap = ap + p / 11.0
+ else:
+ # correct AP calculation
+ # first append sentinel values at the end
+ mrec = torch.cat([
+ torch.as_tensor([0.0], dtype=rec.dtype, device=rec.device),
+ rec,
+ torch.as_tensor([1.0], dtype=rec.dtype, device=rec.device),
+ ], dim=0)
+ mpre = torch.cat([
+ torch.as_tensor([0.0], dtype=prec.dtype, device=prec.device),
+ prec,
+ torch.as_tensor([0.0], dtype=prec.dtype, device=prec.device),
+ ], dim=0)
+
+ # compute the precision envelope
+ for i in range(mpre.shape[0] - 1, 0, -1):
+ mpre[i - 1] = torch.maximum(mpre[i - 1], mpre[i])
+
+ # to calculate area under PR curve, look for points
+ # where X axis (recall) changes value
+ i = torch.where(mrec[1:] != mrec[:-1])[0]
+
+ # and sum (\Delta recall) * prec
+ ap = torch.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
+ return float(ap.item())
+
+
+@torch.jit.script
+def _compute_ap_per_class(
+ tp: torch.Tensor, fp: torch.Tensor, num_gt_instances: int
+) -> float:
+ if tp.shape[0] == 0:
+ return 0.
+
+ tp = tp.cumsum(0)
+ fp = fp.cumsum(0)
+ rec = tp / num_gt_instances
+ prec = tp / (tp + fp + 1e-8)
+
+ return voc_ap(rec, prec)
+
+
+@torch.jit.script
+def _compute_ap(
+ confidence: torch.Tensor,
+ classes: torch.Tensor,
+ sorted_indices: torch.Tensor,
+ batch_indices: torch.Tensor,
+ sample_indices: torch.Tensor,
+ proposal_indices: torch.Tensor,
+ matched: List[torch.Tensor],
+ instance_sem_labels: List[torch.Tensor],
+ ious: List[torch.Tensor],
+ num_classes: int,
+ iou_threshold: float,
+):
+ sorted_indices_cpu = sorted_indices.cpu()
+
+ num_proposals = confidence.shape[0]
+ tp = torch.zeros(num_proposals, dtype=torch.float32)
+ fp = torch.zeros(num_proposals, dtype=torch.float32)
+ for i in range(num_proposals):
+ idx = sorted_indices_cpu[i]
+
+ class_idx = classes[idx]
+ batch_idx = batch_indices[idx].item()
+ sample_idx = sample_indices[idx]
+ proposal_idx = proposal_indices[idx]
+
+ instance_sem_labels_i = instance_sem_labels[batch_idx][sample_idx]
+ invalid_instance_mask = instance_sem_labels_i != class_idx
+
+ ious_i = ious[batch_idx][proposal_idx].clone()
+ ious_i[invalid_instance_mask] = 0.
+ if ious_i.shape[0] == 0:
+ max_iou, max_idx = 0., 0
+ else:
+ max_iou, max_idx = ious_i.max(0)
+ max_iou, max_idx = max_iou.item(), int(max_idx.item())
+
+ if max_iou > iou_threshold:
+ if not matched[batch_idx][sample_idx, max_idx].item():
+ tp[i] = 1.0
+ matched[batch_idx][sample_idx, max_idx] = True
+ else:
+ fp[i] = 1.0
+ else:
+ fp[i] = 1.0
+
+ tp = tp.to(device=confidence.device)
+ fp = fp.to(device=confidence.device)
+
+ sorted_classes = classes[sorted_indices]
+ gt_classes = torch.cat([x.view(-1) for x in instance_sem_labels], dim=0)
+ aps: List[float] = []
+ for c in range(1, num_classes):
+ num_gt_instances = (gt_classes == c).sum()
+ mask = sorted_classes == c
+ ap = _compute_ap_per_class(tp[mask], fp[mask], num_gt_instances)
+ aps.append(ap)
+ return aps
+
+
+def compute_ap(
+ proposals: List[Instances],
+ num_classes: int = 9,
+ iou_threshold: float = 0.5,
+ device="cpu",
+):
+ confidence = torch.cat([
+ p.score_preds for p in proposals if p is not None
+ ], dim=0).to(device=device)
+ classes = torch.cat([
+ p.pt_sem_classes
+ for p in proposals if p is not None
+ ], dim=0).to(device=device)
+ sorted_indices = torch.argsort(confidence, descending=True)
+
+ batch_indices = torch.cat([
+ torch.full((p.score_preds.shape[0],), i, dtype=torch.int64)
+ for i, p in enumerate(proposals) if p is not None
+ ], dim=0)
+ sample_indices = torch.cat([
+ p.batch_indices[p.proposal_offsets[:-1].long()].long()
+ for p in proposals if p is not None
+ ], dim=0).cpu()
+ proposal_indices = torch.cat([
+ torch.arange(p.score_preds.shape[0], dtype=torch.int64)
+ for p in proposals if p is not None
+ ], dim=0)
+
+ matched = [
+ torch.zeros_like(p.instance_sem_labels, dtype=torch.bool, device="cpu")
+ for p in proposals if p is not None
+ ]
+
+ return _compute_ap(
+ confidence,
+ classes,
+ sorted_indices,
+ batch_indices,
+ sample_indices,
+ proposal_indices,
+ matched,
+ [p.instance_sem_labels.to(device=device) for p in proposals if p is not None],
+ [p.ious.to(device=device) for p in proposals if p is not None],
+ num_classes,
+ iou_threshold,
+ )
diff --git a/gapartnet/gapartnet/models/sparse_unet.py b/gapartnet/gapartnet/models/sparse_unet.py
new file mode 100644
index 0000000..1bd50ec
--- /dev/null
+++ b/gapartnet/gapartnet/models/sparse_unet.py
@@ -0,0 +1,287 @@
+from typing import List
+
+import spconv.pytorch as spconv
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ResBlock(spconv.SparseModule):
+ def __init__(
+ self, in_channels: int, out_channels: int, norm_fn: nn.Module, indice_key=None
+ ):
+ super().__init__()
+
+ if in_channels == out_channels:
+ self.shortcut = nn.Identity()
+ else:
+ # assert False
+ self.shortcut = spconv.SparseSequential(
+ spconv.SubMConv3d(in_channels, out_channels, kernel_size=1, \
+ bias=False),
+ norm_fn(out_channels),
+ )
+
+ self.conv1 = spconv.SparseSequential(
+ spconv.SubMConv3d(
+ in_channels, out_channels, kernel_size=3,
+ padding=1, bias=False, indice_key=indice_key,
+ ),
+ norm_fn(out_channels),
+ )
+
+ self.conv2 = spconv.SparseSequential(
+ spconv.SubMConv3d(
+ out_channels, out_channels, kernel_size=3,
+ padding=1, bias=False, indice_key=indice_key,
+ ),
+ norm_fn(out_channels),
+ )
+
+ def forward(self, x: spconv.SparseConvTensor) -> spconv.SparseConvTensor:
+ shortcut = self.shortcut(x)
+
+ x = self.conv1(x)
+ x = x.replace_feature(F.relu(x.features))
+
+ x = self.conv2(x)
+ x = x.replace_feature(F.relu(x.features + shortcut.features))
+
+ return x
+
+
+class UBlock(nn.Module):
+ def __init__(
+ self,
+ channels: List[int],
+ block_fn: nn.Module,
+ block_repeat: int,
+ norm_fn: nn.Module,
+ indice_key_id: int = 1,
+ ):
+ super().__init__()
+
+ self.channels = channels
+
+ encoder_blocks = [
+ block_fn(
+ channels[0], channels[0], norm_fn, indice_key=f"subm{indice_key_id}"
+ )
+ for _ in range(block_repeat)
+ ]
+ self.encoder_blocks = spconv.SparseSequential(*encoder_blocks)
+
+ if len(channels) > 1:
+ self.downsample = spconv.SparseSequential(
+ spconv.SparseConv3d(
+ channels[0], channels[1], kernel_size=2, stride=2,
+ bias=False, indice_key=f"spconv{indice_key_id}",
+ ),
+ norm_fn(channels[1]),
+ nn.ReLU(),
+ )
+
+ self.ublock = UBlock(
+ channels[1:], block_fn, block_repeat, norm_fn, indice_key_id + 1
+ )
+
+ self.upsample = spconv.SparseSequential(
+ spconv.SparseInverseConv3d(
+ channels[1], channels[0], kernel_size=2,
+ bias=False, indice_key=f"spconv{indice_key_id}",
+ ),
+ norm_fn(channels[0]),
+ nn.ReLU(),
+ )
+
+ decoder_blocks = [
+ block_fn(
+ channels[0] * 2, channels[0], norm_fn,
+ indice_key=f"subm{indice_key_id}",
+ ),
+ ]
+ for _ in range(block_repeat -1):
+ decoder_blocks.append(
+ block_fn(
+ channels[0], channels[0], norm_fn,
+ indice_key=f"subm{indice_key_id}",
+ )
+ )
+ self.decoder_blocks = spconv.SparseSequential(*decoder_blocks)
+
+ def forward(self, x: spconv.SparseConvTensor) -> spconv.SparseConvTensor:
+ x = self.encoder_blocks(x)
+ shortcut = x
+
+ if len(self.channels) > 1:
+ x = self.downsample(x)
+ x = self.ublock(x)
+ x = self.upsample(x)
+
+ x = x.replace_feature(torch.cat([x.features, shortcut.features],\
+ dim=-1))
+ x = self.decoder_blocks(x)
+
+ return x
+
+
+class SparseUNet(nn.Module):
+ def __init__(self, stem: nn.Module, ublock: UBlock):
+ super().__init__()
+
+ self.stem = stem
+ self.ublock = ublock
+
+ def forward(self, x):
+ if self.stem is not None:
+ x = self.stem(x)
+ x = self.ublock(x)
+ return x
+
+ @classmethod
+ def build(
+ cls,
+ in_channels: int,
+ channels: List[int],
+ block_repeat: int,
+ norm_fn: nn.Module,
+ without_stem: bool = False,
+ ):
+ if not without_stem:
+ stem = spconv.SparseSequential(
+ spconv.SubMConv3d(
+ in_channels, channels[0], kernel_size=3,
+ padding=1, bias=False, indice_key="subm1",
+ ),
+ norm_fn(channels[0]),
+ nn.ReLU(),
+ )
+ else:
+ stem = spconv.SparseSequential(
+ norm_fn(channels[0]),
+ nn.ReLU(),
+ )
+
+ block = UBlock(channels, ResBlock, block_repeat, norm_fn, \
+ indice_key_id=1)
+
+ return SparseUNet(stem, block)
+
+
+
+class UBlock_NoSkip(nn.Module):
+ def __init__(
+ self,
+ channels: List[int],
+ block_fn: nn.Module,
+ block_repeat: int,
+ norm_fn: nn.Module,
+ indice_key_id: int = 1,
+ ):
+ super().__init__()
+
+ self.channels = channels
+
+ encoder_blocks = [
+ block_fn(
+ channels[0], channels[0], norm_fn, indice_key=f"subm{indice_key_id}"
+ )
+ for _ in range(block_repeat)
+ ]
+ self.encoder_blocks = spconv.SparseSequential(*encoder_blocks)
+
+ if len(channels) > 1:
+ self.downsample = spconv.SparseSequential(
+ spconv.SparseConv3d(
+ channels[0], channels[1], kernel_size=2, stride=2,
+ bias=False, indice_key=f"spconv{indice_key_id}",
+ ),
+ norm_fn(channels[1]),
+ nn.ReLU(),
+ )
+
+ self.ublock = UBlock(
+ channels[1:], block_fn, block_repeat, norm_fn, indice_key_id + 1
+ )
+
+ self.upsample = spconv.SparseSequential(
+ spconv.SparseInverseConv3d(
+ channels[1], channels[0], kernel_size=2,
+ bias=False, indice_key=f"spconv{indice_key_id}",
+ ),
+ norm_fn(channels[0]),
+ nn.ReLU(),
+ )
+
+ decoder_blocks = [
+ block_fn(
+ channels[0], channels[0], norm_fn,
+ indice_key=f"subm{indice_key_id}",
+ ),
+ ]
+ for _ in range(block_repeat -1):
+ decoder_blocks.append(
+ block_fn(
+ channels[0], channels[0], norm_fn,
+ indice_key=f"subm{indice_key_id}",
+ )
+ )
+ self.decoder_blocks = spconv.SparseSequential(*decoder_blocks)
+
+ def forward(self, x: spconv.SparseConvTensor) -> spconv.SparseConvTensor:
+ x = self.encoder_blocks(x)
+ # shortcut = x
+
+ if len(self.channels) > 1:
+ x = self.downsample(x)
+ x = self.ublock(x)
+ x = self.upsample(x)
+
+ # x = x.replace_feature(torch.cat([x.features, shortcut.features],\
+ # dim=-1))
+ x = self.decoder_blocks(x)
+
+ return x
+
+
+class SparseUNet_NoSkip(nn.Module):
+ def __init__(self, stem: nn.Module, ublock: UBlock_NoSkip):
+ super().__init__()
+
+ self.stem = stem
+ self.ublock = ublock
+
+ def forward(self, x):
+ if self.stem is not None:
+ x = self.stem(x)
+ x = self.ublock(x)
+ return x
+
+ @classmethod
+ def build(
+ cls,
+ in_channels: int,
+ channels: List[int],
+ block_repeat: int,
+ norm_fn: nn.Module,
+ without_stem: bool = False,
+ ):
+ if not without_stem:
+ stem = spconv.SparseSequential(
+ spconv.SubMConv3d(
+ in_channels, channels[0], kernel_size=3,
+ padding=1, bias=False, indice_key="subm1",
+ ),
+ norm_fn(channels[0]),
+ nn.ReLU(),
+ )
+ else:
+ stem = spconv.SparseSequential(
+ norm_fn(channels[0]),
+ nn.ReLU(),
+ )
+
+ block = UBlock(channels, ResBlock, block_repeat, norm_fn, \
+ indice_key_id=1)
+
+ return SparseUNet(stem, block)
diff --git a/gapartnet/gapartnet/models/util_net.py b/gapartnet/gapartnet/models/util_net.py
new file mode 100644
index 0000000..5cb7bbb
--- /dev/null
+++ b/gapartnet/gapartnet/models/util_net.py
@@ -0,0 +1,37 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Function
+
+class ReverseLayerF(Function):
+
+ @staticmethod
+ def forward(ctx, x, alpha = 0.1):
+ ctx.alpha = alpha
+
+ return x.view_as(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ output = grad_output.neg() * ctx.alpha
+
+ return output, None
+
+class Discriminator(nn.Module):
+ def __init__(self, input_dim=256, hidden_dim=256, num_domains=9):
+ super(Discriminator, self).__init__()
+ self.input_dim = input_dim
+ self.hidden_dim = hidden_dim
+ layers = [
+ nn.Linear(input_dim, hidden_dim),
+ nn.BatchNorm1d(hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.BatchNorm1d(hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, num_domains),
+ ]
+ self.layers = torch.nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.layers(x)
\ No newline at end of file
diff --git a/gapartnet/gapartnet/structures/__init__.py b/gapartnet/gapartnet/structures/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/gapartnet/gapartnet/structures/instances.py b/gapartnet/gapartnet/structures/instances.py
new file mode 100644
index 0000000..2a6670a
--- /dev/null
+++ b/gapartnet/gapartnet/structures/instances.py
@@ -0,0 +1,44 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+
+
+@dataclass
+class Instances:
+ valid_mask: torch.Tensor
+ sorted_indices: torch.Tensor
+ pt_xyz: torch.Tensor
+
+ batch_indices: torch.Tensor
+ proposal_offsets: torch.Tensor
+ proposal_indices: torch.Tensor
+ num_points_per_proposal: torch.Tensor
+
+ sem_preds: Optional[torch.Tensor] = None
+ pt_sem_classes: Optional[torch.Tensor] = None
+ score_preds: Optional[torch.Tensor] = None
+ npcs_preds: Optional[torch.Tensor] = None
+
+ sem_labels: Optional[torch.Tensor] = None
+ instance_labels: Optional[torch.Tensor] = None
+ instance_sem_labels: Optional[torch.Tensor] = None
+ num_points_per_instance: Optional[torch.Tensor] = None
+ gt_npcs: Optional[torch.Tensor] = None
+
+ npcs_valid_mask: Optional[torch.Tensor] = None
+
+ ious: Optional[torch.Tensor] = None
+
+ cls_preds: Optional[torch.tensor] = None
+ cls_labels: Optional[torch.tensor] = None
+
+ name: Optional[str] = None
+
+@dataclass
+class Result:
+ xyz: torch.Tensor
+ rgb: torch.tensor
+ sem_preds: torch.Tensor
+ ins_preds: torch.Tensor
+ npcs_preds: torch.Tensor
\ No newline at end of file
diff --git a/gapartnet/gapartnet/structures/part_pc.py b/gapartnet/gapartnet/structures/part_pc.py
new file mode 100644
index 0000000..e007746
--- /dev/null
+++ b/gapartnet/gapartnet/structures/part_pc.py
@@ -0,0 +1,13 @@
+from dataclasses import dataclass
+
+import torch
+
+
+@dataclass
+class PartPC:
+ scene_id: str
+ cls_label: int
+
+ points: torch.Tensor
+ rgb: torch.Tensor
+ npcs: torch.Tensor
diff --git a/gapartnet/gapartnet/structures/point_cloud.py b/gapartnet/gapartnet/structures/point_cloud.py
new file mode 100644
index 0000000..bf112ad
--- /dev/null
+++ b/gapartnet/gapartnet/structures/point_cloud.py
@@ -0,0 +1,146 @@
+from dataclasses import dataclass, fields
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+from pyparsing import Opt
+import spconv.pytorch as spconv
+import torch
+
+
+@dataclass
+class PointCloud:
+ scene_id: str
+ obj_cat: int
+
+ points: Union[torch.Tensor, np.ndarray]
+
+
+ sem_labels: Optional[Union[torch.Tensor, np.ndarray]] = None
+ instance_labels: Optional[Union[torch.Tensor, np.ndarray]] = None
+
+ gt_npcs: Optional[Union[torch.Tensor, np.ndarray]] = None
+
+ num_instances: Optional[int] = None
+ instance_regions: Optional[Union[torch.Tensor, np.ndarray]] = None
+ num_points_per_instance: Optional[Union[torch.Tensor, np.ndarray]] = None
+ instance_sem_labels: Optional[torch.Tensor] = None
+
+ voxel_features: Optional[torch.Tensor] = None
+ voxel_coords: Optional[torch.Tensor] = None
+ voxel_coords_range: Optional[List[int]] = None
+ pc_voxel_id: Optional[torch.Tensor] = None
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ field.name: getattr(self, field.name)
+ for field in fields(self)
+ }
+
+ def to_tensor(self) -> "PointCloud":
+ return PointCloud(**{
+ k: torch.from_numpy(v) if isinstance(v, np.ndarray) else v
+ for k, v in self.to_dict().items()
+ })
+
+ def to(self, device: torch.device) -> "PointCloud":
+ return PointCloud(**{
+ k: v.to(device) if isinstance(v, torch.Tensor) else v
+ for k, v in self.to_dict().items()
+ })
+
+ @staticmethod
+ def collate(point_clouds: List["PointCloud"]):
+ batch_size = len(point_clouds)
+ device = point_clouds[0].points.device
+
+ scene_ids = [pc.scene_id for pc in point_clouds]
+ cls_labels = torch.tensor([pc.obj_cat for pc in point_clouds])
+ num_points = [pc.points.shape[0] for pc in point_clouds]
+
+ points = torch.cat([pc.points for pc in point_clouds], dim=0)
+ batch_indices = torch.cat([
+ torch.full((pc.points.shape[0],), i, dtype=torch.int32, device=device)
+ for i, pc in enumerate(point_clouds)
+ ], dim=0)
+
+ if point_clouds[0].sem_labels is not None:
+ sem_labels = torch.cat([pc.sem_labels for pc in point_clouds], dim=0)
+ else:
+ sem_labels = None
+
+ if point_clouds[0].instance_labels is not None:
+ instance_labels = torch.cat([pc.instance_labels for pc in point_clouds], dim=0)
+ else:
+ instance_labels = None
+
+ if point_clouds[0].gt_npcs is not None:
+ gt_npcs = torch.cat([pc.gt_npcs for pc in point_clouds], dim=0)
+ else:
+ gt_npcs = None
+
+ if point_clouds[0].num_instances is not None:
+ num_instances = [pc.num_instances for pc in point_clouds]
+ max_num_instances = max(num_instances)
+ num_points_per_instance = torch.zeros(
+ batch_size, max_num_instances, dtype=torch.int32, device=device
+ )
+ instance_sem_labels = torch.full(
+ (batch_size, max_num_instances), -1, dtype=torch.int32, device=device
+ )
+ for i, pc in enumerate(point_clouds):
+ num_points_per_instance[i, :pc.num_instances] = pc.num_points_per_instance
+ instance_sem_labels[i, :pc.num_instances] = pc.instance_sem_labels
+ else:
+ num_instances = None
+ num_points_per_instance = None
+ instance_sem_labels = None
+
+ if point_clouds[0].instance_regions is not None:
+ instance_regions = torch.cat([
+ pc.instance_regions for pc in point_clouds
+ ], dim=0)
+ else:
+ instance_regions = None
+
+ voxel_batch_indices = torch.cat([
+ torch.full((
+ pc.voxel_coords.shape[0],), i, dtype=torch.int32, device=device
+ )
+ for i, pc in enumerate(point_clouds)
+ ], dim=0)
+ voxel_coords = torch.cat([
+ pc.voxel_coords for pc in point_clouds
+ ], dim=0)
+ voxel_coords = torch.cat([
+ voxel_batch_indices[:, None], voxel_coords
+ ], dim=-1)
+ voxel_features = torch.cat([
+ pc.voxel_features for pc in point_clouds
+ ], dim=0)
+
+ voxel_coords_range = np.max([
+ pc.voxel_coords_range for pc in point_clouds
+ ], axis=0)
+ voxel_tensor = spconv.SparseConvTensor(
+ voxel_features, voxel_coords,
+ spatial_shape=voxel_coords_range.tolist(),
+ batch_size=len(point_clouds),
+ )
+
+ pc_voxel_id = []
+ num_voxel_offset = 0
+ for pc in point_clouds:
+ pc.pc_voxel_id[pc.pc_voxel_id >= 0] += num_voxel_offset
+ pc_voxel_id.append(pc.pc_voxel_id)
+ num_voxel_offset += pc.voxel_coords.shape[0]
+ pc_voxel_id = torch.cat(pc_voxel_id, dim=0)
+
+ return (
+ scene_ids, cls_labels, num_points, points, batch_indices, sem_labels, instance_labels, gt_npcs,
+ num_instances, instance_regions, num_points_per_instance, instance_sem_labels,
+ voxel_tensor, pc_voxel_id,
+ )
+
+if __name__ == "__main__":
+ pc = PointCloud(np.ones((10000,3)), np.ones((10000)), np.ones(10000))
+ print(pc.to_tensor().to("cuda:0"))
diff --git a/gapartnet/gapartnet/structures/segmentation.py b/gapartnet/gapartnet/structures/segmentation.py
new file mode 100644
index 0000000..d972844
--- /dev/null
+++ b/gapartnet/gapartnet/structures/segmentation.py
@@ -0,0 +1,14 @@
+from dataclasses import dataclass
+from typing import List, Optional
+
+import torch
+
+
+@dataclass
+class Segmentation:
+ batch_size: int
+ num_points: List[int]
+
+ sem_preds: torch.Tensor
+ sem_labels: Optional[torch.Tensor] = None
+
diff --git a/gapartnet/gapartnet/utils/__init__.py b/gapartnet/gapartnet/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/gapartnet/gapartnet/utils/color.py b/gapartnet/gapartnet/utils/color.py
new file mode 100644
index 0000000..5ef3eeb
--- /dev/null
+++ b/gapartnet/gapartnet/utils/color.py
@@ -0,0 +1,36 @@
+from typing import Optional
+
+import numpy as np
+
+
+def hsv_to_rgb(
+ h: np.ndarray, s: Optional[np.ndarray] = None, v: Optional[np.ndarray] = None
+) -> np.ndarray:
+ if s is None:
+ s = np.ones_like(h)
+ if v is None:
+ v = np.ones_like(h)
+
+ print("h", h.shape, s.shape, v.shape)
+
+ i = (h * 6.0).astype(np.int32)
+ f = (h * 6.0) - i
+ p = v * (1.0 - s)
+ q = v * (1.0 - s * f)
+ t = v * (1.0 - s * (1.0 - f))
+ i = i % 6
+
+ res = np.zeros((h.shape[0], 3), dtype=np.float32)
+ mask_0 = i == 0
+ res[mask_0] = np.stack([v[mask_0], t[mask_0], p[mask_0]], axis=1)
+ mask_1 = i == 1
+ res[mask_1] = np.stack([q[mask_1], v[mask_1], p[mask_1]], axis=1)
+ mask_2 = i == 2
+ res[mask_2] = np.stack([p[mask_2], v[mask_2], t[mask_2]], axis=1)
+ mask_3 = i == 3
+ res[mask_3] = np.stack([p[mask_3], q[mask_3], v[mask_3]], axis=1)
+ mask_4 = i == 4
+ res[mask_4] = np.stack([t[mask_4], p[mask_4], v[mask_4]], axis=1)
+ mask_5 = i == 5
+ res[mask_5] = np.stack([v[mask_5], p[mask_5], q[mask_5]], axis=1)
+ return res
diff --git a/gapartnet/gapartnet/utils/data.py b/gapartnet/gapartnet/utils/data.py
new file mode 100644
index 0000000..d1109c4
--- /dev/null
+++ b/gapartnet/gapartnet/utils/data.py
@@ -0,0 +1,36 @@
+from typing import Any, Iterator
+
+import torch
+import torch.distributed as dist
+import torchdata.datapipes as dp
+
+
+def trivial_batch_collator(batch):
+ """
+ A batch collator that does nothing.
+ """
+ return batch
+
+
+@dp.functional_datapipe("distributed_sharding_filter")
+class DistributedShardingFilter(dp.iter.ShardingFilter):
+ def __init__(self, source_datapipe: dp.iter.IterDataPipe) -> None:
+ super().__init__(source_datapipe)
+
+ self.rank = 0
+ self.world_size = 1
+ if dist.is_available() and dist.is_initialized():
+ self.rank = dist.get_rank()
+ self.world_size = dist.get_world_size()
+ self.apply_sharding(self.world_size, self.rank)
+
+ def __iter__(self) -> Iterator[Any]:
+ num_workers = self.world_size
+ worker_id = self.rank
+ worker_info = torch.utils.data.get_worker_info()
+ if worker_info is not None:
+ worker_id = worker_id + worker_info.id * num_workers
+ num_workers *= worker_info.num_workers
+ self.apply_sharding(num_workers, worker_id)
+
+ yield from super().__iter__()
diff --git a/gapartnet/gapartnet/utils/info.py b/gapartnet/gapartnet/utils/info.py
new file mode 100644
index 0000000..95be397
--- /dev/null
+++ b/gapartnet/gapartnet/utils/info.py
@@ -0,0 +1,352 @@
+OBJECT_NAME2ID = {
+ # seen category
+ "Box": 0,
+ "Remote": 1,
+ "Microwave": 2,
+ "Camera": 3,
+ "Dishwasher": 4,
+ "WashingMachine": 5,
+ "CoffeeMachine": 6,
+ "Toaster": 7,
+ "StorageFurniture": 8,
+ "AKBBucket": 9, # akb48
+ "AKBBox": 10, # akb48
+ "AKBDrawer": 11, # akb48
+ "AKBTrashCan": 12, # akb48
+ "Bucket": 13, # new
+ "Keyboard": 14, # new
+ "Printer": 15, # new
+ "Toilet": 16, # new
+ # unseen category
+ "KitchenPot": 17,
+ "Safe": 18,
+ "Oven": 19,
+ "Phone": 20,
+ "Refrigerator": 21,
+ "Table": 22,
+ "TrashCan": 23,
+ "Door": 24,
+ "Laptop": 25,
+ "Suitcase": 26, # new
+}
+
+TARGET_PARTS = [
+ 'others',
+ 'line_fixed_handle',
+ 'round_fixed_handle',
+ 'slider_button',
+ 'hinge_door',
+ 'slider_drawer',
+ 'slider_lid',
+ 'hinge_lid',
+ 'hinge_knob',
+ 'revolute_handle'
+]
+
+PART_NAME2ID = {
+ 'others': 0,
+ 'line_fixed_handle': 1,
+ 'round_fixed_handle': 2,
+ 'slider_button': 3,
+ 'hinge_door': 4,
+ 'slider_drawer': 5,
+ 'slider_lid': 6,
+ 'hinge_lid': 7,
+ 'hinge_knob': 8,
+ 'revolute_handle': 9,
+}
+
+PART_ID2NAME = {
+ 0: 'others' ,
+ 1: 'line_fixed_handle' ,
+ 2: 'round_fixed_handle' ,
+ 3: 'slider_button' ,
+ 4: 'hinge_door' ,
+ 5: 'slider_drawer' ,
+ 6: 'slider_lid' ,
+ 7: 'hinge_lid' ,
+ 8: 'hinge_knob' ,
+ 9: 'revolute_handle' ,
+}
+
+
+TARGET_PARTS = [
+ 'others',
+ 'line_fixed_handle',
+ 'round_fixed_handle',
+ 'slider_button',
+ 'hinge_door',
+ 'slider_drawer',
+ 'slider_lid',
+ 'hinge_lid',
+ 'hinge_knob',
+ 'revolute_handle',
+]
+
+TARGET_IDX = [
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+]
+
+# import numpy as np
+# OTHER_COLOR = [230, 230, 230]
+
+# COLOR20 = np.array(
+# [[0, 128, 128], [230, 190, 255], [170, 110, 40], [255, 250, 200], [128, 0, 0],
+# [170, 255, 195], [128, 128, 0], [255, 215, 180], [0, 0, 128], [128, 128, 128],
+# [230, 25, 75], [60, 180, 75], [255, 225, 25], [0, 130, 200], [245, 130, 48],
+# [145, 30, 180], [70, 240, 240], [240, 50, 230], [210, 245, 60], [250, 190, 190]])
+
+# COLOR40 = np.array(
+# [[175,58,119], [81,175,144], [184,70,74], [40,116,79], [184,134,219], [130,137,46], [110,89,164], [92,135,74], [220,140,190], [94,103,39],
+# [144,154,219], [160,86,40], [67,107,165], [194,170,104], [162,95,150], [143,110,44], [146,72,105], [225,142,106], [162,83,86], [227,124,143],[88,170,108], [174,105,226], [78,194,83], [198,62,165], [133,188,52], [97,101,219], [190,177,52], [139,65,168], [75,202,137], [225,66,129],
+# [68,135,42], [226,116,210], [146,186,98], [68,105,201], [219,148,53], [85,142,235], [212,85,42], [78,176,223], [221,63,77], [68,195,195]
+# ])
+
+# SEMANTIC_COLOR = np.array(
+# [OTHER_COLOR, # others
+# [202, 145, 99],
+# [203, 202, 102],
+# [140, 203, 103],
+# [109, 189, 205],
+# [112,157,206],
+# [128,129,212],
+# [175,124,211],
+# [208,118,167]
+# [203,140,103],
+
+# ])
+
+# # SEMANTIC_COLOR = np.array(
+# # [OTHER_COLOR, # others
+# # [140, 103,203,],
+# # [140, 103,203,],
+# # [140, 103,203,],
+# # [140, 103,203,],
+# # [140, 103,203,],
+# # [140, 103,203,],
+# # [140, 103,203,],
+# # [140, 103,203,],
+
+# # ])
+# SEMANTIC_IDX2NAME = {
+# 0: 'others',
+# 1: 'line_fixed_handle',
+# 2: 'round_fixed_handle',
+# 3: 'slider_button',
+# 4: 'hinge_door',
+# 5: 'slider_drawer',
+# 6: 'slider_lid',
+# 7: 'hinge_lid',
+# 8: 'hinge_knob',
+# 9: 'revolute_handle',
+
+# }
+
+
+# SYMMETRY_MATRIX_INDEX = [0,1,2,2,4,0,2,4,3,1]
+
+
+# import math
+# PI = math.pi
+
+# SYMMETRY_MATRIX = {
+# 1:[[
+# [-1.0, 0, 0],
+# [ 0, -1.0, 0],
+# [ 0, 0, 1.0]
+# ]],
+# 2:[
+# [
+# [ math.cos(PI/6),math.sin(PI/6), 0],
+# [-math.sin(PI/6),math.cos(PI/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*2/6),math.sin(PI*2/6), 0],
+# [-math.sin(PI*2/6),math.cos(PI*2/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*3/6),math.sin(PI*3/6), 0],
+# [-math.sin(PI*3/6),math.cos(PI*3/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*4/6),math.sin(PI*4/6), 0],
+# [-math.sin(PI*4/6),math.cos(PI*4/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*5/6),math.sin(PI*5/6), 0],
+# [-math.sin(PI*5/6),math.cos(PI*5/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*6/6),math.sin(PI*6/6), 0],
+# [-math.sin(PI*6/6),math.cos(PI*6/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*7/6),math.sin(PI*7/6), 0],
+# [-math.sin(PI*7/6),math.cos(PI*7/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*8/6),math.sin(PI*8/6), 0],
+# [-math.sin(PI*8/6),math.cos(PI*8/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*9/6),math.sin(PI*9/6), 0],
+# [-math.sin(PI*9/6),math.cos(PI*9/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*10/6),math.sin(PI*10/6), 0],
+# [-math.sin(PI*10/6),math.cos(PI*10/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*11/6),math.sin(PI*11/6), 0],
+# [-math.sin(PI*11/6),math.cos(PI*11/6), 0],
+# [ 0, 0, 1.0]
+# ]
+# ],
+
+# 3:[
+# [
+# [ math.cos(PI/6),math.sin(PI/6), 0],
+# [-math.sin(PI/6),math.cos(PI/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*2/6),math.sin(PI*2/6), 0],
+# [-math.sin(PI*2/6),math.cos(PI*2/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*3/6),math.sin(PI*3/6), 0],
+# [-math.sin(PI*3/6),math.cos(PI*3/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*4/6),math.sin(PI*4/6), 0],
+# [-math.sin(PI*4/6),math.cos(PI*4/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*5/6),math.sin(PI*5/6), 0],
+# [-math.sin(PI*5/6),math.cos(PI*5/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*6/6),math.sin(PI*6/6), 0],
+# [-math.sin(PI*6/6),math.cos(PI*6/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*7/6),math.sin(PI*7/6), 0],
+# [-math.sin(PI*7/6),math.cos(PI*7/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*8/6),math.sin(PI*8/6), 0],
+# [-math.sin(PI*8/6),math.cos(PI*8/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*9/6),math.sin(PI*9/6), 0],
+# [-math.sin(PI*9/6),math.cos(PI*9/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*10/6),math.sin(PI*10/6), 0],
+# [-math.sin(PI*10/6),math.cos(PI*10/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*11/6),math.sin(PI*11/6), 0],
+# [-math.sin(PI*11/6),math.cos(PI*11/6), 0],
+# [ 0, 0, 1.0]
+# ],
+
+# ###################### inverse ######################
+
+# [
+# [ math.sin(PI/6),math.cos(PI/6), 0],
+# [ math.cos(PI/6),-math.sin(PI/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*2/6),math.cos(PI*2/6), 0],
+# [ math.cos(PI*2/6),-math.sin(PI*2/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*3/6),math.cos(PI*3/6), 0],
+# [ math.cos(PI*3/6),-math.sin(PI*3/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*4/6),math.cos(PI*4/6), 0],
+# [ math.cos(PI*4/6),-math.sin(PI*4/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*5/6),math.cos(PI*5/6), 0],
+# [ math.cos(PI*5/6),-math.sin(PI*5/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*6/6),math.cos(PI*6/6), 0],
+# [ math.cos(PI*6/6),-math.sin(PI*6/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*7/6),math.cos(PI*7/6), 0],
+# [ math.cos(PI*7/6),-math.sin(PI*7/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*8/6),math.cos(PI*8/6), 0],
+# [ math.cos(PI*8/6),-math.sin(PI*8/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*9/6),math.cos(PI*9/6), 0],
+# [ math.cos(PI*9/6),-math.sin(PI*9/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*10/6),math.cos(PI*10/6), 0],
+# [ math.cos(PI*10/6),-math.sin(PI*10/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*11/6),math.cos(PI*11/6), 0],
+# [ math.cos(PI*11/6),-math.sin(PI*11/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*12/6),math.cos(PI*12/6), 0],
+# [ math.cos(PI*12/6),-math.sin(PI*12/6), 0],
+# [ 0, 0, -1.0]
+# ]
+
+# ],
+# 4:[[
+# [-1.0, 0, 0],
+# [ 0, 1.0, 0],
+# [ 0, 0, -1.0]
+# ]],
+
+# }
diff --git a/gapartnet/gapartnet/utils/logger.py b/gapartnet/gapartnet/utils/logger.py
new file mode 100644
index 0000000..03dc8a6
--- /dev/null
+++ b/gapartnet/gapartnet/utils/logger.py
@@ -0,0 +1,61 @@
+from typing import Callable, Mapping, Optional, Sequence, Union
+
+from pytorch_lightning.loggers import WandbLogger as _WandbLogger
+
+
+class WandbLogger(_WandbLogger):
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ save_dir: Optional[str] = None,
+ offline: Optional[bool] = False,
+ id: Optional[str] = None,
+ anonymous: Optional[bool] = None,
+ version: Optional[str] = None,
+ project: Optional[str] = None,
+ log_model: Union[str, bool] = False,
+ experiment=None,
+ prefix: Optional[str] = "",
+ agg_key_funcs: Optional[
+ Mapping[str, Callable[[Sequence[float]], float]]
+ ] = None,
+ agg_default_func: Optional[
+ Callable[[Sequence[float]], float]
+ ] = None,
+ entity: Optional[str] = None,
+ job_type: Optional[str] = None,
+ tags: Optional[Sequence] = None,
+ group: Optional[str] = None,
+ notes: Optional[str] = None,
+ mode: Optional[str] = None,
+ sync_tensorboard: Optional[bool] = False,
+ monitor_gym: Optional[bool] = False,
+ save_code: Optional[bool] = False,
+ **kwargs,
+ ):
+ super().__init__(
+ name=name,
+ save_dir=save_dir,
+ offline=offline,
+ id=id,
+ anonymous=anonymous,
+ version=version,
+ project=project,
+ log_model=log_model,
+ experiment=experiment,
+ prefix=prefix,
+ agg_key_funcs=agg_key_funcs,
+ agg_default_func=agg_default_func,
+ entity=entity,
+ job_type=job_type,
+ tags=tags,
+ group=group,
+ notes=notes,
+ mode=mode,
+ sync_tensorboard=sync_tensorboard,
+ monitor_gym=monitor_gym,
+ save_code=save_code,
+ **kwargs,
+ )
+
+
diff --git a/gapartnet/gapartnet/utils/pose_fitting.py b/gapartnet/gapartnet/utils/pose_fitting.py
new file mode 100644
index 0000000..f02ba52
--- /dev/null
+++ b/gapartnet/gapartnet/utils/pose_fitting.py
@@ -0,0 +1,147 @@
+import numpy as np
+
+
+def estimate_similarity_umeyama(source_hom: np.ndarray, target_hom: np.ndarray):
+ num_points = source_hom.shape[1]
+
+ source_centroid = np.mean(source_hom[:3, :], axis=1)
+ target_centroid = np.mean(target_hom[:3, :], axis=1)
+
+ centered_source = source_hom[:3, :] - np.tile(source_centroid, (num_points, 1)).transpose()
+ centered_target = target_hom[:3, :] - np.tile(target_centroid, (num_points, 1)).transpose()
+
+ cov = np.matmul(centered_target, np.transpose(centered_source)) / num_points
+
+ if np.isnan(cov).any():
+ raise RuntimeError("There are NANs in the input.")
+
+ U, D, Vh = np.linalg.svd(cov, full_matrices=True)
+ d = (np.linalg.det(U) * np.linalg.det(Vh)) < 0.0
+ if d:
+ D[-1] = -D[-1]
+ U[:, -1] = -U[:, -1]
+
+ var_P = np.var(source_hom[:3, :], axis=1).sum()
+ scale_factor = 1 / var_P * np.sum(D)
+ scale = np.array([scale_factor, scale_factor, scale_factor])
+ scale_matrix = np.diag(scale)
+
+ rotation = np.matmul(U, Vh).T
+
+ translation = target_hom[:3, :].mean(axis=1) - source_hom[:3, :].mean(axis=1).dot(
+ scale_factor * rotation
+ )
+
+ out_transform = np.identity(4)
+ out_transform[:3, :3] = scale_matrix @ rotation
+ out_transform[:3, 3] = translation
+
+ return scale, rotation, translation, out_transform
+
+
+def evaluate_model(
+ out_transform: np.ndarray, source_hom: np.ndarray, target_hom: np.ndarray, pass_thrsh: float
+):
+ diff = target_hom - np.matmul(out_transform, source_hom)
+ residual_vec = np.linalg.norm(diff[:3, :], axis=0)
+ residual = np.linalg.norm(residual_vec)
+ inlier_idx = np.where(residual_vec < pass_thrsh)
+ num_inliers = np.count_nonzero(inlier_idx)
+ inlier_ratio = num_inliers / source_hom.shape[1]
+ return residual, inlier_ratio, inlier_idx[0]
+
+
+def get_RANSAC_inliers(
+ source_hom: np.ndarray, target_hom: np.ndarray,
+ max_iters: int, pass_thrsh: float, stop_thrsh: float,
+):
+ best_residual = 1e10
+ best_inlier_ratio = 0
+ best_inlier_idx = np.arange(source_hom.shape[1])
+
+ for i in range(max_iters):
+ # Pick 5 random (but corresponding) points from source and target
+ rand_idx = np.random.randint(source_hom.shape[1], size=5)
+ _, _, _, out_transform = estimate_similarity_umeyama(
+ source_hom[:, rand_idx], target_hom[:, rand_idx]
+ )
+
+ residual, inlier_ratio, inlier_idx = evaluate_model(
+ out_transform, source_hom, target_hom, pass_thrsh
+ )
+ if residual < best_residual:
+ best_residual = residual
+ best_inlier_ratio = inlier_ratio
+ best_inlier_idx = inlier_idx
+
+ if best_residual < stop_thrsh:
+ break
+
+ return best_inlier_ratio, best_inlier_idx
+
+
+def estimate_similarity_transform(
+ source: np.ndarray, target: np.ndarray,
+ stop_thrsh: float = 0.5,
+ max_iters: int = 100,
+):
+ if source.shape[0] == 1:
+ source = np.repeat(source, 2, axis=0)
+ target = np.repeat(target, 2, axis=0)
+
+ source_hom = np.transpose(np.hstack([source, np.ones([source.shape[0], 1])]))
+ target_hom = np.transpose(np.hstack([target, np.ones([target.shape[0], 1])]))
+
+ # Auto-parameter selection based on source-target heuristics
+ source_norm = np.mean(np.linalg.norm(source, axis=1))
+ target_norm = np.mean(np.linalg.norm(target, axis=1))
+
+ ratio_st = (source_norm / target_norm)
+ ratio_ts = (target_norm / source_norm)
+ pass_thrsh = ratio_st if ratio_st > ratio_ts else ratio_ts
+
+ best_inlier_ratio, best_inlier_idx = \
+ get_RANSAC_inliers(
+ source_hom, target_hom, max_iters=max_iters,
+ pass_thrsh=pass_thrsh, stop_thrsh=stop_thrsh,
+ )
+ source_inliers_hom = source_hom[:, best_inlier_idx]
+ target_inliers_hom = target_hom[:, best_inlier_idx]
+
+ if best_inlier_ratio < 0.01:
+ return np.asarray([None, None, None]), None, None, None, None
+
+ scale, rotation, translation, out_transform = estimate_similarity_umeyama(
+ source_inliers_hom, target_inliers_hom
+ )
+
+ return scale, rotation, translation, out_transform, best_inlier_idx
+
+
+def estimate_pose_from_npcs(xyz, npcs):
+ scale, rotation, translation, out_transform, best_inlier_idx = \
+ estimate_similarity_transform(npcs, xyz)
+
+ if scale[0] == None:
+ return None, np.asarray([None,None,None]), None, None, None, best_inlier_idx
+ try:
+ rotation_inv = np.linalg.pinv(rotation)
+ except:
+ import pdb
+ pdb.set_trace()
+ trans_seg = np.dot((xyz - translation), rotation_inv) / scale[0]
+ npcs_max = abs(trans_seg[best_inlier_idx]).max(0)
+
+ bbox_raw = np.asarray([
+ [-npcs_max[0], -npcs_max[1], -npcs_max[2]],
+ [npcs_max[0], -npcs_max[1], -npcs_max[2]],
+ [-npcs_max[0], npcs_max[1], -npcs_max[2]],
+ [-npcs_max[0], -npcs_max[1], npcs_max[2]],
+ [npcs_max[0], npcs_max[1], -npcs_max[2]],
+ [npcs_max[0], -npcs_max[1], npcs_max[2]],
+ [-npcs_max[0], npcs_max[1], npcs_max[2]],
+ [npcs_max[0], npcs_max[1], npcs_max[2]],
+ ])
+ bbox_trans = np.dot((bbox_raw * scale[0]), rotation) + translation
+
+ return bbox_trans, scale, rotation, translation, out_transform, best_inlier_idx
diff --git a/gapartnet/gapartnet/utils/symmetry_matrix.py b/gapartnet/gapartnet/utils/symmetry_matrix.py
new file mode 100644
index 0000000..14b0ddd
--- /dev/null
+++ b/gapartnet/gapartnet/utils/symmetry_matrix.py
@@ -0,0 +1,249 @@
+import math
+from typing import Tuple
+
+import torch
+
+PI = math.pi
+SYMMETRY_MATRIX = [
+ # type 0
+ [
+ [
+ [ 1.0, 0, 0],
+ [ 0, 1.0, 0],
+ [ 0, 0, 1.0],
+ ],
+ [
+ [ 1.0, 0, 0],
+ [ 0, 1.0, 0],
+ [ 0, 0, 1.0],
+ ],
+ ],
+
+ # type 1
+ [
+ [
+ [ 1.0, 0, 0],
+ [ 0, 1.0, 0],
+ [ 0, 0, 1.0],
+ ],
+ [
+ [-1.0, 0, 0],
+ [ 0, -1.0, 0],
+ [ 0, 0, 1.0],
+ ],
+ ],
+
+ # type 2
+ [
+ [
+ [ 1.0, 0, 0],
+ [ 0, 1.0, 0],
+ [ 0, 0, 1.0],
+ ],
+ [
+ [-1.0, 0, 0],
+ [ 0, 1.0, 0],
+ [ 0, 0, -1.0],
+ ],
+ ],
+
+ # type 3
+ [
+ [
+ [ 1.0, 0, 0],
+ [ 0, 1.0, 0],
+ [ 0, 0, 1.0],
+ ],
+ [
+ [ math.cos(PI/6), math.sin(PI/6), 0],
+ [-math.sin(PI/6), math.cos(PI/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*2/6), math.sin(PI*2/6), 0],
+ [-math.sin(PI*2/6), math.cos(PI*2/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*3/6), math.sin(PI*3/6), 0],
+ [-math.sin(PI*3/6), math.cos(PI*3/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*4/6), math.sin(PI*4/6), 0],
+ [-math.sin(PI*4/6), math.cos(PI*4/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*5/6), math.sin(PI*5/6), 0],
+ [-math.sin(PI*5/6), math.cos(PI*5/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*6/6), math.sin(PI*6/6), 0],
+ [-math.sin(PI*6/6), math.cos(PI*6/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*7/6), math.sin(PI*7/6), 0],
+ [-math.sin(PI*7/6), math.cos(PI*7/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*8/6), math.sin(PI*8/6), 0],
+ [-math.sin(PI*8/6), math.cos(PI*8/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*9/6), math.sin(PI*9/6), 0],
+ [-math.sin(PI*9/6), math.cos(PI*9/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*10/6), math.sin(PI*10/6), 0],
+ [-math.sin(PI*10/6), math.cos(PI*10/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*11/6), math.sin(PI*11/6), 0],
+ [-math.sin(PI*11/6), math.cos(PI*11/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ ],
+
+ # type 4
+ [
+ [
+ [ 1.0, 0, 0],
+ [ 0, 1.0, 0],
+ [ 0, 0, 1.0],
+ ],
+ [
+ [ math.cos(PI/6), math.sin(PI/6), 0],
+ [-math.sin(PI/6), math.cos(PI/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*2/6), math.sin(PI*2/6), 0],
+ [-math.sin(PI*2/6), math.cos(PI*2/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*3/6), math.sin(PI*3/6), 0],
+ [-math.sin(PI*3/6), math.cos(PI*3/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*4/6), math.sin(PI*4/6), 0],
+ [-math.sin(PI*4/6), math.cos(PI*4/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*5/6), math.sin(PI*5/6), 0],
+ [-math.sin(PI*5/6), math.cos(PI*5/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*6/6), math.sin(PI*6/6), 0],
+ [-math.sin(PI*6/6), math.cos(PI*6/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*7/6), math.sin(PI*7/6), 0],
+ [-math.sin(PI*7/6), math.cos(PI*7/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*8/6), math.sin(PI*8/6), 0],
+ [-math.sin(PI*8/6), math.cos(PI*8/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*9/6), math.sin(PI*9/6), 0],
+ [-math.sin(PI*9/6), math.cos(PI*9/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*10/6), math.sin(PI*10/6), 0],
+ [-math.sin(PI*10/6), math.cos(PI*10/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*11/6), math.sin(PI*11/6), 0],
+ [-math.sin(PI*11/6), math.cos(PI*11/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ ###################### inverse ######################
+ [
+ [ math.sin(PI/6), math.cos(PI/6), 0],
+ [ math.cos(PI/6), -math.sin(PI/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*2/6), math.cos(PI*2/6), 0],
+ [ math.cos(PI*2/6), -math.sin(PI*2/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*3/6), math.cos(PI*3/6), 0],
+ [ math.cos(PI*3/6), -math.sin(PI*3/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*4/6), math.cos(PI*4/6), 0],
+ [ math.cos(PI*4/6), -math.sin(PI*4/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*5/6), math.cos(PI*5/6), 0],
+ [ math.cos(PI*5/6), -math.sin(PI*5/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*6/6), math.cos(PI*6/6), 0],
+ [ math.cos(PI*6/6), -math.sin(PI*6/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*7/6), math.cos(PI*7/6), 0],
+ [ math.cos(PI*7/6), -math.sin(PI*7/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*8/6), math.cos(PI*8/6), 0],
+ [ math.cos(PI*8/6), -math.sin(PI*8/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*9/6), math.cos(PI*9/6), 0],
+ [ math.cos(PI*9/6), -math.sin(PI*9/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*10/6), math.cos(PI*10/6), 0],
+ [ math.cos(PI*10/6), -math.sin(PI*10/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*11/6), math.cos(PI*11/6), 0],
+ [ math.cos(PI*11/6), -math.sin(PI*11/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*12/6), math.cos(PI*12/6), 0],
+ [ math.cos(PI*12/6), -math.sin(PI*12/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ ],
+]
+
+
+def get_symmetry_matrix() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ # type 0 / 1 / 2
+ sm_1 = torch.as_tensor(SYMMETRY_MATRIX[:3], dtype=torch.float32)
+ # type 3
+ sm_2 = torch.as_tensor(SYMMETRY_MATRIX[3:4], dtype=torch.float32)
+ # type 4
+ sm_3 = torch.as_tensor(SYMMETRY_MATRIX[4:5], dtype=torch.float32)
+
+ return sm_1, sm_2, sm_3
diff --git a/gapartnet/gapartnet/utils/utils.py b/gapartnet/gapartnet/utils/utils.py
new file mode 100644
index 0000000..6c54729
--- /dev/null
+++ b/gapartnet/gapartnet/utils/utils.py
@@ -0,0 +1,161 @@
+import torch, glob, os, numpy as np
+import sys
+import math
+sys.path.append('../')
+
+from util.log import logger
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ if self.count == 0:
+ self.avg = 0
+ else:
+ self.avg = self.sum / self.count
+
+
+def step_learning_rate(optimizer,
+ base_lr,
+ epoch,
+ step_epoch_1,
+ step_epoch_2,
+ prepare_epochs,
+ multiplier=0.1,
+ clip=1e-6):
+ """Sets the learning rate to the base LR decayed by 10 every step epochs"""
+ if epoch <= prepare_epochs:
+ lr = max(base_lr * (multiplier**(epoch // step_epoch_1)), clip)
+ else:
+ lr = max(base_lr * (multiplier**((epoch - prepare_epochs) // step_epoch_2)), clip)
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr
+
+
+def intersectionAndUnion(output, target, K, ignore_index=255):
+ # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
+ assert (output.ndim in [1, 2, 3])
+ assert output.shape == target.shape
+ output = output.reshape(output.size).copy()
+ target = target.reshape(target.size)
+ output[np.where(target == ignore_index)[0]] = ignore_index
+ intersection = output[np.where(output == target)[0]]
+ area_intersection, _ = np.histogram(
+ intersection,
+ bins=np.arange(K + 1)) # area_intersection: K, indicates the number of members in each class in intersection
+ area_output, _ = np.histogram(output, bins=np.arange(K + 1))
+ area_target, _ = np.histogram(target, bins=np.arange(K + 1))
+ area_union = area_output + area_target - area_intersection
+ return area_intersection, area_union, area_target
+
+
+def checkpoint_restore(model, exp_path, exp_name, use_cuda=True, epoch=0, dist=False, f=''):
+ if use_cuda:
+ model.cpu()
+ print(f)
+ if not f:
+ if epoch > 0:
+ f = os.path.join(exp_path, exp_name + '-%09d' % epoch + '.pth')
+ assert os.path.isfile(f)
+ else:
+ f = sorted(glob.glob(os.path.join(exp_path, exp_name + '-*.pth')))
+ if len(f) > 0:
+ f = f[-1]
+ epoch = int(f[len(exp_path) + len(exp_name) + 2:-4])
+
+ if len(f) > 0:
+ logger.info('Restore from ' + f)
+ checkpoint = torch.load(f)
+ for k, v in checkpoint.items():
+ if 'module.' in k:
+ checkpoint = {k[len('module.'):]: v for k, v in checkpoint.items()}
+ break
+ if dist:
+ model.module.load_state_dict(checkpoint)
+ else:
+ model.load_state_dict(checkpoint)
+
+ if use_cuda:
+ model.cuda()
+ return epoch + 1
+
+
+def is_power2(num):
+ return num != 0 and ((num & (num - 1)) == 0)
+
+
+def is_multiple(num, multiple):
+ return num != 0 and num % multiple == 0
+
+
+def checkpoint_save(model, exp_path, exp_name, epoch, save_freq=16, use_cuda=True):
+ f = os.path.join(exp_path, exp_name + '-%09d' % epoch + '.pth')
+ logger.info('Saving ' + f)
+ model.cpu()
+ torch.save(model.state_dict(), f)
+ if use_cuda:
+ model.cuda()
+
+ #remove previous checkpoints unless they are a multiple of save_freq to save disk space
+ epoch = epoch - 1
+ f = os.path.join(exp_path, exp_name + '-%09d' % epoch + '.pth')
+ if os.path.isfile(f):
+ if not is_multiple(epoch, save_freq):
+ os.remove(f)
+
+
+def load_model_param(model, pretrained_dict, prefix=""):
+ # suppose every param in model should exist in pretrain_dict, but may differ in the prefix of the name
+ # For example: model_dict: "0.conv.weight" pretrain_dict: "FC_layer.0.conv.weight"
+ model_dict = model.state_dict()
+ len_prefix = 0 if len(prefix) == 0 else len(prefix) + 1
+ pretrained_dict_filter = {
+ k[len_prefix:]: v
+ for k, v in pretrained_dict.items() if k[len_prefix:] in model_dict and prefix in k
+ }
+ assert len(pretrained_dict_filter) > 0
+ model_dict.update(pretrained_dict_filter)
+ model.load_state_dict(model_dict)
+ return len(pretrained_dict_filter), len(model_dict)
+
+
+def write_obj(points, colors, out_filename):
+ N = points.shape[0]
+ fout = open(out_filename, 'w')
+ for i in range(N):
+ c = colors[i]
+ fout.write('v %f %f %f %d %d %d\n' % (points[i, 0], points[i, 1], points[i, 2], c[0], c[1], c[2]))
+ fout.close()
+
+
+def get_batch_offsets(batch_idxs, bs):
+ '''
+ :param batch_idxs: (N), int
+ :param bs: int
+ :return: batch_offsets: (bs + 1)
+ '''
+ batch_offsets = torch.zeros(bs + 1).int().cuda()
+ for i in range(bs):
+ batch_offsets[i + 1] = batch_offsets[i] + (batch_idxs == i).sum()
+ assert batch_offsets[-1] == batch_idxs.shape[0]
+ return batch_offsets
+
+
+def print_error(message, user_fault=False):
+ sys.stderr.write('ERROR: ' + str(message) + '\n')
+ if user_fault:
+ sys.exit(2)
+ sys.exit(-1)
+
diff --git a/gapartnet/gapartnet/utils/utils_3d.py b/gapartnet/gapartnet/utils/utils_3d.py
new file mode 100644
index 0000000..15b9e9e
--- /dev/null
+++ b/gapartnet/gapartnet/utils/utils_3d.py
@@ -0,0 +1,77 @@
+# ScanNet util_3d: https://github.com/ScanNet/ScanNet/blob/master/BenchmarkScripts/util_3d.py
+
+import json, numpy as np
+
+
+def load_ids(filename):
+ ids = open(filename).read().splitlines()
+ ids = np.array(ids, dtype=np.int64)
+ return ids
+
+def load_ids_npcs(filename):
+ tmp = np.loadtxt(filename)
+ # print(tmp)
+ ids = tmp[0]
+ npcs = tmp[1:4].T
+ return ids,npcs
+
+
+# ------------ Instance Utils ------------ #
+
+
+class Instance(object):
+ instance_id = 0
+ label_id = 0
+ vert_count = 0
+ med_dist = -1
+ dist_conf = 0.0
+
+ def __init__(self, mesh_vert_instances, instance_id):
+ if (instance_id == -1):
+ return
+ self.instance_id = int(instance_id)
+ self.label_id = int(self.get_label_id(instance_id))
+ self.vert_count = int(self.get_instance_verts(mesh_vert_instances, instance_id))
+
+ def get_label_id(self, instance_id):
+ return int(instance_id // 1000)
+
+ def get_instance_verts(self, mesh_vert_instances, instance_id):
+ return (mesh_vert_instances == instance_id).sum()
+
+ def to_json(self):
+ return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4)
+
+ def to_dict(self):
+ dict = {}
+ dict["instance_id"] = self.instance_id
+ dict["label_id"] = self.label_id
+ dict["vert_count"] = self.vert_count
+ dict["med_dist"] = self.med_dist
+ dict["dist_conf"] = self.dist_conf
+ return dict
+
+ def from_json(self, data):
+ self.instance_id = int(data["instance_id"])
+ self.label_id = int(data["label_id"])
+ self.vert_count = int(data["vert_count"])
+ if ("med_dist" in data):
+ self.med_dist = float(data["med_dist"])
+ self.dist_conf = float(data["dist_conf"])
+
+ def __str__(self):
+ return "(" + str(self.instance_id) + ")"
+
+
+def get_instances(ids, class_ids, class_labels, id2label):
+ instances = {}
+ for label in class_labels:
+ instances[label] = []
+ instance_ids = np.unique(ids)
+ for id in instance_ids:
+ if id == -100: # -100 for others instance
+ continue
+ inst = Instance(ids, id)
+ if inst.label_id in class_ids:
+ instances[id2label[inst.label_id]].append(inst.to_dict())
+ return instances
diff --git a/gapartnet/gapartnet/version.py b/gapartnet/gapartnet/version.py
new file mode 100644
index 0000000..eb33cbb
--- /dev/null
+++ b/gapartnet/gapartnet/version.py
@@ -0,0 +1,2 @@
+__version__ = '0.1.0+cadfe09'
+git_version = 'cadfe095a43e3b9466b9f70a47541987728da0fb'
diff --git a/gapartnet/misc/info.py b/gapartnet/misc/info.py
new file mode 100644
index 0000000..d5d5357
--- /dev/null
+++ b/gapartnet/misc/info.py
@@ -0,0 +1,602 @@
+import math
+from typing import Tuple
+
+import torch
+
+OBJECT_NAME2ID = {
+ # seen category
+ "Box": 0,
+ "Remote": 1,
+ "Microwave": 2,
+ "Camera": 3,
+ "Dishwasher": 4,
+ "WashingMachine": 5,
+ "CoffeeMachine": 6,
+ "Toaster": 7,
+ "StorageFurniture": 8,
+ "AKBBucket": 9, # akb48
+ "AKBBox": 10, # akb48
+ "AKBDrawer": 11, # akb48
+ "AKBTrashCan": 12, # akb48
+ "Bucket": 13, # new
+ "Keyboard": 14, # new
+ "Printer": 15, # new
+ "Toilet": 16, # new
+ # unseen category
+ "KitchenPot": 17,
+ "Safe": 18,
+ "Oven": 19,
+ "Phone": 20,
+ "Refrigerator": 21,
+ "Table": 22,
+ "TrashCan": 23,
+ "Door": 24,
+ "Laptop": 25,
+ "Suitcase": 26, # new
+}
+
+TARGET_PARTS = [
+ 'others',
+ 'line_fixed_handle',
+ 'round_fixed_handle',
+ 'slider_button',
+ 'hinge_door',
+ 'slider_drawer',
+ 'slider_lid',
+ 'hinge_lid',
+ 'hinge_knob',
+ 'revolute_handle'
+]
+
+PART_NAME2ID = {
+ 'others': 0,
+ 'line_fixed_handle': 1,
+ 'round_fixed_handle': 2,
+ 'slider_button': 3,
+ 'hinge_door': 4,
+ 'slider_drawer': 5,
+ 'slider_lid': 6,
+ 'hinge_lid': 7,
+ 'hinge_knob': 8,
+ 'revolute_handle': 9,
+}
+
+PART_ID2NAME = {
+ 0: 'others' ,
+ 1: 'line_fixed_handle' ,
+ 2: 'round_fixed_handle' ,
+ 3: 'slider_button' ,
+ 4: 'hinge_door' ,
+ 5: 'slider_drawer' ,
+ 6: 'slider_lid' ,
+ 7: 'hinge_lid' ,
+ 8: 'hinge_knob' ,
+ 9: 'revolute_handle' ,
+}
+
+
+TARGET_PARTS = [
+ 'others',
+ 'line_fixed_handle',
+ 'round_fixed_handle',
+ 'slider_button',
+ 'hinge_door',
+ 'slider_drawer',
+ 'slider_lid',
+ 'hinge_lid',
+ 'hinge_knob',
+ 'revolute_handle',
+]
+
+TARGET_IDX = [
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+]
+PI = math.pi
+SYMMETRY_MATRIX = [
+ # type 0
+ [
+ [
+ [ 1.0, 0, 0],
+ [ 0, 1.0, 0],
+ [ 0, 0, 1.0],
+ ],
+ [
+ [ 1.0, 0, 0],
+ [ 0, 1.0, 0],
+ [ 0, 0, 1.0],
+ ],
+ ],
+
+ # type 1
+ [
+ [
+ [ 1.0, 0, 0],
+ [ 0, 1.0, 0],
+ [ 0, 0, 1.0],
+ ],
+ [
+ [-1.0, 0, 0],
+ [ 0, -1.0, 0],
+ [ 0, 0, 1.0],
+ ],
+ ],
+
+ # type 2
+ [
+ [
+ [ 1.0, 0, 0],
+ [ 0, 1.0, 0],
+ [ 0, 0, 1.0],
+ ],
+ [
+ [-1.0, 0, 0],
+ [ 0, 1.0, 0],
+ [ 0, 0, -1.0],
+ ],
+ ],
+
+ # type 3
+ [
+ [
+ [ 1.0, 0, 0],
+ [ 0, 1.0, 0],
+ [ 0, 0, 1.0],
+ ],
+ [
+ [ math.cos(PI/6), math.sin(PI/6), 0],
+ [-math.sin(PI/6), math.cos(PI/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*2/6), math.sin(PI*2/6), 0],
+ [-math.sin(PI*2/6), math.cos(PI*2/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*3/6), math.sin(PI*3/6), 0],
+ [-math.sin(PI*3/6), math.cos(PI*3/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*4/6), math.sin(PI*4/6), 0],
+ [-math.sin(PI*4/6), math.cos(PI*4/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*5/6), math.sin(PI*5/6), 0],
+ [-math.sin(PI*5/6), math.cos(PI*5/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*6/6), math.sin(PI*6/6), 0],
+ [-math.sin(PI*6/6), math.cos(PI*6/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*7/6), math.sin(PI*7/6), 0],
+ [-math.sin(PI*7/6), math.cos(PI*7/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*8/6), math.sin(PI*8/6), 0],
+ [-math.sin(PI*8/6), math.cos(PI*8/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*9/6), math.sin(PI*9/6), 0],
+ [-math.sin(PI*9/6), math.cos(PI*9/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*10/6), math.sin(PI*10/6), 0],
+ [-math.sin(PI*10/6), math.cos(PI*10/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*11/6), math.sin(PI*11/6), 0],
+ [-math.sin(PI*11/6), math.cos(PI*11/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ ],
+
+ # type 4
+ [
+ [
+ [ 1.0, 0, 0],
+ [ 0, 1.0, 0],
+ [ 0, 0, 1.0],
+ ],
+ [
+ [ math.cos(PI/6), math.sin(PI/6), 0],
+ [-math.sin(PI/6), math.cos(PI/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*2/6), math.sin(PI*2/6), 0],
+ [-math.sin(PI*2/6), math.cos(PI*2/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*3/6), math.sin(PI*3/6), 0],
+ [-math.sin(PI*3/6), math.cos(PI*3/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*4/6), math.sin(PI*4/6), 0],
+ [-math.sin(PI*4/6), math.cos(PI*4/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*5/6), math.sin(PI*5/6), 0],
+ [-math.sin(PI*5/6), math.cos(PI*5/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*6/6), math.sin(PI*6/6), 0],
+ [-math.sin(PI*6/6), math.cos(PI*6/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*7/6), math.sin(PI*7/6), 0],
+ [-math.sin(PI*7/6), math.cos(PI*7/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*8/6), math.sin(PI*8/6), 0],
+ [-math.sin(PI*8/6), math.cos(PI*8/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*9/6), math.sin(PI*9/6), 0],
+ [-math.sin(PI*9/6), math.cos(PI*9/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*10/6), math.sin(PI*10/6), 0],
+ [-math.sin(PI*10/6), math.cos(PI*10/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ [
+ [ math.cos(PI*11/6), math.sin(PI*11/6), 0],
+ [-math.sin(PI*11/6), math.cos(PI*11/6), 0],
+ [ 0, 0, 1.0]
+ ],
+ ###################### inverse ######################
+ [
+ [ math.sin(PI/6), math.cos(PI/6), 0],
+ [ math.cos(PI/6), -math.sin(PI/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*2/6), math.cos(PI*2/6), 0],
+ [ math.cos(PI*2/6), -math.sin(PI*2/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*3/6), math.cos(PI*3/6), 0],
+ [ math.cos(PI*3/6), -math.sin(PI*3/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*4/6), math.cos(PI*4/6), 0],
+ [ math.cos(PI*4/6), -math.sin(PI*4/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*5/6), math.cos(PI*5/6), 0],
+ [ math.cos(PI*5/6), -math.sin(PI*5/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*6/6), math.cos(PI*6/6), 0],
+ [ math.cos(PI*6/6), -math.sin(PI*6/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*7/6), math.cos(PI*7/6), 0],
+ [ math.cos(PI*7/6), -math.sin(PI*7/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*8/6), math.cos(PI*8/6), 0],
+ [ math.cos(PI*8/6), -math.sin(PI*8/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*9/6), math.cos(PI*9/6), 0],
+ [ math.cos(PI*9/6), -math.sin(PI*9/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*10/6), math.cos(PI*10/6), 0],
+ [ math.cos(PI*10/6), -math.sin(PI*10/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*11/6), math.cos(PI*11/6), 0],
+ [ math.cos(PI*11/6), -math.sin(PI*11/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ [
+ [ math.sin(PI*12/6), math.cos(PI*12/6), 0],
+ [ math.cos(PI*12/6), -math.sin(PI*12/6), 0],
+ [ 0, 0, -1.0]
+ ],
+ ],
+]
+
+
+def get_symmetry_matrix() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ # type 0 / 1 / 2
+ sm_1 = torch.as_tensor(SYMMETRY_MATRIX[:3], dtype=torch.float32)
+ # type 3
+ sm_2 = torch.as_tensor(SYMMETRY_MATRIX[3:4], dtype=torch.float32)
+ # type 4
+ sm_3 = torch.as_tensor(SYMMETRY_MATRIX[4:5], dtype=torch.float32)
+
+ return sm_1, sm_2, sm_3
+
+
+# import numpy as np
+# OTHER_COLOR = [230, 230, 230]
+
+# COLOR20 = np.array(
+# [[0, 128, 128], [230, 190, 255], [170, 110, 40], [255, 250, 200], [128, 0, 0],
+# [170, 255, 195], [128, 128, 0], [255, 215, 180], [0, 0, 128], [128, 128, 128],
+# [230, 25, 75], [60, 180, 75], [255, 225, 25], [0, 130, 200], [245, 130, 48],
+# [145, 30, 180], [70, 240, 240], [240, 50, 230], [210, 245, 60], [250, 190, 190]])
+
+# COLOR40 = np.array(
+# [[175,58,119], [81,175,144], [184,70,74], [40,116,79], [184,134,219], [130,137,46], [110,89,164], [92,135,74], [220,140,190], [94,103,39],
+# [144,154,219], [160,86,40], [67,107,165], [194,170,104], [162,95,150], [143,110,44], [146,72,105], [225,142,106], [162,83,86], [227,124,143],[88,170,108], [174,105,226], [78,194,83], [198,62,165], [133,188,52], [97,101,219], [190,177,52], [139,65,168], [75,202,137], [225,66,129],
+# [68,135,42], [226,116,210], [146,186,98], [68,105,201], [219,148,53], [85,142,235], [212,85,42], [78,176,223], [221,63,77], [68,195,195]
+# ])
+
+# SEMANTIC_COLOR = np.array(
+# [OTHER_COLOR, # others
+# [202, 145, 99],
+# [203, 202, 102],
+# [140, 203, 103],
+# [109, 189, 205],
+# [112,157,206],
+# [128,129,212],
+# [175,124,211],
+# [208,118,167]
+# [203,140,103],
+
+# ])
+
+# # SEMANTIC_COLOR = np.array(
+# # [OTHER_COLOR, # others
+# # [140, 103,203,],
+# # [140, 103,203,],
+# # [140, 103,203,],
+# # [140, 103,203,],
+# # [140, 103,203,],
+# # [140, 103,203,],
+# # [140, 103,203,],
+# # [140, 103,203,],
+
+# # ])
+# SEMANTIC_IDX2NAME = {
+# 0: 'others',
+# 1: 'line_fixed_handle',
+# 2: 'round_fixed_handle',
+# 3: 'slider_button',
+# 4: 'hinge_door',
+# 5: 'slider_drawer',
+# 6: 'slider_lid',
+# 7: 'hinge_lid',
+# 8: 'hinge_knob',
+# 9: 'revolute_handle',
+
+# }
+
+
+# SYMMETRY_MATRIX_INDEX = [0,1,2,2,4,0,2,4,3,1]
+
+
+# import math
+# PI = math.pi
+
+# SYMMETRY_MATRIX = {
+# 1:[[
+# [-1.0, 0, 0],
+# [ 0, -1.0, 0],
+# [ 0, 0, 1.0]
+# ]],
+# 2:[
+# [
+# [ math.cos(PI/6),math.sin(PI/6), 0],
+# [-math.sin(PI/6),math.cos(PI/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*2/6),math.sin(PI*2/6), 0],
+# [-math.sin(PI*2/6),math.cos(PI*2/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*3/6),math.sin(PI*3/6), 0],
+# [-math.sin(PI*3/6),math.cos(PI*3/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*4/6),math.sin(PI*4/6), 0],
+# [-math.sin(PI*4/6),math.cos(PI*4/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*5/6),math.sin(PI*5/6), 0],
+# [-math.sin(PI*5/6),math.cos(PI*5/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*6/6),math.sin(PI*6/6), 0],
+# [-math.sin(PI*6/6),math.cos(PI*6/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*7/6),math.sin(PI*7/6), 0],
+# [-math.sin(PI*7/6),math.cos(PI*7/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*8/6),math.sin(PI*8/6), 0],
+# [-math.sin(PI*8/6),math.cos(PI*8/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*9/6),math.sin(PI*9/6), 0],
+# [-math.sin(PI*9/6),math.cos(PI*9/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*10/6),math.sin(PI*10/6), 0],
+# [-math.sin(PI*10/6),math.cos(PI*10/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*11/6),math.sin(PI*11/6), 0],
+# [-math.sin(PI*11/6),math.cos(PI*11/6), 0],
+# [ 0, 0, 1.0]
+# ]
+# ],
+
+# 3:[
+# [
+# [ math.cos(PI/6),math.sin(PI/6), 0],
+# [-math.sin(PI/6),math.cos(PI/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*2/6),math.sin(PI*2/6), 0],
+# [-math.sin(PI*2/6),math.cos(PI*2/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*3/6),math.sin(PI*3/6), 0],
+# [-math.sin(PI*3/6),math.cos(PI*3/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*4/6),math.sin(PI*4/6), 0],
+# [-math.sin(PI*4/6),math.cos(PI*4/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*5/6),math.sin(PI*5/6), 0],
+# [-math.sin(PI*5/6),math.cos(PI*5/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*6/6),math.sin(PI*6/6), 0],
+# [-math.sin(PI*6/6),math.cos(PI*6/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*7/6),math.sin(PI*7/6), 0],
+# [-math.sin(PI*7/6),math.cos(PI*7/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*8/6),math.sin(PI*8/6), 0],
+# [-math.sin(PI*8/6),math.cos(PI*8/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*9/6),math.sin(PI*9/6), 0],
+# [-math.sin(PI*9/6),math.cos(PI*9/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*10/6),math.sin(PI*10/6), 0],
+# [-math.sin(PI*10/6),math.cos(PI*10/6), 0],
+# [ 0, 0, 1.0]
+# ],
+# [
+# [ math.cos(PI*11/6),math.sin(PI*11/6), 0],
+# [-math.sin(PI*11/6),math.cos(PI*11/6), 0],
+# [ 0, 0, 1.0]
+# ],
+
+# ###################### inverse ######################
+
+# [
+# [ math.sin(PI/6),math.cos(PI/6), 0],
+# [ math.cos(PI/6),-math.sin(PI/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*2/6),math.cos(PI*2/6), 0],
+# [ math.cos(PI*2/6),-math.sin(PI*2/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*3/6),math.cos(PI*3/6), 0],
+# [ math.cos(PI*3/6),-math.sin(PI*3/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*4/6),math.cos(PI*4/6), 0],
+# [ math.cos(PI*4/6),-math.sin(PI*4/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*5/6),math.cos(PI*5/6), 0],
+# [ math.cos(PI*5/6),-math.sin(PI*5/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*6/6),math.cos(PI*6/6), 0],
+# [ math.cos(PI*6/6),-math.sin(PI*6/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*7/6),math.cos(PI*7/6), 0],
+# [ math.cos(PI*7/6),-math.sin(PI*7/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*8/6),math.cos(PI*8/6), 0],
+# [ math.cos(PI*8/6),-math.sin(PI*8/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*9/6),math.cos(PI*9/6), 0],
+# [ math.cos(PI*9/6),-math.sin(PI*9/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*10/6),math.cos(PI*10/6), 0],
+# [ math.cos(PI*10/6),-math.sin(PI*10/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*11/6),math.cos(PI*11/6), 0],
+# [ math.cos(PI*11/6),-math.sin(PI*11/6), 0],
+# [ 0, 0, -1.0]
+# ],
+# [
+# [ math.sin(PI*12/6),math.cos(PI*12/6), 0],
+# [ math.cos(PI*12/6),-math.sin(PI*12/6), 0],
+# [ 0, 0, -1.0]
+# ]
+
+# ],
+# 4:[[
+# [-1.0, 0, 0],
+# [ 0, 1.0, 0],
+# [ 0, 0, -1.0]
+# ]],
+
+# }
diff --git a/gapartnet/misc/pose_fitting.py b/gapartnet/misc/pose_fitting.py
new file mode 100644
index 0000000..f02ba52
--- /dev/null
+++ b/gapartnet/misc/pose_fitting.py
@@ -0,0 +1,147 @@
+import numpy as np
+
+
+def estimate_similarity_umeyama(source_hom: np.ndarray, target_hom: np.ndarray):
+ num_points = source_hom.shape[1]
+
+ source_centroid = np.mean(source_hom[:3, :], axis=1)
+ target_centroid = np.mean(target_hom[:3, :], axis=1)
+
+ centered_source = source_hom[:3, :] - np.tile(source_centroid, (num_points, 1)).transpose()
+ centered_target = target_hom[:3, :] - np.tile(target_centroid, (num_points, 1)).transpose()
+
+ cov = np.matmul(centered_target, np.transpose(centered_source)) / num_points
+
+ if np.isnan(cov).any():
+ raise RuntimeError("There are NANs in the input.")
+
+ U, D, Vh = np.linalg.svd(cov, full_matrices=True)
+ d = (np.linalg.det(U) * np.linalg.det(Vh)) < 0.0
+ if d:
+ D[-1] = -D[-1]
+ U[:, -1] = -U[:, -1]
+
+ var_P = np.var(source_hom[:3, :], axis=1).sum()
+ scale_factor = 1 / var_P * np.sum(D)
+ scale = np.array([scale_factor, scale_factor, scale_factor])
+ scale_matrix = np.diag(scale)
+
+ rotation = np.matmul(U, Vh).T
+
+ translation = target_hom[:3, :].mean(axis=1) - source_hom[:3, :].mean(axis=1).dot(
+ scale_factor * rotation
+ )
+
+ out_transform = np.identity(4)
+ out_transform[:3, :3] = scale_matrix @ rotation
+ out_transform[:3, 3] = translation
+
+ return scale, rotation, translation, out_transform
+
+
+def evaluate_model(
+ out_transform: np.ndarray, source_hom: np.ndarray, target_hom: np.ndarray, pass_thrsh: float
+):
+ diff = target_hom - np.matmul(out_transform, source_hom)
+ residual_vec = np.linalg.norm(diff[:3, :], axis=0)
+ residual = np.linalg.norm(residual_vec)
+ inlier_idx = np.where(residual_vec < pass_thrsh)
+ num_inliers = np.count_nonzero(inlier_idx)
+ inlier_ratio = num_inliers / source_hom.shape[1]
+ return residual, inlier_ratio, inlier_idx[0]
+
+
+def get_RANSAC_inliers(
+ source_hom: np.ndarray, target_hom: np.ndarray,
+ max_iters: int, pass_thrsh: float, stop_thrsh: float,
+):
+ best_residual = 1e10
+ best_inlier_ratio = 0
+ best_inlier_idx = np.arange(source_hom.shape[1])
+
+ for i in range(max_iters):
+ # Pick 5 random (but corresponding) points from source and target
+ rand_idx = np.random.randint(source_hom.shape[1], size=5)
+ _, _, _, out_transform = estimate_similarity_umeyama(
+ source_hom[:, rand_idx], target_hom[:, rand_idx]
+ )
+
+ residual, inlier_ratio, inlier_idx = evaluate_model(
+ out_transform, source_hom, target_hom, pass_thrsh
+ )
+ if residual < best_residual:
+ best_residual = residual
+ best_inlier_ratio = inlier_ratio
+ best_inlier_idx = inlier_idx
+
+ if best_residual < stop_thrsh:
+ break
+
+ return best_inlier_ratio, best_inlier_idx
+
+
+def estimate_similarity_transform(
+ source: np.ndarray, target: np.ndarray,
+ stop_thrsh: float = 0.5,
+ max_iters: int = 100,
+):
+ if source.shape[0] == 1:
+ source = np.repeat(source, 2, axis=0)
+ target = np.repeat(target, 2, axis=0)
+
+ source_hom = np.transpose(np.hstack([source, np.ones([source.shape[0], 1])]))
+ target_hom = np.transpose(np.hstack([target, np.ones([target.shape[0], 1])]))
+
+ # Auto-parameter selection based on source-target heuristics
+ source_norm = np.mean(np.linalg.norm(source, axis=1))
+ target_norm = np.mean(np.linalg.norm(target, axis=1))
+
+ ratio_st = (source_norm / target_norm)
+ ratio_ts = (target_norm / source_norm)
+ pass_thrsh = ratio_st if ratio_st > ratio_ts else ratio_ts
+
+ best_inlier_ratio, best_inlier_idx = \
+ get_RANSAC_inliers(
+ source_hom, target_hom, max_iters=max_iters,
+ pass_thrsh=pass_thrsh, stop_thrsh=stop_thrsh,
+ )
+ source_inliers_hom = source_hom[:, best_inlier_idx]
+ target_inliers_hom = target_hom[:, best_inlier_idx]
+
+ if best_inlier_ratio < 0.01:
+ return np.asarray([None, None, None]), None, None, None, None
+
+ scale, rotation, translation, out_transform = estimate_similarity_umeyama(
+ source_inliers_hom, target_inliers_hom
+ )
+
+ return scale, rotation, translation, out_transform, best_inlier_idx
+
+
+def estimate_pose_from_npcs(xyz, npcs):
+ scale, rotation, translation, out_transform, best_inlier_idx = \
+ estimate_similarity_transform(npcs, xyz)
+
+ if scale[0] == None:
+ return None, np.asarray([None,None,None]), None, None, None, best_inlier_idx
+ try:
+ rotation_inv = np.linalg.pinv(rotation)
+ except:
+ import pdb
+ pdb.set_trace()
+ trans_seg = np.dot((xyz - translation), rotation_inv) / scale[0]
+ npcs_max = abs(trans_seg[best_inlier_idx]).max(0)
+
+ bbox_raw = np.asarray([
+ [-npcs_max[0], -npcs_max[1], -npcs_max[2]],
+ [npcs_max[0], -npcs_max[1], -npcs_max[2]],
+ [-npcs_max[0], npcs_max[1], -npcs_max[2]],
+ [-npcs_max[0], -npcs_max[1], npcs_max[2]],
+ [npcs_max[0], npcs_max[1], -npcs_max[2]],
+ [npcs_max[0], -npcs_max[1], npcs_max[2]],
+ [-npcs_max[0], npcs_max[1], npcs_max[2]],
+ [npcs_max[0], npcs_max[1], npcs_max[2]],
+ ])
+ bbox_trans = np.dot((bbox_raw * scale[0]), rotation) + translation
+
+ return bbox_trans, scale, rotation, translation, out_transform, best_inlier_idx
diff --git a/gapartnet/misc/visu.py b/gapartnet/misc/visu.py
new file mode 100644
index 0000000..3f849fd
--- /dev/null
+++ b/gapartnet/misc/visu.py
@@ -0,0 +1,261 @@
+import torch
+import numpy as np
+import yaml
+from os.path import join as pjoin
+import os
+import argparse
+import sys
+sys.path.append(sys.path[0] + "/..")
+import importlib
+from structure.point_cloud import PointCloud
+from dataset.gapartnet import apply_voxelization
+from misc.pose_fitting import estimate_pose_from_npcs
+import cv2
+from typing import List
+import glob
+from misc.visu_util import OBJfile2points, map2image, save_point_cloud_to_ply, \
+ WorldSpaceToBallSpace, FindMaxDis, draw_bbox_old, draw_bbox, COLOR20, \
+ OTHER_COLOR, HEIGHT, WIDTH, EDGE, K, font, fontScale, fontColor,thickness, lineType
+
+def process_gapartnetfile(GAPARTNET_DATA_ROOT, name, split = "train"):
+ data_path = f"{GAPARTNET_DATA_ROOT}/{split}/pth/{name}.pth"
+ trans_path = f"{GAPARTNET_DATA_ROOT}/{split}/meta/{name}.txt"
+
+ pc, rgb, semantic_label, instance_label, npcs_map = torch.load(data_path)
+
+ trans = np.loadtxt(trans_path)
+ xyz = pc * trans[0] + trans[1:4]
+
+ # save_point_cloud_to_ply(xyz, rgb*255, data_path.split("/")[-1].split(".")[0]+"_preinput.ply")
+ # save_point_cloud_to_ply(pc, rgb*255, data_path.split("/")[-1].split(".")[0]+"_input.ply")
+ points_input = torch.cat((torch.tensor(pc),torch.tensor(rgb)), dim = 1)
+ return points_input, trans, semantic_label, instance_label, npcs_map
+
+
+def visualize_gapartnet(
+ SAVE_ROOT,
+ GAPARTNET_DATA_ROOT,
+ RAW_IMG_ROOT,
+ save_option: List = [],
+ name: str = "pc",
+ split: str = "",
+ bboxes: np.ndarray = None, # type: ignore
+ sem_preds: np.ndarray = None, # type: ignore
+ ins_preds: np.ndarray = None, # type: ignore
+ npcs_preds: np.ndarray = None, # type: ignore
+ have_proposal = True,
+ save_detail = False,
+):
+
+ final_save_root = f"{SAVE_ROOT}/{split}"
+ save_root = f"{SAVE_ROOT}/{split}/{name}"
+ os.makedirs(final_save_root, exist_ok=True)
+ if save_detail:
+ os.makedirs(f"{save_root}", exist_ok=True)
+ final_img = np.ones((3 * (HEIGHT + EDGE) + EDGE, 4 * (WIDTH + EDGE) + EDGE, 3), dtype=np.uint8) * 255
+
+ points_input, trans, semantic_label, instance_label, npcs_map = process_gapartnetfile(GAPARTNET_DATA_ROOT, name, split)
+
+ points_input = points_input.numpy()
+ xyz_input = points_input[:,:3]
+ rgb = points_input[:,3:6]
+ xyz = xyz_input * trans[0] + trans[1:4]
+ pc_img = map2image(xyz, rgb*255.0)
+ pc_img = cv2.cvtColor(pc_img, cv2.COLOR_BGR2RGB)
+
+ if "raw" in save_option:
+ raw_img_path = f"{RAW_IMG_ROOT}/{name}.png"
+ if os.path.exists(raw_img_path):
+ raw_img = cv2.imread(raw_img_path)
+ if save_detail:
+ cv2.imwrite(f"{save_root}/raw.png", raw_img)
+ X_START = EDGE
+ Y_START = EDGE
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = raw_img
+ text = "raw"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "pc" in save_option:
+ if save_detail:
+ cv2.imwrite(f"{save_root}/pc.png", pc_img)
+ X_START = EDGE + (HEIGHT + EDGE)
+ Y_START = EDGE
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = pc_img
+ text = "pc"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "sem_pred" in save_option:
+ sem_pred_img = map2image(xyz, COLOR20[sem_preds])
+ sem_pred_img = cv2.cvtColor(sem_pred_img, cv2.COLOR_BGR2RGB)
+ if save_detail:
+ cv2.imwrite(f"{save_root}/sem_pred.png", sem_pred_img)
+ X_START = EDGE + (WIDTH + EDGE)
+ Y_START = EDGE + (HEIGHT + EDGE)
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = sem_pred_img
+ text = "sem_pred"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "ins_pred" in save_option:
+ # ins_pred_color = np.ones_like(xyz) * 230
+ # if have_proposal:
+ # for ins_i in range(len(proposal_offsets) - 1):
+ # ins_pred_color[proposal_indices[proposal_offsets[ins_i]:proposal_offsets[ins_i + 1]]] = COLOR20[ins_i%19 + 1]
+ # import pdb; pdb.set_trace()
+ ins_pred_img = map2image(xyz, COLOR20[(ins_preds%20).astype(np.int_)])
+ ins_pred_img = cv2.cvtColor(ins_pred_img, cv2.COLOR_BGR2RGB)
+ if save_detail:
+ cv2.imwrite(f"{save_root}/ins_pred.png", ins_pred_img)
+ X_START = EDGE + (WIDTH + EDGE) * 1
+ Y_START = EDGE + (HEIGHT + EDGE) * 2
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = ins_pred_img
+ text = "ins_pred"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "npcs_pred" in save_option:
+ npcs_pred_img = map2image(xyz, npcs_preds*255.0)
+ npcs_pred_img = cv2.cvtColor(npcs_pred_img, cv2.COLOR_BGR2RGB)
+ if save_detail:
+ cv2.imwrite(f"{save_root}/npcs_pred.png", npcs_pred_img)
+ X_START = EDGE + (WIDTH + EDGE) * 1
+ Y_START = EDGE + (HEIGHT + EDGE) * 3
+ # import pdb
+ # pdb.set_trace()
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = npcs_pred_img
+ text = "npcs_pred"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "bbox_pred" in save_option:
+ img_bbox_pred = pc_img.copy()
+ draw_bbox(img_bbox_pred, bboxes, trans)
+ if save_detail:
+ cv2.imwrite(f"{save_root}/bbox_pred.png", img_bbox_pred)
+ X_START = EDGE + (WIDTH + EDGE) * 2
+ Y_START = EDGE + (HEIGHT + EDGE) * 2
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = img_bbox_pred
+ text = "bbox_pred"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "bbox_pred_pure" in save_option:
+ img_empty = np.ones((HEIGHT, WIDTH, 3), dtype=np.uint8) * 255
+ draw_bbox(img_empty, bboxes, trans)
+ if save_detail:
+ cv2.imwrite(f"{save_root}/bbox_pure.png", img_empty)
+ X_START = EDGE + (WIDTH + EDGE) * 2
+ Y_START = EDGE + (HEIGHT + EDGE) * 3
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = img_empty
+ text = "bbox_pred_pure"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "sem_gt" in save_option:
+ sem_gt = semantic_label
+ sem_gt_img = map2image(xyz, COLOR20[sem_gt])
+ sem_gt_img = cv2.cvtColor(sem_gt_img, cv2.COLOR_BGR2RGB)
+ if save_detail:
+ cv2.imwrite(f"{save_root}/sem_gt.png", sem_gt_img)
+ X_START = EDGE + (WIDTH + EDGE) * 0
+ Y_START = EDGE + (HEIGHT + EDGE) * 1
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = sem_gt_img
+ text = "sem_gt"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "ins_gt" in save_option:
+ ins_gt = instance_label
+ ins_color = COLOR20[ins_gt%19 + 1]
+ ins_color[np.where(ins_gt == -100)] = 230
+ ins_gt_img = map2image(xyz, ins_color)
+
+ ins_gt_img = cv2.cvtColor(ins_gt_img, cv2.COLOR_BGR2RGB)
+ if save_detail:
+ cv2.imwrite(f"{save_root}/ins_gt.png", ins_gt_img)
+ X_START = EDGE + (WIDTH + EDGE) * 0
+ Y_START = EDGE + (HEIGHT + EDGE) * 2
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = ins_gt_img
+ text = "ins_gt"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "npcs_gt" in save_option:
+ npcs_gt = npcs_map + 0.5
+ npcs_gt_img = map2image(xyz, npcs_gt*255.0)
+ npcs_gt_img = cv2.cvtColor(npcs_gt_img, cv2.COLOR_BGR2RGB)
+ if save_detail:
+ cv2.imwrite(f"{save_root}/npcs_gt.png", npcs_gt_img)
+ X_START = EDGE + (WIDTH + EDGE) * 0
+ Y_START = EDGE + (HEIGHT + EDGE) * 3
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = npcs_gt_img
+ text = "npcs_gt"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "bbox_gt" in save_option:
+ bboxes_gt = [[]]
+ ins_gt = instance_label
+ npcs_gt = npcs_map
+ # import pdb
+ # pdb.set_trace()
+ num_ins = ins_gt.max()+1
+ if num_ins >= 1:
+ for ins_i in range(num_ins):
+ mask_i = ins_gt == ins_i
+ xyz_input_i = xyz_input[mask_i]
+ npcs_i = npcs_gt[mask_i]
+ if xyz_input_i.shape[0]<=5:
+ continue
+
+ bbox_xyz, scale, rotation, translation, out_transform, best_inlier_idx = \
+ estimate_pose_from_npcs(xyz_input_i, npcs_i)
+ if scale[0] == None:
+ continue
+ bboxes_gt[0].append(bbox_xyz.tolist())
+ img_bbox_gt = pc_img.copy()
+ draw_bbox(img_bbox_gt, bboxes_gt[0], trans)
+ if save_detail:
+ cv2.imwrite(f"{save_root}/bbox_gt.png", img_bbox_gt)
+ X_START = EDGE + (WIDTH + EDGE) * 2
+ Y_START = EDGE + (HEIGHT + EDGE) * 1
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = img_bbox_gt
+ text = "bbox_gt"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "bbox_gt_pure" in save_option:
+ bboxes_gt = [[]]
+ ins_gt = instance_label
+ npcs_gt = npcs_map
+ # import pdb
+ # pdb.set_trace()
+ num_ins = ins_gt.max()+1
+ if num_ins >= 1:
+ for ins_i in range(num_ins):
+ mask_i = ins_gt == ins_i
+ xyz_input_i = xyz_input[mask_i]
+ npcs_i = npcs_gt[mask_i]
+ if xyz_input_i.shape[0]<=5:
+ continue
+
+ bbox_xyz, scale, rotation, translation, out_transform, best_inlier_idx = \
+ estimate_pose_from_npcs(xyz_input_i, npcs_i)
+ if scale[0] == None:
+ continue
+
+ bboxes_gt[0].append(bbox_xyz.tolist())
+ img_bbox_gt_pure = np.ones((HEIGHT, WIDTH, 3), dtype=np.uint8) * 255
+ draw_bbox(img_bbox_gt_pure, bboxes_gt[0], trans)
+ if save_detail:
+ cv2.imwrite(f"{save_root}/bbox_gt_pure.png", img_bbox_gt_pure)
+ X_START = EDGE + (WIDTH + EDGE) * 2
+ Y_START = EDGE + (HEIGHT + EDGE) * 0
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = img_bbox_gt_pure
+ text = "bbox_gt_pure"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ cv2.imwrite(f"{final_save_root}/{name}.png", final_img)
diff --git a/gapartnet/misc/visu_util.py b/gapartnet/misc/visu_util.py
new file mode 100644
index 0000000..790c2b6
--- /dev/null
+++ b/gapartnet/misc/visu_util.py
@@ -0,0 +1,173 @@
+
+from os.path import join as pjoin
+import cv2
+import numpy as np
+
+COLOR20 = np.array(
+ [[230, 230, 230], [0, 128, 128], [230, 190, 255], [170, 110, 40], [255, 250, 200], [128, 0, 0],
+ [170, 255, 195], [128, 128, 0], [255, 215, 180], [0, 0, 128], [128, 128, 128],
+ [230, 25, 75], [60, 180, 75], [255, 225, 25], [0, 130, 200], [245, 130, 48],
+ [145, 30, 180], [70, 240, 240], [240, 50, 230], [210, 245, 60], [250, 190, 190]])
+HEIGHT = int(800)
+WIDTH = int(800)
+EDGE = int(40)
+K = np.array([[1268.637939453125, 0, 400, 0], [0, 1268.637939453125, 400, 0],
+ [0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32)
+
+OTHER_COLOR = np.array([230, 230, 230])
+
+font = cv2.FONT_HERSHEY_SIMPLEX
+fontScale = 2
+fontColor = (0,0,0)
+thickness = 2
+lineType = 3
+
+def save_point_cloud_to_ply(points, colors, save_name='01.ply', save_root='/scratch/genghaoran/GAPartNet/GAPartNet_inference/asset/real'):
+ '''
+ Save point cloud to ply file
+ '''
+ PLY_HEAD = f"ply\nformat ascii 1.0\nelement vertex {len(points)}\nproperty float x\nproperty float y\nproperty float z\nproperty uchar red\nproperty uchar green\nproperty uchar blue\nend_header\n"
+ file_sting = PLY_HEAD
+ for i in range(len(points)):
+ file_sting += f'{points[i][0]} {points[i][1]} {points[i][2]} {int(colors[i][0])} {int(colors[i][1])} {int(colors[i][2])}\n'
+ f = open(pjoin(save_root, save_name), 'w')
+ f.write(file_sting)
+ f.close()
+
+def draw_bbox(img, bbox_list, trans):
+ for i,bbox in enumerate(bbox_list):
+ if len(bbox) == 0:
+ continue
+ bbox = np.array(bbox)
+ bbox = bbox * trans[0]+trans[1:4]
+ K = np.array([[1268.637939453125, 0, 400, 0], [0, 1268.637939453125, 400, 0],
+ [0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32)
+ point2image = []
+ for pts in bbox:
+ x = pts[0]
+ y = pts[1]
+ z = pts[2]
+ x_new = (np.around(x * K[0][0] / z + K[0][2])).astype(dtype=int)
+ y_new = (np.around(y * K[1][1] / z + K[1][2])).astype(dtype=int)
+ point2image.append([x_new, y_new])
+ cl = [255,0,255]
+ # import pdb
+ # pdb.set_trace()
+ cv2.line(img,point2image[0],point2image[1],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[0],point2image[2],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[0],point2image[3],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[1],point2image[4],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[1],point2image[5],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[2],point2image[6],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[6],point2image[3],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[4],point2image[7],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[5],point2image[7],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[3],point2image[5],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[2],point2image[4],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[6],point2image[7],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[0],point2image[1],color=(0,0,255),thickness=3) # red
+ cv2.line(img,point2image[0],point2image[3],color=(255,0,0),thickness=3) # green
+ cv2.line(img,point2image[0],point2image[2],color=(0,255,0),thickness=3) # blue
+ return img
+
+def draw_bbox_old(img, bbox_list, trans):
+ for i,bbox in enumerate(bbox_list):
+ if len(bbox) == 0:
+ continue
+ bbox = np.array(bbox)
+ bbox = bbox * trans[0]+trans[1:4]
+ K = np.array([[1268.637939453125, 0, 400, 0], [0, 1268.637939453125, 400, 0],
+ [0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32)
+ point2image = []
+ for pts in bbox:
+ x = pts[0]
+ y = pts[1]
+ z = pts[2]
+ x_new = (np.around(x * K[0][0] / z + K[0][2])).astype(dtype=int)
+ y_new = (np.around(y * K[1][1] / z + K[1][2])).astype(dtype=int)
+ point2image.append([x_new, y_new])
+ cl = [255,0,0]
+ cv2.line(img,point2image[0],point2image[1],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[0],point2image[1],color=(255,0,0),thickness=1)
+ cv2.line(img,point2image[1],point2image[2],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[1],point2image[2],color=(0,255,0),thickness=1)
+ cv2.line(img,point2image[2],point2image[3],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[2],point2image[3],color=(0,0,255),thickness=1)
+ cv2.line(img,point2image[3],point2image[0],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[4],point2image[5],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[5],point2image[6],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[6],point2image[7],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[4],point2image[7],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[4],point2image[0],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[1],point2image[5],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[2],point2image[6],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[3],point2image[7],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ return img
+
+def map2image(pts, rgb):
+ # input为每个shape的info,取第idx行
+ image_rgb = np.ones((HEIGHT, WIDTH, 3), dtype=np.uint8) * 255
+ K = np.array([[1268.637939453125, 0, 400, 0], [0, 1268.637939453125, 400, 0],
+ [0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32)
+
+ num_point = pts.shape[0]
+ # print(num_point)
+ # print(pts)
+ # print(rgb.shape)
+
+ point2image = {}
+ for i in range(num_point):
+ x = pts[i][0]
+ y = pts[i][1]
+ z = pts[i][2]
+ x_new = (np.around(x * K[0][0] / z + K[0][2])).astype(dtype=int)
+ y_new = (np.around(y * K[1][1] / z + K[1][2])).astype(dtype=int)
+ point2image[i] = (y_new, x_new)
+
+ # 还原原始的RGB图
+ for i in range(num_point):
+ # print(i, point2image[i][0], point2image[i][1])
+ if point2image[i][0]+1 >= HEIGHT or point2image[i][0] < 0 or point2image[i][1]+1 >= WIDTH or point2image[i][1] < 0:
+ continue
+ image_rgb[point2image[i][0]][point2image[i][1]] = rgb[i]
+ image_rgb[point2image[i][0]+1][point2image[i][1]] = rgb[i]
+ image_rgb[point2image[i][0]+1][point2image[i][1]+1] = rgb[i]
+ image_rgb[point2image[i][0]][point2image[i][1]+1] = rgb[i]
+
+ # rgb_pil = Image.fromarray(image_rgb, mode='RGB')
+ # rgb_pil.save(os.path.join(save_path, f'{instance_name}_{task}.png'))
+ return image_rgb
+
+def OBJfile2points(file):
+ objFilePath = file
+ with open(objFilePath) as file:
+ points = []
+ while 1:
+ line = file.readline()
+ if not line:
+ break
+ strs = line.split(" ")
+ if strs[0] == "v":
+ points.append((float(strs[1]), float(strs[2]), float(strs[3]),float(strs[4]), float(strs[5]), float(strs[6])))
+ if strs[0] == "vt":
+ break
+ points = np.array(points)
+ return points
+
+def FindMaxDis(pointcloud):
+ max_xyz = pointcloud.max(0)
+ min_xyz = pointcloud.min(0)
+ center = (max_xyz + min_xyz) / 2
+ max_radius = ((((pointcloud - center)**2).sum(1))**0.5).max()
+ return max_radius, center
+
+def WorldSpaceToBallSpace(pointcloud):
+ """
+ change the raw pointcloud in world space to united vector ball space
+ pay attention: raw data changed
+ return: max_radius: the max_distance in raw pointcloud to center
+ center: [x,y,z] of the raw center
+ """
+ max_radius, center = FindMaxDis(pointcloud)
+ pointcloud_normalized = (pointcloud - center) / max_radius
+ return pointcloud_normalized, max_radius, center
diff --git a/gapartnet/network/backbone.py b/gapartnet/network/backbone.py
new file mode 100644
index 0000000..1bd50ec
--- /dev/null
+++ b/gapartnet/network/backbone.py
@@ -0,0 +1,287 @@
+from typing import List
+
+import spconv.pytorch as spconv
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ResBlock(spconv.SparseModule):
+ def __init__(
+ self, in_channels: int, out_channels: int, norm_fn: nn.Module, indice_key=None
+ ):
+ super().__init__()
+
+ if in_channels == out_channels:
+ self.shortcut = nn.Identity()
+ else:
+ # assert False
+ self.shortcut = spconv.SparseSequential(
+ spconv.SubMConv3d(in_channels, out_channels, kernel_size=1, \
+ bias=False),
+ norm_fn(out_channels),
+ )
+
+ self.conv1 = spconv.SparseSequential(
+ spconv.SubMConv3d(
+ in_channels, out_channels, kernel_size=3,
+ padding=1, bias=False, indice_key=indice_key,
+ ),
+ norm_fn(out_channels),
+ )
+
+ self.conv2 = spconv.SparseSequential(
+ spconv.SubMConv3d(
+ out_channels, out_channels, kernel_size=3,
+ padding=1, bias=False, indice_key=indice_key,
+ ),
+ norm_fn(out_channels),
+ )
+
+ def forward(self, x: spconv.SparseConvTensor) -> spconv.SparseConvTensor:
+ shortcut = self.shortcut(x)
+
+ x = self.conv1(x)
+ x = x.replace_feature(F.relu(x.features))
+
+ x = self.conv2(x)
+ x = x.replace_feature(F.relu(x.features + shortcut.features))
+
+ return x
+
+
+class UBlock(nn.Module):
+ def __init__(
+ self,
+ channels: List[int],
+ block_fn: nn.Module,
+ block_repeat: int,
+ norm_fn: nn.Module,
+ indice_key_id: int = 1,
+ ):
+ super().__init__()
+
+ self.channels = channels
+
+ encoder_blocks = [
+ block_fn(
+ channels[0], channels[0], norm_fn, indice_key=f"subm{indice_key_id}"
+ )
+ for _ in range(block_repeat)
+ ]
+ self.encoder_blocks = spconv.SparseSequential(*encoder_blocks)
+
+ if len(channels) > 1:
+ self.downsample = spconv.SparseSequential(
+ spconv.SparseConv3d(
+ channels[0], channels[1], kernel_size=2, stride=2,
+ bias=False, indice_key=f"spconv{indice_key_id}",
+ ),
+ norm_fn(channels[1]),
+ nn.ReLU(),
+ )
+
+ self.ublock = UBlock(
+ channels[1:], block_fn, block_repeat, norm_fn, indice_key_id + 1
+ )
+
+ self.upsample = spconv.SparseSequential(
+ spconv.SparseInverseConv3d(
+ channels[1], channels[0], kernel_size=2,
+ bias=False, indice_key=f"spconv{indice_key_id}",
+ ),
+ norm_fn(channels[0]),
+ nn.ReLU(),
+ )
+
+ decoder_blocks = [
+ block_fn(
+ channels[0] * 2, channels[0], norm_fn,
+ indice_key=f"subm{indice_key_id}",
+ ),
+ ]
+ for _ in range(block_repeat -1):
+ decoder_blocks.append(
+ block_fn(
+ channels[0], channels[0], norm_fn,
+ indice_key=f"subm{indice_key_id}",
+ )
+ )
+ self.decoder_blocks = spconv.SparseSequential(*decoder_blocks)
+
+ def forward(self, x: spconv.SparseConvTensor) -> spconv.SparseConvTensor:
+ x = self.encoder_blocks(x)
+ shortcut = x
+
+ if len(self.channels) > 1:
+ x = self.downsample(x)
+ x = self.ublock(x)
+ x = self.upsample(x)
+
+ x = x.replace_feature(torch.cat([x.features, shortcut.features],\
+ dim=-1))
+ x = self.decoder_blocks(x)
+
+ return x
+
+
+class SparseUNet(nn.Module):
+ def __init__(self, stem: nn.Module, ublock: UBlock):
+ super().__init__()
+
+ self.stem = stem
+ self.ublock = ublock
+
+ def forward(self, x):
+ if self.stem is not None:
+ x = self.stem(x)
+ x = self.ublock(x)
+ return x
+
+ @classmethod
+ def build(
+ cls,
+ in_channels: int,
+ channels: List[int],
+ block_repeat: int,
+ norm_fn: nn.Module,
+ without_stem: bool = False,
+ ):
+ if not without_stem:
+ stem = spconv.SparseSequential(
+ spconv.SubMConv3d(
+ in_channels, channels[0], kernel_size=3,
+ padding=1, bias=False, indice_key="subm1",
+ ),
+ norm_fn(channels[0]),
+ nn.ReLU(),
+ )
+ else:
+ stem = spconv.SparseSequential(
+ norm_fn(channels[0]),
+ nn.ReLU(),
+ )
+
+ block = UBlock(channels, ResBlock, block_repeat, norm_fn, \
+ indice_key_id=1)
+
+ return SparseUNet(stem, block)
+
+
+
+class UBlock_NoSkip(nn.Module):
+ def __init__(
+ self,
+ channels: List[int],
+ block_fn: nn.Module,
+ block_repeat: int,
+ norm_fn: nn.Module,
+ indice_key_id: int = 1,
+ ):
+ super().__init__()
+
+ self.channels = channels
+
+ encoder_blocks = [
+ block_fn(
+ channels[0], channels[0], norm_fn, indice_key=f"subm{indice_key_id}"
+ )
+ for _ in range(block_repeat)
+ ]
+ self.encoder_blocks = spconv.SparseSequential(*encoder_blocks)
+
+ if len(channels) > 1:
+ self.downsample = spconv.SparseSequential(
+ spconv.SparseConv3d(
+ channels[0], channels[1], kernel_size=2, stride=2,
+ bias=False, indice_key=f"spconv{indice_key_id}",
+ ),
+ norm_fn(channels[1]),
+ nn.ReLU(),
+ )
+
+ self.ublock = UBlock(
+ channels[1:], block_fn, block_repeat, norm_fn, indice_key_id + 1
+ )
+
+ self.upsample = spconv.SparseSequential(
+ spconv.SparseInverseConv3d(
+ channels[1], channels[0], kernel_size=2,
+ bias=False, indice_key=f"spconv{indice_key_id}",
+ ),
+ norm_fn(channels[0]),
+ nn.ReLU(),
+ )
+
+ decoder_blocks = [
+ block_fn(
+ channels[0], channels[0], norm_fn,
+ indice_key=f"subm{indice_key_id}",
+ ),
+ ]
+ for _ in range(block_repeat -1):
+ decoder_blocks.append(
+ block_fn(
+ channels[0], channels[0], norm_fn,
+ indice_key=f"subm{indice_key_id}",
+ )
+ )
+ self.decoder_blocks = spconv.SparseSequential(*decoder_blocks)
+
+ def forward(self, x: spconv.SparseConvTensor) -> spconv.SparseConvTensor:
+ x = self.encoder_blocks(x)
+ # shortcut = x
+
+ if len(self.channels) > 1:
+ x = self.downsample(x)
+ x = self.ublock(x)
+ x = self.upsample(x)
+
+ # x = x.replace_feature(torch.cat([x.features, shortcut.features],\
+ # dim=-1))
+ x = self.decoder_blocks(x)
+
+ return x
+
+
+class SparseUNet_NoSkip(nn.Module):
+ def __init__(self, stem: nn.Module, ublock: UBlock_NoSkip):
+ super().__init__()
+
+ self.stem = stem
+ self.ublock = ublock
+
+ def forward(self, x):
+ if self.stem is not None:
+ x = self.stem(x)
+ x = self.ublock(x)
+ return x
+
+ @classmethod
+ def build(
+ cls,
+ in_channels: int,
+ channels: List[int],
+ block_repeat: int,
+ norm_fn: nn.Module,
+ without_stem: bool = False,
+ ):
+ if not without_stem:
+ stem = spconv.SparseSequential(
+ spconv.SubMConv3d(
+ in_channels, channels[0], kernel_size=3,
+ padding=1, bias=False, indice_key="subm1",
+ ),
+ norm_fn(channels[0]),
+ nn.ReLU(),
+ )
+ else:
+ stem = spconv.SparseSequential(
+ norm_fn(channels[0]),
+ nn.ReLU(),
+ )
+
+ block = UBlock(channels, ResBlock, block_repeat, norm_fn, \
+ indice_key_id=1)
+
+ return SparseUNet(stem, block)
diff --git a/gapartnet/network/grouping_utils.py b/gapartnet/network/grouping_utils.py
new file mode 100644
index 0000000..5e2b20b
--- /dev/null
+++ b/gapartnet/network/grouping_utils.py
@@ -0,0 +1,454 @@
+from typing import List, Tuple
+
+import torch
+from epic_ops.ball_query import ball_query
+from epic_ops.ccl import connected_components_labeling
+from epic_ops.nms import nms
+from epic_ops.reduce import segmented_reduce
+from epic_ops.voxelize import voxelize
+
+from structure.instances import Instances
+
+
+@torch.jit.script
+def compute_npcs_loss(
+ npcs_preds: torch.Tensor,
+ gt_npcs: torch.Tensor,
+ proposal_indices: torch.Tensor,
+ symmetry_matrix: torch.Tensor,
+) -> torch.Tensor:
+ _, num_points_per_proposal = torch.unique_consecutive(
+ proposal_indices, return_counts=True
+ )
+
+ # gt_npcs: n, 3 -> n, 1, 1, 3
+ # symmetry_matrix: n, m, 3, 3
+ gt_npcs = gt_npcs[:, None, None, :] @ symmetry_matrix
+ # n, m, 1, 3 -> n, m, 3
+ gt_npcs = gt_npcs.squeeze(2)
+
+ # npcs_preds: n, 3 -> n, 1, 3
+ dist2 = (npcs_preds[:, None, :] - gt_npcs - 0.5) ** 2
+ # n, m, 3 -> n, m
+ dist2 = dist2.sum(dim=-1)
+
+ loss = torch.where(
+ dist2 <= 0.01,
+ 5 * dist2, torch.sqrt(dist2) - 0.05,
+ )
+ loss = torch.segment_reduce(
+ loss, "mean", lengths=num_points_per_proposal
+ )
+ loss, _ = loss.min(dim=-1)
+ return loss.mean()
+
+
+@torch.jit.script
+def segmented_voxelize(
+ pt_xyz: torch.Tensor,
+ pt_features: torch.Tensor,
+ segment_offsets: torch.Tensor,
+ segment_indices: torch.Tensor,
+ num_points_per_segment: torch.Tensor,
+ score_fullscale: float,
+ score_scale: float,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ segment_offsets_begin = segment_offsets[:-1]
+ segment_offsets_end = segment_offsets[1:]
+
+ segment_coords_mean = segmented_reduce(
+ pt_xyz, segment_offsets_begin, segment_offsets_end, mode="sum"
+ ) / num_points_per_segment[:, None]
+
+ centered_points = pt_xyz - segment_coords_mean[segment_indices]
+
+ segment_coords_min = segmented_reduce(
+ centered_points, segment_offsets_begin, segment_offsets_end, mode="min"
+ )
+ segment_coords_max = segmented_reduce(
+ centered_points, segment_offsets_begin, segment_offsets_end, mode="max"
+ )
+
+ # score_fullscale = 50.
+ # score_scale = 50.
+ segment_scales = 1. / (
+ (segment_coords_max - segment_coords_min) / score_fullscale
+ ).max(-1)[0] - 0.01
+ segment_scales = torch.clamp(segment_scales, min=None, max=score_scale)
+
+ min_xyz = segment_coords_min * segment_scales[..., None]
+ max_xyz = segment_coords_max * segment_scales[..., None]
+
+ segment_scales = segment_scales[segment_indices]
+ scaled_points = centered_points * segment_scales[..., None]
+
+ range_xyz = max_xyz - min_xyz
+ offsets = -min_xyz + torch.clamp(
+ score_fullscale - range_xyz - 0.001, min=0
+ ) * torch.rand(3, dtype=min_xyz.dtype, device=min_xyz.device) + torch.clamp(
+ score_fullscale - range_xyz + 0.001, max=0
+ ) * torch.rand(3, dtype=min_xyz.dtype, device=min_xyz.device)
+ scaled_points += offsets[segment_indices]
+
+ voxel_features, voxel_coords, voxel_batch_indices, pc_voxel_id = voxelize(
+ scaled_points,
+ pt_features,
+ batch_offsets=segment_offsets.long(),
+ voxel_size=torch.as_tensor([1., 1., 1.]),
+ points_range_min=torch.as_tensor([0., 0., 0.]),
+ points_range_max=torch.as_tensor([score_fullscale, score_fullscale, score_fullscale]),
+ reduction="mean",
+ )
+ voxel_coords = torch.cat([voxel_batch_indices[:, None], voxel_coords], dim=1)
+
+ return voxel_features, voxel_coords, pc_voxel_id
+
+
+@torch.jit.script
+def cluster_proposals(
+ pt_xyz: torch.Tensor,
+ batch_indices: torch.Tensor,
+ batch_offsets: torch.Tensor,
+ sem_preds: torch.Tensor,
+ ball_query_radius: float,
+ max_num_points_per_query: int,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ device = pt_xyz.device
+ index_dtype = batch_indices.dtype
+
+ clustered_indices, num_points_per_query = ball_query(
+ pt_xyz,
+ pt_xyz,
+ batch_indices,
+ batch_offsets,
+ ball_query_radius,
+ max_num_points_per_query,
+ point_labels=sem_preds,
+ query_labels=sem_preds,
+ )
+
+ ccl_indices_begin = torch.arange(
+ pt_xyz.shape[0], dtype=index_dtype, device=device
+ ) * max_num_points_per_query
+ ccl_indices_end = ccl_indices_begin + num_points_per_query
+ ccl_indices = torch.stack([ccl_indices_begin, ccl_indices_end], dim=1)
+ cc_labels = connected_components_labeling(
+ ccl_indices.view(-1), clustered_indices.view(-1), compacted=False
+ )
+
+ sorted_cc_labels, sorted_indices = torch.sort(cc_labels)
+ return sorted_cc_labels, sorted_indices
+
+
+@torch.jit.script
+def get_gt_scores(
+ ious: torch.Tensor, fg_thresh: float = 0.75, bg_thresh: float = 0.25
+) -> torch.Tensor:
+ fg_mask = ious > fg_thresh
+ bg_mask = ious < bg_thresh
+ intermidiate_mask = ~(fg_mask | bg_mask)
+
+ gt_scores = fg_mask.float()
+ k = 1 / (fg_thresh - bg_thresh)
+ b = bg_thresh / (bg_thresh - fg_thresh)
+ gt_scores[intermidiate_mask] = ious[intermidiate_mask] * k + b
+
+ return gt_scores
+
+
+def filter_invalid_proposals(
+ proposals: Instances,
+ score_threshold: float,
+ min_num_points_per_proposal: int,
+) -> Instances:
+ score_preds = proposals.score_preds
+ proposal_indices = proposals.proposal_indices
+ num_points_per_proposal = proposals.num_points_per_proposal
+
+ valid_proposals_mask = (
+ score_preds > score_threshold
+ ) & (num_points_per_proposal > min_num_points_per_proposal)
+ valid_points_mask = valid_proposals_mask[proposal_indices]
+
+ proposal_indices = proposal_indices[valid_points_mask]
+ _, proposal_indices, num_points_per_proposal = torch.unique_consecutive(
+ proposal_indices, return_inverse=True, return_counts=True
+ )
+ num_proposals = num_points_per_proposal.shape[0]
+
+ proposal_offsets = torch.zeros(
+ num_proposals + 1, dtype=torch.int32, device=proposal_indices.device
+ )
+ proposal_offsets[1:] = num_points_per_proposal.cumsum(0)
+
+ if proposals.npcs_valid_mask is not None:
+ valid_npcs_mask = valid_points_mask[proposals.npcs_valid_mask]
+ else:
+ valid_npcs_mask = valid_points_mask
+
+ return Instances(
+ valid_mask=proposals.valid_mask,
+ sorted_indices=proposals.sorted_indices[valid_points_mask],
+ pt_xyz=proposals.pt_xyz[valid_points_mask],
+ batch_indices=proposals.batch_indices[valid_points_mask],
+ proposal_offsets=proposal_offsets,
+ proposal_indices=proposal_indices,
+ num_points_per_proposal=num_points_per_proposal,
+ sem_preds=proposals.sem_preds[valid_points_mask],
+ score_preds=proposals.score_preds[valid_proposals_mask],
+ npcs_preds=proposals.npcs_preds[
+ valid_npcs_mask
+ ] if proposals.npcs_preds is not None else None,
+ sem_labels=proposals.sem_labels[
+ valid_points_mask
+ ] if proposals.sem_labels is not None else None,
+ instance_labels=proposals.instance_labels[
+ valid_points_mask
+ ] if proposals.instance_labels is not None else None,
+ instance_sem_labels=proposals.instance_sem_labels,
+ num_points_per_instance=proposals.num_points_per_instance,
+ gt_npcs=proposals.gt_npcs[
+ valid_npcs_mask
+ ] if proposals.gt_npcs is not None else None,
+ npcs_valid_mask=proposals.npcs_valid_mask[valid_points_mask] \
+ if proposals.npcs_valid_mask is not None else None,
+ ious=proposals.ious[
+ valid_proposals_mask
+ ] if proposals.ious is not None else None,
+ )
+
+
+def apply_nms(
+ proposals: Instances,
+ iou_threshold: float = 0.3,
+):
+ score_preds = proposals.score_preds
+ sorted_indices = proposals.sorted_indices
+ proposal_offsets = proposals.proposal_offsets
+ proposal_indices = proposals.proposal_indices
+ num_points_per_proposal = proposals.num_points_per_proposal
+
+ values = torch.ones(
+ sorted_indices.shape[0], dtype=torch.float32, device=sorted_indices.device
+ )
+ csr = torch.sparse_csr_tensor(
+ proposal_offsets.int(), sorted_indices.int(), values,
+ dtype=torch.float32, device=sorted_indices.device,
+ )
+ intersection = csr @ csr.t()
+ intersection = intersection.to_dense()
+ union = num_points_per_proposal[:, None] + num_points_per_proposal[None, :]
+ union = union - intersection
+
+ ious = intersection / (union + 1e-8)
+ keep = nms(ious.cuda(), score_preds.cuda(), iou_threshold)
+ keep = keep.to(score_preds.device)
+
+ valid_proposals_mask = torch.zeros(
+ ious.shape[0], dtype=torch.bool, device=score_preds.device
+ )
+ valid_proposals_mask[keep] = True
+ valid_points_mask = valid_proposals_mask[proposal_indices]
+
+ proposal_indices = proposal_indices[valid_points_mask]
+ _, proposal_indices, num_points_per_proposal = torch.unique_consecutive(
+ proposal_indices, return_inverse=True, return_counts=True
+ )
+ num_proposals = num_points_per_proposal.shape[0]
+
+ proposal_offsets = torch.zeros(
+ num_proposals + 1, dtype=torch.int32, device=proposal_indices.device
+ )
+ proposal_offsets[1:] = num_points_per_proposal.cumsum(0)
+
+ if proposals.npcs_valid_mask is not None:
+ valid_npcs_mask = valid_points_mask[proposals.npcs_valid_mask]
+ else:
+ valid_npcs_mask = valid_points_mask
+
+ return Instances(
+ valid_mask=proposals.valid_mask,
+ sorted_indices=proposals.sorted_indices[valid_points_mask],
+ pt_xyz=proposals.pt_xyz[valid_points_mask],
+ batch_indices=proposals.batch_indices[valid_points_mask],
+ proposal_offsets=proposal_offsets,
+ proposal_indices=proposal_indices,
+ num_points_per_proposal=num_points_per_proposal,
+ sem_preds=proposals.sem_preds[valid_points_mask],
+ score_preds=proposals.score_preds[valid_proposals_mask],
+ npcs_preds=proposals.npcs_preds[
+ valid_npcs_mask
+ ] if proposals.npcs_preds is not None else None,
+ sem_labels=proposals.sem_labels[
+ valid_points_mask
+ ] if proposals.sem_labels is not None else None,
+ instance_labels=proposals.instance_labels[
+ valid_points_mask
+ ] if proposals.instance_labels is not None else None,
+ instance_sem_labels=proposals.instance_sem_labels,
+ num_points_per_instance=proposals.num_points_per_instance,
+ gt_npcs=proposals.gt_npcs[
+ valid_npcs_mask
+ ] if proposals.gt_npcs is not None else None,
+ npcs_valid_mask=proposals.npcs_valid_mask[valid_points_mask] \
+ if proposals.npcs_valid_mask is not None else None,
+ ious=proposals.ious[
+ valid_proposals_mask
+ ] if proposals.ious is not None else None,
+ )
+
+
+@torch.jit.script
+def voc_ap(
+ rec: torch.Tensor,
+ prec: torch.Tensor,
+ use_07_metric: bool = False,
+) -> float:
+ if use_07_metric:
+ # 11 point metric
+ ap = torch.as_tensor(0, dtype=prec.dtype, device=prec.device)
+ for t in range(0, 11, 1):
+ t /= 10.0
+ if torch.sum(rec >= t) == 0:
+ p = torch.as_tensor(0, dtype=prec.dtype, device=prec.device)
+ else:
+ p = torch.max(prec[rec >= t])
+ ap = ap + p / 11.0
+ else:
+ # correct AP calculation
+ # first append sentinel values at the end
+ mrec = torch.cat([
+ torch.as_tensor([0.0], dtype=rec.dtype, device=rec.device),
+ rec,
+ torch.as_tensor([1.0], dtype=rec.dtype, device=rec.device),
+ ], dim=0)
+ mpre = torch.cat([
+ torch.as_tensor([0.0], dtype=prec.dtype, device=prec.device),
+ prec,
+ torch.as_tensor([0.0], dtype=prec.dtype, device=prec.device),
+ ], dim=0)
+
+ # compute the precision envelope
+ for i in range(mpre.shape[0] - 1, 0, -1):
+ mpre[i - 1] = torch.maximum(mpre[i - 1], mpre[i])
+
+ # to calculate area under PR curve, look for points
+ # where X axis (recall) changes value
+ i = torch.where(mrec[1:] != mrec[:-1])[0]
+
+ # and sum (\Delta recall) * prec
+ ap = torch.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
+ return float(ap.item())
+
+
+@torch.jit.script
+def _compute_ap_per_class(
+ tp: torch.Tensor, fp: torch.Tensor, num_gt_instances: int
+) -> float:
+ if tp.shape[0] == 0:
+ return 0.
+
+ tp = tp.cumsum(0)
+ fp = fp.cumsum(0)
+ rec = tp / num_gt_instances
+ prec = tp / (tp + fp + 1e-8)
+
+ return voc_ap(rec, prec)
+
+
+@torch.jit.script
+def _compute_ap(
+ confidence: torch.Tensor,
+ classes: torch.Tensor,
+ sorted_indices: torch.Tensor,
+ batch_indices: torch.Tensor,
+ sample_indices: torch.Tensor,
+ proposal_indices: torch.Tensor,
+ matched: List[torch.Tensor],
+ instance_sem_labels: List[torch.Tensor],
+ ious: List[torch.Tensor],
+ num_classes: int,
+ iou_threshold: float,
+):
+ sorted_indices_cpu = sorted_indices.cpu()
+
+ num_proposals = confidence.shape[0]
+ tp = torch.zeros(num_proposals, dtype=torch.float32)
+ fp = torch.zeros(num_proposals, dtype=torch.float32)
+ for i in range(num_proposals):
+ idx = sorted_indices_cpu[i]
+
+ class_idx = classes[idx]
+ batch_idx = batch_indices[idx].item()
+ sample_idx = sample_indices[idx]
+ proposal_idx = proposal_indices[idx]
+
+ instance_sem_labels_i = instance_sem_labels[batch_idx][sample_idx]
+ invalid_instance_mask = instance_sem_labels_i != class_idx
+
+ ious_i = ious[batch_idx][proposal_idx].clone()
+ ious_i[invalid_instance_mask] = 0.
+ if ious_i.shape[0] == 0:
+ max_iou, max_idx = 0., 0
+ else:
+ max_iou, max_idx = ious_i.max(0)
+ max_iou, max_idx = max_iou.item(), int(max_idx.item())
+
+ if max_iou > iou_threshold:
+ if not matched[batch_idx][sample_idx, max_idx].item():
+ tp[i] = 1.0
+ matched[batch_idx][sample_idx, max_idx] = True
+ else:
+ fp[i] = 1.0
+ else:
+ fp[i] = 1.0
+
+ tp = tp.to(device=confidence.device)
+ fp = fp.to(device=confidence.device)
+
+ sorted_classes = classes[sorted_indices]
+ gt_classes = torch.cat([x.view(-1) for x in instance_sem_labels], dim=0)
+ aps: List[float] = []
+ for c in range(1, num_classes):
+ num_gt_instances = (gt_classes == c).sum()
+ mask = sorted_classes == c
+ ap = _compute_ap_per_class(tp[mask], fp[mask], num_gt_instances)
+ aps.append(ap)
+ return aps
+
+
+def compute_ap(
+ proposals: List[Instances],
+ num_classes: int = 9,
+ iou_threshold: float = 0.5,
+ device="cpu",
+):
+ confidence = torch.cat([p.score_preds for p in proposals], dim=0).to(device=device)
+ classes = torch.cat([p.pt_sem_classes for p in proposals], dim=0).to(device=device)
+ sorted_indices = torch.argsort(confidence, descending=True)
+
+ batch_indices = torch.cat([
+ torch.full((p.score_preds.shape[0],), i, dtype=torch.int64)
+ for i, p in enumerate(proposals)
+ ], dim=0)
+ sample_indices = torch.cat([
+ p.batch_indices[p.proposal_offsets[:-1].long()].long()
+ for p in proposals
+ ], dim=0).cpu()
+ proposal_indices = torch.cat([
+ torch.arange(p.score_preds.shape[0], dtype=torch.int64)
+ for p in proposals
+ ], dim=0)
+
+ matched = [
+ torch.zeros_like(p.instance_sem_labels, dtype=torch.bool, device="cpu")
+ for p in proposals
+ ]
+
+ ap = _compute_ap(confidence,classes,sorted_indices,batch_indices,
+ sample_indices,proposal_indices,matched,
+ [p.instance_sem_labels.to(device=device) for p in proposals],
+ [p.ious.to(device=device) for p in proposals],
+ num_classes, iou_threshold,)
+
+ return ap
diff --git a/gapartnet/network/losses.py b/gapartnet/network/losses.py
new file mode 100644
index 0000000..576e4d1
--- /dev/null
+++ b/gapartnet/network/losses.py
@@ -0,0 +1,158 @@
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from kornia.metrics import mean_iou as _mean_iou
+
+
+@torch.no_grad()
+def pixel_accuracy(pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> float:
+ """
+ Compute pixel accuracy.
+ """
+
+ if gt_mask.numel() > 0:
+ accuracy = (pred_mask == gt_mask).sum() / gt_mask.numel()
+ accuracy = accuracy.item()
+ else:
+ accuracy = 0.
+ return accuracy
+
+
+@torch.no_grad()
+def mean_iou(pred_mask: torch.Tensor, gt_mask: torch.Tensor, num_classes: int) -> float:
+ """
+ Compute mIoU.
+ """
+
+ valid_mask = gt_mask >= 0
+ miou = _mean_iou(
+ pred_mask[valid_mask][None], gt_mask[valid_mask][None], num_classes=num_classes
+ ).mean()
+ return miou
+
+
+def focal_loss(
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ alpha: Optional[torch.Tensor] = None,
+ gamma: float = 2.0,
+ reduction: str = "mean",
+ ignore_index: int = -100,
+) -> torch.Tensor:
+ if ignore_index is not None:
+ valid_mask = targets != ignore_index
+ targets = targets[valid_mask]
+
+ if targets.shape[0] == 0:
+ return torch.tensor(0.0).to(dtype=inputs.dtype, device=inputs.device)
+
+ inputs = inputs[valid_mask]
+
+ log_p = F.log_softmax(inputs, dim=-1)
+ ce_loss = F.nll_loss(
+ log_p, targets, weight=alpha, ignore_index=ignore_index, reduction="none"
+ )
+ log_p_t = log_p.gather(1, targets[:, None]).squeeze(-1)
+ loss = ce_loss * ((1 - log_p_t.exp()) ** gamma)
+
+ if reduction == "mean":
+ loss = loss.mean()
+ elif reduction == "sum":
+ loss = loss.sum()
+
+ return loss
+
+
+def sigmoid_focal_loss(
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ alpha: float = 0.25,
+ gamma: float = 2,
+ reduction: str = "none",
+) -> torch.Tensor:
+ """
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+ Args:
+ inputs (Tensor): A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ alpha (float): Weighting factor in range (0,1) to balance
+ positive vs negative examples or -1 for ignore. Default: ``0.25``.
+ gamma (float): Exponent of the modulating factor (1 - p_t) to
+ balance easy vs hard examples. Default: ``2``.
+ reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
+ ``'none'``: No reduction will be applied to the output.
+ ``'mean'``: The output will be averaged.
+ ``'sum'``: The output will be summed. Default: ``'none'``.
+ Returns:
+ Loss tensor with the reduction option applied.
+ """
+ p = torch.sigmoid(inputs)
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+ p_t = p * targets + (1 - p) * (1 - targets)
+ loss = ce_loss * ((1 - p_t) ** gamma)
+
+ if alpha >= 0:
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+ loss = alpha_t * loss
+
+ if reduction == "mean":
+ loss = loss.mean()
+ elif reduction == "sum":
+ loss = loss.sum()
+
+ return loss
+
+
+def one_hot(
+ labels: torch.Tensor,
+ num_classes: int,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ eps: float = 1e-6,
+) -> torch.Tensor:
+ if not isinstance(labels, torch.Tensor):
+ raise TypeError(f"Input labels type is not a torch.Tensor. Got {type(labels)}")
+
+ if not labels.dtype == torch.int64:
+ raise ValueError(f"labels must be of the same dtype torch.int64. Got: {labels.dtype}")
+
+ if num_classes < 1:
+ raise ValueError("The number of classes must be bigger than one." " Got: {}".format(num_classes))
+
+ shape = labels.shape
+ one_hot = torch.zeros((shape[0], num_classes) + shape[1:], device=device, dtype=dtype)
+
+ return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + eps
+
+
+def dice_loss(input: torch.Tensor, target: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
+ if not isinstance(input, torch.Tensor):
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
+
+ if not len(input.shape) == 4:
+ raise ValueError(f"Invalid input shape, we expect BxNxHxW. Got: {input.shape}")
+
+ if not input.shape[-2:] == target.shape[-2:]:
+ raise ValueError(f"input and target shapes must be the same. Got: {input.shape} and {target.shape}")
+
+ if not input.device == target.device:
+ raise ValueError(f"input and target must be in the same device. Got: {input.device} and {target.device}")
+
+ # compute softmax over the classes axis
+ input_soft: torch.Tensor = F.softmax(input, dim=1)
+
+ # create the labels one hot tensor
+ target_one_hot: torch.Tensor = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype)
+
+ # compute the actual dice score
+ dims = (1, 2, 3)
+ intersection = torch.sum(input_soft * target_one_hot, dims)
+ cardinality = torch.sum(input_soft + target_one_hot, dims)
+
+ dice_score = 2.0 * intersection / (cardinality + eps)
+
+ return torch.mean(-dice_score + 1.0)
diff --git a/gapartnet/network/model.py b/gapartnet/network/model.py
new file mode 100644
index 0000000..8bfe266
--- /dev/null
+++ b/gapartnet/network/model.py
@@ -0,0 +1,1051 @@
+import lightning.pytorch as lp
+from typing import Optional, Dict, Tuple, List
+import functools
+import torch
+import torch.nn as nn
+import numpy as np
+import spconv.pytorch as spconv
+import torch.nn.functional as F
+from einops import rearrange, repeat
+
+from epic_ops.reduce import segmented_maxpool
+from epic_ops.iou import batch_instance_seg_iou
+
+from network.losses import focal_loss, dice_loss, pixel_accuracy, mean_iou
+from network.grouping_utils import (apply_nms, cluster_proposals, compute_ap,
+ compute_npcs_loss, filter_invalid_proposals,
+ get_gt_scores, segmented_voxelize)
+from structure.point_cloud import PointCloudBatch, PointCloud
+from structure.segmentation import Segmentation
+from structure.instances import Instances
+
+from misc.info import OBJECT_NAME2ID, PART_ID2NAME, PART_NAME2ID, get_symmetry_matrix
+from misc.visu import visualize_gapartnet
+from misc.pose_fitting import estimate_pose_from_npcs
+
+class GAPartNet(lp.LightningModule):
+ def __init__(
+ self,
+ in_channels: int,
+ num_part_classes: int,
+ backbone_type: str = "SparseUNet",
+ backbone_cfg: Dict = {},
+ learning_rate: float = 1e-3,
+ # semantic segmentation
+ ignore_sem_label: int = -100,
+ use_sem_focal_loss: bool = True,
+ use_sem_dice_loss: bool = True,
+ # instance segmentation
+ instance_seg_cfg: Dict = {},
+ # npcs segmentation
+ symmetry_indices: List = [],
+ # training
+ training_schedule: List = [],
+ # validation
+ val_score_threshold: float = 0.09,
+ val_min_num_points_per_proposal: int = 3,
+ val_nms_iou_threshold: float = 0.3,
+ val_ap_iou_threshold: float = 0.5,
+ # testing
+ visualize_cfg: Dict = {},
+
+ debug: bool = True,
+ ckpt: str = "", # type: ignore
+ ):
+ super().__init__()
+ self.save_hyperparameters()
+ self.validation_step_outputs = []
+
+ self.in_channels = in_channels
+ self.num_part_classes = num_part_classes
+ self.backbone_type = backbone_type
+ self.backbone_cfg = backbone_cfg
+ self.learning_rate = learning_rate
+ self.ignore_sem_label = ignore_sem_label
+ self.use_sem_focal_loss = use_sem_focal_loss
+ self.use_sem_dice_loss = use_sem_dice_loss
+ self.visualize_cfg = visualize_cfg
+ self.start_scorenet, self.start_npcs = training_schedule
+ self.start_clustering = min(self.start_scorenet, self.start_npcs)
+ self.val_nms_iou_threshold = val_nms_iou_threshold
+ self.val_ap_iou_threshold = val_ap_iou_threshold
+ self.val_score_threshold = val_score_threshold
+ self.val_min_num_points_per_proposal = val_min_num_points_per_proposal
+ self.symmetry_indices = torch.as_tensor(symmetry_indices, dtype=torch.int64).to(self.device)
+
+ self.ball_query_radius = instance_seg_cfg["ball_query_radius"]
+ self.max_num_points_per_query = instance_seg_cfg["max_num_points_per_query"]
+ self.min_num_points_per_proposal = instance_seg_cfg["min_num_points_per_proposal"]
+ self.max_num_points_per_query_shift = instance_seg_cfg["max_num_points_per_query_shift"]
+ self.score_fullscale = instance_seg_cfg["score_fullscale"]
+ self.score_scale = instance_seg_cfg["score_scale"]
+
+
+ ## network
+ norm_fn = functools.partial(nn.BatchNorm1d, eps=1e-4, momentum=0.1)
+ # backbone
+ if self.backbone_type == "SparseUNet":
+ from .backbone import SparseUNet
+ channels = self.backbone_cfg["channels"]
+ block_repeat = self.backbone_cfg["block_repeat"]
+ self.backbone = SparseUNet.build(in_channels, channels, block_repeat, norm_fn)
+ else:
+ raise NotImplementedError(f"backbone type {self.backbone_type} not implemented")
+ # semantic segmentation head
+ self.sem_seg_head = nn.Linear(channels[0], self.num_part_classes)
+ # offset prediction
+ self.offset_head = nn.Sequential(
+ nn.Linear(channels[0], channels[0]),
+ norm_fn(channels[0]),
+ nn.ReLU(inplace=True),
+ nn.Linear(channels[0], 3),
+ )
+
+ self.score_unet = SparseUNet.build(
+ channels[0], channels[:2], block_repeat, norm_fn, without_stem=True
+ )
+ self.score_head = nn.Linear(channels[0], self.num_part_classes - 1)
+
+
+ self.npcs_unet = SparseUNet.build(
+ channels[0], channels[:2], block_repeat, norm_fn, without_stem=True
+ )
+ self.npcs_head = nn.Linear(channels[0], 3 * (self.num_part_classes - 1))
+
+ # symmetry
+ # self.register_buffer(
+ # "symmetry_indices", torch.as_tensor(symmetry_indices, dtype=torch.int64)
+ # )
+ # if symmetry_indices is not None:
+ # assert len(symmetry_indices) == self.num_part_classes, (symmetry_indices, self.num_part_classes)
+ (
+ symmetry_matrix_1, symmetry_matrix_2, symmetry_matrix_3
+ ) = get_symmetry_matrix()
+ self.symmetry_matrix_1 = symmetry_matrix_1
+ self.symmetry_matrix_2 = symmetry_matrix_2
+ self.symmetry_matrix_3 = symmetry_matrix_3
+
+
+ if ckpt != "":
+ print("Loading pretrained model from:", ckpt)
+ state_dict = torch.load(
+ ckpt, map_location="cpu"
+ )["state_dict"]
+ missing_keys, unexpected_keys = self.load_state_dict(
+ state_dict, strict=False,
+ )
+ if len(missing_keys) > 0:
+ print("missing_keys:", missing_keys)
+ if len(unexpected_keys) > 0:
+ print("unexpected_keys:", unexpected_keys)
+
+ def forward_backbone(
+ self,
+ pc_batch: PointCloudBatch,
+ ):
+ if self.backbone_type == "SparseUNet":
+ voxel_tensor = pc_batch.voxel_tensor
+ pc_voxel_id = pc_batch.pc_voxel_id
+ voxel_features = self.backbone(voxel_tensor)
+ pc_feature = voxel_features.features[pc_voxel_id]
+
+ return pc_feature
+
+ def forward_sem_seg(
+ self,
+ pc_feature: torch.Tensor,
+ ) -> torch.Tensor:
+ sem_logits = self.sem_seg_head(pc_feature)
+
+ return sem_logits
+
+ def loss_sem_seg(
+ self,
+ sem_logits: torch.Tensor,
+ sem_labels: torch.Tensor,
+ ) -> torch.Tensor:
+ if self.use_sem_focal_loss:
+ loss = focal_loss(
+ sem_logits, sem_labels,
+ alpha=None,
+ gamma=2.0,
+ ignore_index=self.ignore_sem_label,
+ reduction="mean",
+ )
+ else:
+ loss = F.cross_entropy(
+ sem_logits, sem_labels,
+ weight=None,
+ ignore_index=self.ignore_sem_label,
+ reduction="mean",
+ )
+
+ if self.use_sem_dice_loss:
+ loss += dice_loss(
+ sem_logits[:, :, None, None], sem_labels[:, None, None],
+ )
+
+ return loss
+
+ def forward_offset(
+ self,
+ pc_feature: torch.Tensor,
+ ) -> torch.Tensor:
+ offset = self.offset_head(pc_feature)
+
+ return offset
+
+ def loss_offset(
+ self,
+ offsets: torch.Tensor,
+ gt_offsets: torch.Tensor,
+ sem_labels: torch.Tensor,
+ instance_labels: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ valid_instance_mask = (sem_labels > 0) & (instance_labels >= 0)
+
+ pt_diff = offsets - gt_offsets
+ pt_dist = torch.sum(pt_diff.abs(), dim=-1)
+ loss_offset_dist = pt_dist[valid_instance_mask].mean()
+
+ gt_offsets_norm = torch.norm(gt_offsets, p=2, dim=-1)
+ gt_offsets = gt_offsets / (gt_offsets_norm[:, None] + 1e-8)
+
+ offsets_norm = torch.norm(offsets, p=2, dim=-1)
+ offsets = offsets / (offsets_norm[:, None] + 1e-8)
+
+ dir_diff = -(gt_offsets * offsets).sum(-1)
+ loss_offset_dir = dir_diff[valid_instance_mask].mean()
+
+ return loss_offset_dist, loss_offset_dir
+
+ def proposal_clustering_and_revoxelize(
+ self,
+ pt_xyz: torch.Tensor,
+ batch_indices: torch.Tensor,
+ pt_features: torch.Tensor,
+ sem_preds: torch.Tensor,
+ offset_preds: torch.Tensor,
+ instance_labels: Optional[torch.Tensor],
+ ):
+ device = self.device
+
+ if instance_labels is not None:
+ valid_mask = (sem_preds > 0) & (instance_labels >= 0)
+ else:
+ valid_mask = sem_preds > 0
+
+ pt_xyz = pt_xyz[valid_mask]
+ batch_indices = batch_indices[valid_mask]
+ pt_features = pt_features[valid_mask]
+ sem_preds = sem_preds[valid_mask].int()
+ offset_preds = offset_preds[valid_mask]
+ if instance_labels is not None:
+ instance_labels = instance_labels[valid_mask]
+
+ # get batch offsets (csr) from batch indices
+ _, batch_indices_compact, num_points_per_batch = torch.unique_consecutive(
+ batch_indices, return_inverse=True, return_counts=True
+ )
+ batch_indices_compact = batch_indices_compact.int()
+ batch_offsets = torch.zeros(
+ (num_points_per_batch.shape[0] + 1,), dtype=torch.int32, device=device
+ )
+ batch_offsets[1:] = num_points_per_batch.cumsum(0)
+
+ # cluster proposals: dual set
+ sorted_cc_labels, sorted_indices = cluster_proposals(
+ pt_xyz, batch_indices_compact, batch_offsets, sem_preds,
+ self.ball_query_radius, self.max_num_points_per_query,
+ )
+
+ sorted_cc_labels_shift, sorted_indices_shift = cluster_proposals(
+ pt_xyz + offset_preds, batch_indices_compact, batch_offsets, sem_preds,
+ self.ball_query_radius, self.max_num_points_per_query_shift,
+ )
+
+ # combine clusters
+ sorted_cc_labels = torch.cat([
+ sorted_cc_labels,
+ sorted_cc_labels_shift + sorted_cc_labels.shape[0],
+ ], dim=0)
+ sorted_indices = torch.cat([sorted_indices, sorted_indices_shift], dim=0)
+
+ # compact the proposal ids
+ _, proposal_indices, num_points_per_proposal = torch.unique_consecutive(
+ sorted_cc_labels, return_inverse=True, return_counts=True
+ )
+
+ # remove small proposals
+ valid_proposal_mask = (
+ num_points_per_proposal >= self.min_num_points_per_proposal
+ )
+ # proposal to point
+ valid_point_mask = valid_proposal_mask[proposal_indices]
+
+ sorted_indices = sorted_indices[valid_point_mask]
+ if sorted_indices.shape[0] == 0:
+ import pdb; pdb.set_trace()
+ return None, None, None
+
+ batch_indices = batch_indices[sorted_indices]
+ pt_xyz = pt_xyz[sorted_indices]
+ pt_features = pt_features[sorted_indices]
+ sem_preds = sem_preds[sorted_indices]
+ if instance_labels is not None:
+ instance_labels = instance_labels[sorted_indices]
+
+ # re-compact the proposal ids
+ proposal_indices = proposal_indices[valid_point_mask]
+ _, proposal_indices, num_points_per_proposal = torch.unique_consecutive(
+ proposal_indices, return_inverse=True, return_counts=True
+ )
+ num_proposals = num_points_per_proposal.shape[0]
+
+ # get proposal batch offsets
+ proposal_offsets = torch.zeros(
+ num_proposals + 1, dtype=torch.int32, device=device
+ )
+ proposal_offsets[1:] = num_points_per_proposal.cumsum(0)
+
+ # voxelization
+ voxel_features, voxel_coords, pc_voxel_id = segmented_voxelize(
+ pt_xyz, pt_features,
+ proposal_offsets, proposal_indices,
+ num_points_per_proposal,
+ self.score_fullscale, self.score_scale,
+ )
+ voxel_tensor = spconv.SparseConvTensor(
+ voxel_features, voxel_coords.int(),
+ spatial_shape=[self.score_fullscale] * 3,
+ batch_size=num_proposals,
+ )
+ if not (pc_voxel_id >= 0).all():
+ import pdb
+ pdb.set_trace()
+
+
+
+ proposals = Instances(
+ valid_mask=valid_mask,
+ sorted_indices=sorted_indices,
+ pt_xyz=pt_xyz,
+ batch_indices=batch_indices,
+ proposal_offsets=proposal_offsets,
+ proposal_indices=proposal_indices,
+ num_points_per_proposal=num_points_per_proposal,
+ sem_preds=sem_preds,
+ instance_labels=instance_labels,
+ )
+
+ return voxel_tensor, pc_voxel_id, proposals
+
+ def forward_proposal_score(
+ self,
+ voxel_tensor: spconv.SparseConvTensor,
+ pc_voxel_id: torch.Tensor,
+ proposals: Instances,
+ ):
+ proposal_offsets = proposals.proposal_offsets
+ proposal_offsets_begin = proposal_offsets[:-1] # type: ignore
+ proposal_offsets_end = proposal_offsets[1:] # type: ignore
+
+ score_features = self.score_unet(voxel_tensor)
+ score_features = score_features.features[pc_voxel_id]
+ pooled_score_features, _ = segmented_maxpool(
+ score_features, proposal_offsets_begin, proposal_offsets_end
+ )
+ score_logits = self.score_head(pooled_score_features)
+
+ return score_logits
+
+ def loss_proposal_score(
+ self,
+ score_logits: torch.Tensor,
+ proposals: Instances,
+ num_points_per_instance: torch.Tensor,
+ ) -> torch.Tensor:
+ ious = batch_instance_seg_iou(
+ proposals.proposal_offsets, # type: ignore
+ proposals.instance_labels, # type: ignore
+ proposals.batch_indices, # type: ignore
+ num_points_per_instance,
+ )
+ proposals.ious = ious
+ proposals.num_points_per_instance = num_points_per_instance
+
+ ious_max = ious.max(-1)[0]
+ gt_scores = get_gt_scores(ious_max, 0.75, 0.25)
+
+ return F.binary_cross_entropy_with_logits(score_logits, gt_scores)
+
+ def forward_proposal_npcs(
+ self,
+ voxel_tensor: spconv.SparseConvTensor,
+ pc_voxel_id: torch.Tensor,
+ ) -> torch.Tensor:
+ npcs_features = self.npcs_unet(voxel_tensor)
+ npcs_logits = self.npcs_head(npcs_features.features)
+ npcs_logits = npcs_logits[pc_voxel_id]
+
+ return npcs_logits
+
+ def loss_proposal_npcs(
+ self,
+ npcs_logits: torch.Tensor,
+ gt_npcs: torch.Tensor,
+ proposals: Instances,
+ ) -> torch.Tensor:
+ sem_preds, sem_labels = proposals.sem_preds, proposals.sem_labels
+ proposal_indices = proposals.proposal_indices
+ valid_mask = (sem_preds == sem_labels) & (gt_npcs != 0).any(dim=-1)
+
+ npcs_logits = npcs_logits[valid_mask]
+ gt_npcs = gt_npcs[valid_mask]
+ sem_preds = sem_preds[valid_mask].long()
+ sem_labels = sem_labels[valid_mask]
+ proposal_indices = proposal_indices[valid_mask]
+
+ npcs_logits = rearrange(npcs_logits, "n (k c) -> n k c", c=3)
+ npcs_logits = npcs_logits.gather(
+ 1, index=repeat(sem_preds - 1, "n -> n one c", one=1, c=3)
+ ).squeeze(1)
+
+ proposals.npcs_preds = npcs_logits.detach()
+ proposals.gt_npcs = gt_npcs
+ proposals.npcs_valid_mask = valid_mask
+
+ loss_npcs = 0
+
+ # import pdb; pdb.set_trace()
+ self.symmetry_indices = self.symmetry_indices.to(sem_preds.device)
+ self.symmetry_matrix_1 = self.symmetry_matrix_1.to(sem_preds.device)
+ self.symmetry_matrix_2 = self.symmetry_matrix_2.to(sem_preds.device)
+ self.symmetry_matrix_3 = self.symmetry_matrix_3.to(sem_preds.device)
+ # import pdb; pdb.set_trace()
+ symmetry_indices = self.symmetry_indices[sem_preds]
+ # group #1
+ group_1_mask = symmetry_indices < 3
+ symmetry_indices_1 = symmetry_indices[group_1_mask]
+ if symmetry_indices_1.shape[0] > 0:
+ loss_npcs += compute_npcs_loss(
+ npcs_logits[group_1_mask], gt_npcs[group_1_mask],
+ proposal_indices[group_1_mask],
+ self.symmetry_matrix_1[symmetry_indices_1]
+ )
+
+ # group #2
+ group_2_mask = symmetry_indices == 3
+ symmetry_indices_2 = symmetry_indices[group_2_mask]
+ if symmetry_indices_2.shape[0] > 0:
+ loss_npcs += compute_npcs_loss(
+ npcs_logits[group_2_mask], gt_npcs[group_2_mask],
+ proposal_indices[group_2_mask],
+ self.symmetry_matrix_2[symmetry_indices_2 - 3]
+ )
+
+ # group #3
+ group_3_mask = symmetry_indices == 4
+ symmetry_indices_3 = symmetry_indices[group_3_mask]
+ if symmetry_indices_3.shape[0] > 0:
+ loss_npcs += compute_npcs_loss(
+ npcs_logits[group_3_mask], gt_npcs[group_3_mask],
+ proposal_indices[group_3_mask],
+ self.symmetry_matrix_3[symmetry_indices_3 - 4]
+ )
+
+ return loss_npcs
+
+
+
+ def _training_or_validation_step(
+ self,
+ point_clouds: List[PointCloud],
+ batch_idx: int,
+ running_mode: str,
+ ):
+ batch_size = len(point_clouds)
+
+ # data batch parsing
+ data_batch = PointCloud.collate(point_clouds)
+ points = data_batch.points
+ sem_labels = data_batch.sem_labels
+ pc_ids = data_batch.pc_ids
+ instance_regions = data_batch.instance_regions
+ instance_labels = data_batch.instance_labels
+ batch_indices = data_batch.batch_indices
+ instance_sem_labels = data_batch.instance_sem_labels
+ num_points_per_instance = data_batch.num_points_per_instance
+ gt_npcs = data_batch.gt_npcs
+
+
+ pt_xyz = points[:, :3]
+ # cls_labels.to(pt_xyz.device)
+
+ pc_feature = self.forward_backbone(pc_batch=data_batch)
+
+ # semantic segmentation
+ sem_logits = self.forward_sem_seg(pc_feature)
+
+ sem_preds = torch.argmax(sem_logits.detach(), dim=-1)
+
+ if sem_labels is not None:
+ loss_sem_seg = self.loss_sem_seg(sem_logits, sem_labels)
+ else:
+ loss_sem_seg = 0.
+
+ # accuracy
+ all_accu = (sem_preds == sem_labels).sum().float() / (sem_labels.shape[0])
+
+ if sem_labels is not None:
+ instance_mask = sem_labels > 0
+ pixel_accu = pixel_accuracy(sem_preds[instance_mask], sem_labels[instance_mask])
+ else:
+ pixel_accu = 0.0
+
+ sem_seg = Segmentation(
+ batch_size=batch_size,
+ sem_preds=sem_preds,
+ sem_labels=sem_labels,
+ all_accu=all_accu,
+ pixel_accu=pixel_accu,)
+
+ offsets_preds = self.forward_offset(pc_feature)
+ if instance_regions is not None:
+ offsets_gt = instance_regions[:, :3] - pt_xyz
+ loss_offset_dist, loss_offset_dir = self.loss_offset(
+ offsets_preds, offsets_gt, sem_labels, instance_labels, # type: ignore
+ )
+ else:
+ import pdb; pdb.set_trace()
+ loss_offset_dist, loss_offset_dir = 0., 0.
+
+ if self.current_epoch >= self.start_clustering:
+ voxel_tensor, pc_voxel_id, proposals = self.proposal_clustering_and_revoxelize(
+ pt_xyz = pt_xyz,
+ batch_indices=batch_indices,
+ pt_features=pc_feature,
+ sem_preds=sem_preds,
+ offset_preds=offsets_preds,
+ instance_labels=instance_labels,
+ )
+
+ if sem_labels is not None and proposals is not None:
+ proposals.sem_labels = sem_labels[proposals.valid_mask][
+ proposals.sorted_indices
+ ]
+ if proposals is not None:
+ proposals.instance_sem_labels = instance_sem_labels
+ else:
+ proposals = None
+
+ # clustering and scoring
+ if self.current_epoch >= self.start_scorenet and voxel_tensor is not None and proposals is not None: # type: ignore
+ score_logits = self.forward_proposal_score(
+ voxel_tensor, pc_voxel_id, proposals
+ ) # type: ignore
+ proposal_offsets_begin = proposals.proposal_offsets[:-1].long() # type: ignore
+
+ if proposals.sem_labels is not None: # type: ignore
+ proposal_sem_labels = proposals.sem_labels[proposal_offsets_begin].long() # type: ignore
+ else:
+ proposal_sem_labels = proposals.sem_preds[proposal_offsets_begin].long() # type: ignore
+ score_logits = score_logits.gather(
+ 1, proposal_sem_labels[:, None] - 1
+ ).squeeze(1)
+ proposals.score_preds = score_logits.detach().sigmoid() # type: ignore
+ if num_points_per_instance is not None: # type: ignore
+ loss_prop_score = self.loss_proposal_score(
+ score_logits, proposals, num_points_per_instance, # type: ignore
+ )
+ else:
+ import pdb; pdb.set_trace()
+ loss_prop_score = 0.0
+ else:
+ loss_prop_score = 0.0
+
+ if self.current_epoch >= self.start_npcs and voxel_tensor is not None:
+ npcs_logits = self.forward_proposal_npcs(
+ voxel_tensor, pc_voxel_id
+ )
+ if gt_npcs is not None:
+ gt_npcs = gt_npcs[proposals.valid_mask][proposals.sorted_indices]
+ loss_prop_npcs = self.loss_proposal_npcs(npcs_logits, gt_npcs, proposals)
+
+ # valid_mask = (sem_preds == sem_labels) & (gt_npcs != 0).any(dim=-1)
+ # proposals.npcs_valid_mask = valid_mask
+
+
+ # npcs_logits = rearrange(npcs_logits, "n (k c) -> n k c", c=3)
+ # npcs_preds = npcs_logits.gather(
+ # 1, index=repeat(sem_preds - 1, "n -> n one c", one=1, c=3)
+ # ).squeeze(1)
+
+ # import pdb; pdb.set_trace()
+ # npcs_logits = npcs_logits.detach()
+ # npcs_logits = rearrange(npcs_logits, "n (k c) -> n k c", c=3)
+ # npcs_logits = npcs_logits.gather(1, index=repeat(proposals.sem_preds.long() - 1, "n -> n one c", one=1, c=3)).squeeze(1)
+ # proposals.npcs_preds = npcs_logits
+ # npcs_map = torch.zeros_like(pt_xyz, device=pt_xyz.device)
+ # npcs_map[instance_mask]
+
+
+ else:
+ npcs_preds = None
+
+ loss_prop_npcs = 0.0
+
+ # total loss
+ loss = loss_sem_seg + loss_offset_dist + loss_offset_dir + loss_prop_score + loss_prop_npcs
+
+
+ prefix = running_mode
+ # losses
+ self.log(
+ f"{prefix}_loss/total_loss",
+ loss,
+ batch_size=batch_size,
+ on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
+ self.log(
+ f"{prefix}_loss/loss_sem_seg",
+ loss_sem_seg,
+ batch_size=batch_size,
+ on_epoch=True, prog_bar=False, logger=True, sync_dist=True
+ )
+ self.log(
+ f"{prefix}_loss/loss_offset_dist",
+ loss_offset_dist,
+ batch_size=batch_size,
+ on_epoch=True, prog_bar=False, logger=True, sync_dist=True
+ )
+ self.log(
+ f"{prefix}_loss/loss_offset_dir",
+ loss_offset_dir,
+ batch_size=batch_size,
+ on_epoch=True, prog_bar=False, logger=True, sync_dist=True
+ )
+ self.log(
+ f"{prefix}_loss/loss_prop_score",
+ loss_prop_score,
+ batch_size=batch_size,
+ on_epoch=True, prog_bar=False, logger=True, sync_dist=True
+ )
+ self.log(
+ f"{prefix}_loss/loss_prop_npcs",
+ loss_prop_npcs,
+ batch_size=batch_size,
+ on_epoch=True, prog_bar=False, logger=True, sync_dist=True
+ )
+
+ # evaulation metrics
+ self.log(
+ f"{prefix}/all_accu",
+ all_accu * 100,
+ batch_size=batch_size,
+ on_epoch=True, prog_bar=False, logger=True, sync_dist=True
+ )
+ self.log(
+ f"{prefix}/pixel_accu",
+ pixel_accu * 100,
+ batch_size=batch_size,
+ on_epoch=True, prog_bar=False, logger=True, sync_dist=True
+ )
+
+ return pc_ids, sem_seg, proposals, loss
+
+ def training_step(self, point_clouds: List[PointCloud], batch_idx: int):
+ _, _, proposals, loss = self._training_or_validation_step(
+ point_clouds, batch_idx, "train"
+ )
+ return loss
+
+ def validation_step(self, point_clouds: List[PointCloud], batch_idx: int, dataloader_idx: int = 0):
+ split = ["val", "test_intra", "test_inter"]
+ pc_ids, sem_seg, proposals, _ = self._training_or_validation_step(
+ point_clouds, batch_idx, split[dataloader_idx]
+ )
+
+ if proposals is not None:
+ proposals = filter_invalid_proposals(
+ proposals,
+ score_threshold=self.val_score_threshold,
+ min_num_points_per_proposal=self.val_min_num_points_per_proposal
+ )
+ proposals = apply_nms(proposals, self.val_nms_iou_threshold)
+
+
+ if dataloader_idx > len(self.validation_step_outputs) - 1:
+ self.validation_step_outputs.append([])
+
+ proposals.pt_sem_classes = proposals.sem_preds[proposals.proposal_offsets[:-1].long()]
+
+
+
+
+ proposals_ = Instances(
+ score_preds=proposals.score_preds, pt_sem_classes=proposals.pt_sem_classes, \
+ batch_indices=proposals.batch_indices, instance_sem_labels=proposals.instance_sem_labels, \
+ ious=proposals.ious, proposal_offsets=proposals.proposal_offsets, valid_mask= proposals.valid_mask)
+
+
+
+ self.validation_step_outputs[dataloader_idx].append((pc_ids, sem_seg, proposals_))
+ return pc_ids, sem_seg, proposals_
+
+ def on_validation_epoch_end(self):
+
+
+ splits = ["val", "test_intra", "test_inter"]
+ all_accus = []
+ pixel_accus = []
+ mious = []
+ mean_ap50 = []
+ mAPs = []
+ for i_, validation_step_outputs in enumerate(self.validation_step_outputs):
+ split = splits[i_]
+ pc_ids = [i for x in validation_step_outputs for i in x[0]]
+ batch_size = validation_step_outputs[0][1].batch_size
+ data_size = sum(x[1].batch_size for x in validation_step_outputs)
+ all_accu = sum(x[1].all_accu for x in validation_step_outputs) / len(validation_step_outputs)
+ pixel_accu = sum(x[1].pixel_accu for x in validation_step_outputs) / len(validation_step_outputs)
+
+
+ # semantic segmentation
+ sem_preds = torch.cat(
+ [x[1].sem_preds for x in validation_step_outputs], dim=0
+ )
+ sem_labels = torch.cat(
+ [x[1].sem_labels for x in validation_step_outputs], dim=0
+ )
+ miou = mean_iou(sem_preds, sem_labels, num_classes=self.num_part_classes)
+
+ # instance segmentation
+ proposals = [x[2] for x in validation_step_outputs if x[2]!= None]
+
+ del validation_step_outputs
+
+ # semantic segmentation
+ all_accus.append(all_accu)
+ mious.append(miou)
+ pixel_accus.append(pixel_accu)
+
+ # instance segmentation
+ thes = [0.5 + 0.05 * i for i in range(10)]
+ aps = []
+ for the in thes:
+ ap = compute_ap(proposals, self.num_part_classes, the)
+ aps.append(ap)
+ if the == 0.5:
+ ap50 = ap
+ mAP = np.array(aps).mean()
+ mAPs.append(mAP)
+
+ for class_idx in range(1, self.num_part_classes):
+ partname = PART_ID2NAME[class_idx]
+ self.log(
+ f"{split}/AP@50_{partname}",
+ np.mean(ap50[class_idx - 1]) * 100,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=False, logger=True, sync_dist=True,
+ )
+
+
+ mean_ap50.append(np.mean(ap50))
+
+
+ self.log(f"{split}/AP@50",
+ np.mean(ap50) * 100,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+ self.log(f"{split}/mAP",
+ mAP * 100,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+ self.log(f"{split}/all_accu",
+ all_accu * 100.0,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
+ self.log(f"{split}/pixel_accu",
+ pixel_accu * 100.0,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
+ self.log(f"{split}/miou",
+ miou * 100.0,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+
+ self.log("monitor_metrics/mean_all_accu",
+ (all_accus[1]+all_accus[2])/2 * 100.0,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
+ self.log("monitor_metrics/mean_pixel_accu",
+ (pixel_accus[1]+pixel_accus[2])/2 * 100.0,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
+ self.log("monitor_metrics/mean_imou",
+ (mious[1]+mious[2])/2 * 100.0,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+ self.log("monitor_metrics/mean_AP@50",
+ (mean_ap50[1]+mean_ap50[2])/2 * 100.0,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+ self.log("monitor_metrics/mean_mAP",
+ (mAPs[1]+mAPs[2])/2 * 100.0,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=True, logger=True, sync_dist=True
+ )
+
+
+
+ self.validation_step_outputs.clear()
+
+ def test_step(self, point_clouds: List[PointCloud], batch_idx: int, dataloader_idx: int = 0):
+ split = ["val", "intra", "inter"]
+ pc_ids, sem_seg, proposals, _ = self._training_or_validation_step(
+ point_clouds, batch_idx, split[dataloader_idx]
+ )
+
+ if proposals is not None:
+ proposals = filter_invalid_proposals(
+ proposals,
+ score_threshold=self.val_score_threshold,
+ min_num_points_per_proposal=self.val_min_num_points_per_proposal
+ )
+ proposals = apply_nms(proposals, self.val_nms_iou_threshold)
+
+
+ if dataloader_idx > len(self.validation_step_outputs) - 1:
+ self.validation_step_outputs.append([])
+
+ proposals.pt_sem_classes = proposals.sem_preds[proposals.proposal_offsets[:-1].long()]
+
+
+ # # NMS and filter
+ # if proposals is not None:
+ # proposals = filter_invalid_proposals(
+ # proposals,
+ # score_threshold=self.val_score_threshold,
+ # min_num_points_per_proposal=self.val_min_num_points_per_proposal
+ # )
+ # proposals = apply_nms(proposals, self.val_nms_iou_threshold)
+
+
+ proposals_ = Instances(
+ pt_xyz = proposals.pt_xyz,
+ score_preds=proposals.score_preds,
+ pt_sem_classes=proposals.pt_sem_classes,
+ batch_indices=proposals.batch_indices,
+ instance_sem_labels=proposals.instance_sem_labels,
+ ious=proposals.ious,
+ proposal_offsets=proposals.proposal_offsets,
+ proposal_indices=proposals.proposal_indices,
+ valid_mask= proposals.valid_mask,
+ num_points_per_proposal=proposals.num_points_per_proposal,
+ num_points_per_instance=proposals.num_points_per_instance,
+ sorted_indices = proposals.sorted_indices,
+ npcs_preds=proposals.npcs_preds,
+ npcs_valid_mask=proposals.npcs_valid_mask,
+
+ )
+
+ self.validation_step_outputs[dataloader_idx].append((pc_ids, sem_seg, proposals_))
+ return pc_ids, sem_seg, proposals_
+
+ def on_test_epoch_end(self):
+
+ splits = ["val", "test_intra", "test_inter"]
+ all_accus = []
+ pixel_accus = []
+ mious = []
+ mean_ap50 = []
+ mAPs = []
+ for i_, validation_step_outputs in enumerate(self.validation_step_outputs):
+ split = splits[i_]
+ pc_ids = [i for x in validation_step_outputs for i in x[0]]
+ batch_size = validation_step_outputs[0][1].batch_size
+ data_size = sum(x[1].batch_size for x in validation_step_outputs)
+ all_accu = sum(x[1].all_accu for x in validation_step_outputs) / len(validation_step_outputs)
+ pixel_accu = sum(x[1].pixel_accu for x in validation_step_outputs) / len(validation_step_outputs)
+
+
+ # semantic segmentation
+ sem_preds = torch.cat(
+ [x[1].sem_preds for x in validation_step_outputs], dim=0
+ )
+ sem_labels = torch.cat(
+ [x[1].sem_labels for x in validation_step_outputs], dim=0
+ )
+ miou = mean_iou(sem_preds, sem_labels, num_classes=self.num_part_classes)
+
+ # instance segmentation
+ proposals = [x[2] for x in validation_step_outputs if x[2]!= None]
+
+ # pose estimation
+ # npcs_preds = torch.cat(
+ # [x[2] for x in validation_step_outputs], dim=0
+ # )
+
+ # npcs_maps = pcs[0].points[:,:3].clone()
+ # npcs_maps[:] = 230./255.
+ # if proposals is not None:
+ # npcs_maps[proposal_indices] = npcs_preds
+ # import pdb; pdb.set_trace()
+ # import pdb; pdb.set_trace()
+ del validation_step_outputs
+
+ # semantic segmentation
+ all_accus.append(all_accu)
+ mious.append(miou)
+ pixel_accus.append(pixel_accu)
+
+ # instance segmentation
+ thes = [0.5 + 0.05 * i for i in range(10)]
+ aps = []
+ for the in thes:
+ ap = compute_ap(proposals, self.num_part_classes, the)
+ aps.append(ap)
+ if the == 0.5:
+ ap50 = ap
+ mAP = np.array(aps).mean()
+ mAPs.append(mAP)
+
+ for class_idx in range(1, self.num_part_classes):
+ partname = PART_ID2NAME[class_idx]
+ self.log(
+ f"{split}/AP@50_{partname}",
+ np.mean(ap50[class_idx - 1]) * 100,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=False, logger=True, sync_dist=True,
+ )
+
+
+ mean_ap50.append(np.mean(ap50))
+
+
+ if self.visualize_cfg["visualize"] == True:
+ if self.visualize_cfg["sample_num"] > 0:
+ import random
+ sample_ids = random.sample(range(len(pc_ids)), self.visualize_cfg["sample_num"])
+ else:
+ sample_ids = range(len(pc_ids))
+
+
+ for sample_id in sample_ids:
+ batch_id = sample_id // batch_size
+ batch_sample_id = sample_id % batch_size
+ proposals_ = proposals[batch_id]
+
+ mask = proposals_.valid_mask.reshape(-1,20000)[batch_sample_id]
+
+ if proposals_ is not None:
+ pt_xyz = proposals_.pt_xyz
+ batch_indices = proposals_.batch_indices
+ proposal_offsets = proposals_.proposal_offsets
+ num_points_per_proposal = proposals_.num_points_per_proposal
+ num_proposals = num_points_per_proposal.shape[0]
+ score_preds= proposals_.score_preds
+ mask = proposals_.valid_mask
+
+ indices = torch.arange(mask.shape[0], dtype=torch.int64,device = sem_preds.device)
+ proposal_indices = indices[proposals_.valid_mask][proposals_.sorted_indices]
+
+ ins_seg_preds = torch.ones(mask.shape[0]) * 0
+ for ins_i in range(len(proposal_offsets) - 1):
+ ins_seg_preds[proposal_indices[proposal_offsets[ins_i]:proposal_offsets[ins_i + 1]]] = ins_i+1
+
+ npcs_maps = torch.ones(proposals_.valid_mask.shape[0],3, device = proposals_.valid_mask.device)*0.0
+ valid_index = torch.where(proposals_.valid_mask==True)[0][proposals_.sorted_indices.long()[torch.where(proposals_.npcs_valid_mask==True)]]
+ npcs_maps[valid_index] = proposals_.npcs_preds
+
+ # bounding box
+ bboxes = []
+ bboxes_batch_index = []
+ for proposal_i in range(len(proposal_offsets) - 1):
+ npcs_i = npcs_maps[proposal_indices[proposal_offsets[proposal_i]:proposal_offsets[proposal_i + 1]]]
+ npcs_i = npcs_i - 0.5
+ xyz_i = pt_xyz[proposal_offsets[proposal_i]:proposal_offsets[proposal_i + 1]]
+ # import pdb; pdb.set_trace()
+ if xyz_i.shape[0] < 10:
+ continue
+ bbox_xyz, scale, rotation, translation, out_transform, best_inlier_idx = estimate_pose_from_npcs(xyz_i.cpu().numpy(), npcs_i.cpu().numpy())
+ # import pdb; pdb.set_trace()
+ if scale[0] == None:
+ continue
+ bboxes_batch_index.append(batch_indices[proposal_offsets[proposal_i]])
+ bboxes.append(bbox_xyz.tolist())
+
+ # get the sampled data point
+ sample_sem_pred = sem_preds.reshape(-1,20000)[sample_id]
+ sample_ins_seg_pred = ins_seg_preds.reshape(-1,20000)[batch_sample_id]
+ sample_npcs_map = npcs_maps.reshape(-1,20000, 3)[batch_sample_id]
+ sample_bboxes = [bboxes[i] for i in range(len(bboxes)) if bboxes_batch_index[i] == batch_sample_id]
+
+ visualize_gapartnet(
+ SAVE_ROOT=self.visualize_cfg["SAVE_ROOT"],
+ RAW_IMG_ROOT = self.visualize_cfg["RAW_IMG_ROOT"],
+ GAPARTNET_DATA_ROOT=self.visualize_cfg["GAPARTNET_DATA_ROOT"],
+ save_option=self.visualize_cfg["save_option"],
+ name = pc_ids[sample_id],
+ split = split,
+ sem_preds=sample_sem_pred.cpu().numpy(), # type: ignore
+ ins_preds=sample_ins_seg_pred.cpu().numpy(),
+ npcs_preds=sample_npcs_map.cpu().numpy(),
+ bboxes = sample_bboxes,
+ )
+
+
+
+ self.log(f"{split}/AP@50",
+ np.mean(ap50) * 100,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+ self.log(f"{split}/mAP",
+ mAP * 100,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+ self.log(f"{split}/all_accu",
+ all_accu * 100.0,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+ self.log(f"{split}/pixel_accu",
+ pixel_accu * 100.0,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+ self.log(f"{split}/miou",
+ miou * 100.0,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+
+ # need to make sure the order of the splits is correct:
+ # the second validation set is intra set and the third set is inter set
+ self.log("monitor_metrics/mean_all_accu",
+ (all_accus[1]+all_accus[2])/2 * 100.0,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+ self.log("monitor_metrics/mean_pixel_accu",
+ (pixel_accus[1]+pixel_accus[2])/2 * 100.0,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+ self.log("monitor_metrics/mean_imou",
+ (mious[1]+mious[2])/2 * 100.0,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+ self.log("monitor_metrics/mean_AP@50",
+ (mean_ap50[1]+mean_ap50[2])/2 * 100.0,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+ self.log("monitor_metrics/mean_mAP",
+ (mAPs[1]+mAPs[2])/2 * 100.0,
+ batch_size=data_size,
+ on_epoch=True, prog_bar=True, logger=True, sync_dist=True
+ )
+
+
+ self.validation_step_outputs.clear()
+
+ def configure_optimizers(self):
+ return torch.optim.Adam(
+ filter(lambda p: p.requires_grad, self.parameters()),
+ lr=self.learning_rate,
+ )
diff --git a/gapartnet/output/README.md b/gapartnet/output/README.md
new file mode 100644
index 0000000..bec7fc0
--- /dev/null
+++ b/gapartnet/output/README.md
@@ -0,0 +1 @@
+output data will be put here~
\ No newline at end of file
diff --git a/gapartnet/output/example.png b/gapartnet/output/example.png
new file mode 100644
index 0000000..decd719
Binary files /dev/null and b/gapartnet/output/example.png differ
diff --git a/gapartnet/output/example2.png b/gapartnet/output/example2.png
new file mode 100644
index 0000000..9888d6c
Binary files /dev/null and b/gapartnet/output/example2.png differ
diff --git a/gapartnet/output/example3.png b/gapartnet/output/example3.png
new file mode 100644
index 0000000..decd719
Binary files /dev/null and b/gapartnet/output/example3.png differ
diff --git a/gapartnet/structure/instances.py b/gapartnet/structure/instances.py
new file mode 100644
index 0000000..3586a08
--- /dev/null
+++ b/gapartnet/structure/instances.py
@@ -0,0 +1,44 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+
+
+@dataclass
+class Instances:
+ valid_mask: Optional[torch.Tensor] = None
+ sorted_indices: Optional[torch.Tensor] = None
+ pt_xyz: Optional[torch.Tensor] = None
+
+ batch_indices: Optional[torch.Tensor] = None
+ proposal_offsets: Optional[torch.Tensor] = None
+ proposal_indices: Optional[torch.Tensor] = None
+ num_points_per_proposal: Optional[torch.Tensor] = None
+
+ sem_preds: Optional[torch.Tensor] = None
+ pt_sem_classes: Optional[torch.Tensor] = None
+ score_preds: Optional[torch.Tensor] = None
+ npcs_preds: Optional[torch.Tensor] = None
+
+ sem_labels: Optional[torch.Tensor] = None
+ instance_labels: Optional[torch.Tensor] = None
+ instance_sem_labels: Optional[torch.Tensor] = None
+ num_points_per_instance: Optional[torch.Tensor] = None
+ gt_npcs: Optional[torch.Tensor] = None
+
+ npcs_valid_mask: Optional[torch.Tensor] = None
+
+ ious: Optional[torch.Tensor] = None
+
+ cls_preds: Optional[torch.Tensor] = None
+ cls_labels: Optional[torch.Tensor] = None
+
+ name: Optional[str] = None
+
+@dataclass
+class Result:
+ xyz: torch.Tensor
+ rgb: torch.Tensor
+ sem_preds: torch.Tensor
+ ins_preds: torch.Tensor
+ npcs_preds: torch.Tensor
\ No newline at end of file
diff --git a/gapartnet/structure/point_cloud.py b/gapartnet/structure/point_cloud.py
new file mode 100644
index 0000000..123c50a
--- /dev/null
+++ b/gapartnet/structure/point_cloud.py
@@ -0,0 +1,194 @@
+from dataclasses import dataclass, fields
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+from pyparsing import Opt
+import spconv.pytorch as spconv
+import torch
+
+@dataclass
+class PointCloudBatch:
+ # basic
+ pc_ids: List[str]
+ points: torch.Tensor
+ batch_indices: torch.Tensor
+ batch_size: int
+ device: str = None # type: ignore
+
+ # voxel
+ voxel_tensor: any = None, # type: ignore
+ pc_voxel_id: any = None # type: ignore
+
+ # semantic
+ sem_labels: torch.Tensor = None # type: ignore
+ obj_cls_labels = None
+
+ # instance
+ instance_labels: Optional[torch.Tensor] = None
+ num_instances: Optional[List[int]] = None
+ instance_regions: Optional[torch.Tensor] = None
+ num_points_per_instance: Optional[torch.Tensor] = None
+ instance_sem_labels: Optional[torch.Tensor] = None
+
+ #npcs
+ gt_npcs: Optional[Union[torch.Tensor, np.ndarray]] = None
+
+@dataclass
+class PointCloud:
+ pc_id: str
+
+ points: Union[torch.Tensor, np.ndarray]
+
+ obj_cat: int = -1
+
+ sem_labels: Optional[Union[torch.Tensor, np.ndarray]] = None
+ instance_labels: Optional[Union[torch.Tensor, np.ndarray]] = None
+
+ gt_npcs: Optional[Union[torch.Tensor, np.ndarray]] = None
+
+ # instance number
+ num_instances: Optional[int] = None
+
+ # for points in an instance: 0-3: mean_xyz; 3-6: max_xyz; 6-9: min_xyz
+ instance_regions: Optional[Union[torch.Tensor, np.ndarray]] = None
+
+ # instance points number
+ num_points_per_instance: Optional[Union[torch.Tensor, np.ndarray]] = None
+
+ # instance semantic label
+ instance_sem_labels: Optional[torch.Tensor] = None
+
+ voxel_features: Optional[torch.Tensor] = None
+ voxel_coords: Optional[torch.Tensor] = None
+ voxel_coords_range: Optional[List[int]] = None
+ pc_voxel_id: Optional[torch.Tensor] = None
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ field.name: getattr(self, field.name)
+ for field in fields(self)
+ }
+
+ def to_tensor(self) -> "PointCloud":
+ return PointCloud(**{
+ k: torch.from_numpy(v) if isinstance(v, np.ndarray) else v
+ for k, v in self.to_dict().items()
+ }) # type: ignore
+
+ def to(self, device: torch.device) -> "PointCloud":
+ return PointCloud(**{
+ k: v.to(device) if isinstance(v, torch.Tensor) else v
+ for k, v in self.to_dict().items()
+ }) # type: ignore
+
+ @staticmethod
+ def collate(point_clouds: List["PointCloud"]) -> PointCloudBatch:
+ batch_size = len(point_clouds)
+ device = point_clouds[0].points.device # type: ignore
+
+ pc_ids = [pc.pc_id for pc in point_clouds]
+ cls_labels = torch.tensor([pc.obj_cat for pc in point_clouds])
+ num_points = [pc.points.shape[0] for pc in point_clouds]
+
+ points = torch.cat([pc.points for pc in point_clouds], dim=0)# type: ignore #
+ batch_indices = torch.cat([
+ torch.full((pc.points.shape[0],), i, dtype=torch.int32, device=device)
+ for i, pc in enumerate(point_clouds)
+ ], dim=0) #
+
+ if point_clouds[0].sem_labels is not None:
+ sem_labels = torch.cat([pc.sem_labels for pc in point_clouds], dim=0) # type: ignore
+ else:
+ sem_labels = None
+
+ if point_clouds[0].instance_labels is not None:
+ instance_labels = torch.cat([pc.instance_labels for pc in point_clouds], dim=0) # type: ignore
+ else:
+ instance_labels = None
+
+ if point_clouds[0].gt_npcs is not None:
+ gt_npcs = torch.cat([pc.gt_npcs for pc in point_clouds], dim=0)
+ else:
+ gt_npcs = None
+
+ if point_clouds[0].num_instances is not None:
+ num_instances = [pc.num_instances for pc in point_clouds]
+ max_num_instances = max(num_instances) # type: ignore
+ num_points_per_instance = torch.zeros(
+ batch_size, max_num_instances, dtype=torch.int32, device=device
+ )
+ instance_sem_labels = torch.full(
+ (batch_size, max_num_instances), -1, dtype=torch.int32, device=device
+ )
+ for i, pc in enumerate(point_clouds):
+ num_points_per_instance[i, :pc.num_instances] = pc.num_points_per_instance # type: ignore
+ instance_sem_labels[i, :pc.num_instances] = pc.instance_sem_labels # type: ignore
+ else:
+ num_instances = None
+ num_points_per_instance = None
+ instance_sem_labels = None
+ import pdb; pdb.set_trace()
+
+ if point_clouds[0].instance_regions is not None:
+ instance_regions = torch.cat([
+ pc.instance_regions for pc in point_clouds
+ ], dim=0) # type: ignore
+ else:
+ instance_regions = None
+
+ voxel_batch_indices = torch.cat([
+ torch.full((
+ pc.voxel_coords.shape[0],), i, dtype=torch.int32, device=device # type: ignore
+ )
+ for i, pc in enumerate(point_clouds)
+ ], dim=0)
+ voxel_coords = torch.cat([
+ pc.voxel_coords for pc in point_clouds
+ ], dim=0) # type: ignore
+ voxel_coords = torch.cat([
+ voxel_batch_indices[:, None], voxel_coords
+ ], dim=-1)
+ voxel_features = torch.cat([
+ pc.voxel_features for pc in point_clouds
+ ], dim=0) # type: ignore
+
+ voxel_coords_range = np.max([
+ pc.voxel_coords_range for pc in point_clouds
+ ], axis=0) # type: ignore
+ voxel_tensor = spconv.SparseConvTensor(
+ voxel_features, voxel_coords,
+ spatial_shape=voxel_coords_range.tolist(),
+ batch_size=len(point_clouds),
+ )
+
+ pc_voxel_id = []
+ num_voxel_offset = 0
+ for pc in point_clouds:
+ pc.pc_voxel_id[pc.pc_voxel_id >= 0] += num_voxel_offset # type: ignore
+ pc_voxel_id.append(pc.pc_voxel_id)
+ num_voxel_offset += pc.voxel_coords.shape[0] # type: ignore
+ pc_voxel_id = torch.cat(pc_voxel_id, dim=0)
+
+ return PointCloudBatch(
+ pc_ids=pc_ids,
+ points = points,
+ batch_indices=batch_indices,
+ batch_size=batch_size,
+ device=device, # type: ignore
+ voxel_tensor=voxel_tensor,
+ pc_voxel_id=pc_voxel_id,
+ sem_labels=sem_labels, # type: ignore
+ # instance
+ num_instances=num_instances, # type: ignore
+ instance_regions=instance_regions,
+ num_points_per_instance=num_points_per_instance,
+ instance_sem_labels=instance_sem_labels,
+ instance_labels = instance_labels,
+ #npcs
+ gt_npcs=gt_npcs,
+ )
+
+
+if __name__ == "__main__":
+ pc = PointCloud(np.ones((10000,3)), np.ones((10000)), np.ones(10000)) # type: ignore
+ print(pc.to_tensor().to("cuda:0")) # type: ignore
\ No newline at end of file
diff --git a/gapartnet/structure/segmentation.py b/gapartnet/structure/segmentation.py
new file mode 100644
index 0000000..e2db8b1
--- /dev/null
+++ b/gapartnet/structure/segmentation.py
@@ -0,0 +1,14 @@
+from dataclasses import dataclass
+from typing import List, Optional
+
+import torch
+
+
+@dataclass
+class Segmentation:
+ batch_size: int
+
+ sem_preds: torch.Tensor
+ sem_labels: Optional[torch.Tensor] = None
+ all_accu: Optional[torch.Tensor] = None
+ pixel_accu: Optional[float] = None
diff --git a/gapartnet/tools/visu.py b/gapartnet/tools/visu.py
new file mode 100644
index 0000000..e10dc0b
--- /dev/null
+++ b/gapartnet/tools/visu.py
@@ -0,0 +1,552 @@
+import torch
+import numpy as np
+import yaml
+from os.path import join as pjoin
+import os
+import argparse
+import sys
+sys.path.append(sys.path[0] + "/..")
+import importlib
+from gapartnet.structures.point_cloud import PointCloud
+from gapartnet.datasets.gapartnet_new import apply_voxelization
+from gapartnet.utils.pose_fitting import estimate_pose_from_npcs
+import cv2
+from typing import List
+import glob
+from visu_utils import OBJfile2points, map2image, save_point_cloud_to_ply, \
+ WorldSpaceToBallSpace, FindMaxDis, draw_bbox_old, draw_bbox, COLOR20, \
+ OTHER_COLOR, HEIGHT, WIDTH, EDGE, K, font, fontScale, fontColor,thickness, lineType
+
+GAPARTNET_DATA_ROOT = "data/GAPartNet_All"
+RAW_IMG_ROOT = "data/image_kuafu" # just for visualization, not necessary
+SAVE_ROOT = "output/GAPartNet_result"
+
+# OPTION
+FEW_SHOT = True # if True, only visualize the FEW_NUM samples, otherwise visualize all
+FEW_NUM = 10 # only valid when FEW_SHOT is True
+save_option = ["raw", "pc", "sem_pred", "ins_pred", "npcs_pred", "bbox_pred", "pure_bbox",
+ "sem_gt", "ins_gt", "npcs_gt", "bbox_gt", "bbox_gt_pure"] # save options
+SAVE_LOCAL = False
+splits = ["train", "val", "test_intra", "test_inter", ] #
+dir_name = "visu"
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--class_path", type=str, default = "gapartnet.models.gapartnet.InsSeg")
+ parser.add_argument("--ckpt", type=str, default = "ckpt/new.ckpt")
+ parser.add_argument("--in_channels", type=int, default = 6)
+ parser.add_argument("--device", type=str, default = "cuda:0")
+ args = parser.parse_args()
+
+
+ perception_cfg = {}
+ perception_cfg["class_path"] = args.class_path
+ perception_cfg["ckpt"] = args.ckpt
+ perception_cfg["device"] = args.device
+ perception_cfg["in_channels"] = args.in_channels
+ # perception_cfg["channels"] = [16, 64, 112] # [16, 32, 48, 64, 80, 96, 112]
+
+ return args, perception_cfg
+
+class MODEL:
+ def __init__(self, cfg):
+ self.model_cfg = cfg
+ self.perception_model = self._load_perception_model(self.model_cfg)
+ self.device = cfg["device"]
+
+
+ def _load_perception_model(self, perception_model_cfg):
+ class_path = perception_model_cfg["class_path"]
+ ckpt_pth = perception_model_cfg["ckpt"]
+ device = perception_model_cfg["device"]
+
+ module_name = ".".join(class_path.split(".")[:-1])
+ class_name = class_path.split(".")[-1]
+
+ module = importlib.import_module(module_name)
+ cls = getattr(module, class_name)
+ net = cls.load_from_checkpoint(ckpt_pth)
+
+ net.cluster_proposals_start_at = 0
+ net.score_net_start_at = 0
+ net.npcs_net_start_at = 0
+ net.freeze()
+ net.eval()
+ net.to(device)
+
+ return net
+
+ def _inference_perception_model(self, points_list: List[torch.Tensor]):
+ device = self.perception_model.device
+
+ pcs = []
+ for points in points_list:
+ pc = PointCloud(
+ scene_id=["eval"],
+ points=points,
+ obj_cat = 0
+ )
+ pc = apply_voxelization(
+ pc, voxel_size=(1. / 100, 1. / 100, 1. / 100)
+ )
+ pc = pc.to(device=device)
+ pcs.append(pc)
+
+ with torch.no_grad():
+ scene_ids, segmentations, proposals = self.perception_model(pcs)
+
+ sem_preds = segmentations.sem_preds
+ if proposals is not None:
+ pt_xyz = proposals.pt_xyz
+ batch_indices = proposals.batch_indices
+ proposal_offsets = proposals.proposal_offsets
+ num_points_per_proposal = proposals.num_points_per_proposal
+ num_proposals = num_points_per_proposal.shape[0]
+ npcs_preds = proposals.npcs_preds
+ score_preds= proposals.score_preds
+
+ indices = torch.arange(sem_preds.shape[0], dtype=torch.int64, device=device)
+ proposal_indices = indices[proposals.valid_mask][proposals.sorted_indices]
+
+
+ npcs_maps = pcs[0].points[:,:3].clone()
+ npcs_maps[:] = 230./255.
+ if proposals is not None:
+ npcs_maps[proposal_indices] = npcs_preds
+ bboxes = [[] for _ in range(len(points_list))]
+ if proposals is not None:
+ for i in range(num_proposals):
+ offset_begin = proposal_offsets[i].item()
+ offset_end = proposal_offsets[i + 1].item()
+
+ batch_idx = batch_indices[offset_begin]
+ xyz_i = pt_xyz[offset_begin:offset_end]
+ npcs_i = npcs_preds[offset_begin:offset_end]
+
+ npcs_i = npcs_i - 0.5
+ if xyz_i.shape[0]<=4:
+ continue
+ bbox_xyz, scale, rotation, translation, out_transform, best_inlier_idx = \
+ estimate_pose_from_npcs(xyz_i.cpu().numpy(), npcs_i.cpu().numpy())
+ # import pdb
+ # pdb.set_trace()
+ if scale[0] == None:
+ continue
+ bboxes[batch_idx].append(bbox_xyz.tolist())
+ try:
+ return bboxes, sem_preds, npcs_maps, proposal_indices, proposal_offsets
+ except:
+ return bboxes, sem_preds, npcs_maps, None, None
+
+ def inference(self, points):
+ bboxes, sem_preds, npcs_maps, proposal_indices, proposal_offsets = self._inference_perception_model([points])
+ return bboxes, sem_preds, npcs_maps, proposal_indices, proposal_offsets
+
+ def inference_real(self, file_path, label = "", save_root = "/scratch/genghaoran/GAPartNet/GAPartNet_inference/asset/"):
+ trans_gapartnet = np.array([ 1.26171422e+00, -6.60613179e-04, 4.20249701e-02, 4.23497820e+00])
+ data_path = file_path
+ if ".obj" in file_path:
+ points = OBJfile2points(data_path)
+ points[:, 2] = -points[:,2]
+ points[:, 1] = -points[:,1]
+ save_point_cloud_to_ply(points[:,:3], points[:,3:6], data_path.split("/")[-1].split(".")[0] + label+"_preinput.ply")
+ xyz, max_radius, center = WorldSpaceToBallSpace(points[:,:3])
+ trans = np.array([max_radius, center[0], center[1], center[2]])
+ else:
+ import pdb
+ pdb.set_trace()
+ points_input = torch.cat(
+ (torch.tensor(xyz, dtype=torch.float32 ,device = self.perception_model.device),
+ torch.tensor(points[:,-3:], dtype=torch.float32 ,device = self.perception_model.device)),
+ dim = 1)
+ save_point_cloud_to_ply(points_input[:,:3], points_input[:,3:6]*255, data_path.split("/")[-1].split(".")[0] + label+"_input.ply")
+ bboxes, sem_preds, npcs_maps, proposal_indices, proposal_offsets = self.perception_model._inference_perception_model([points_input])
+
+ print("-------bbox" ,len(bboxes[0]),"-------")
+ point_img = points_input.cpu().numpy()
+ point_img[:,:3] = point_img[:,:3] * trans_gapartnet[0] + trans_gapartnet[1:4]
+ img = map2image(point_img[:,:3], point_img[:,3:6]*255.0)
+
+ im_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_raw = im_rgb.copy()
+ # cv2.imwrite(save_root + label + data_path.split("/")[-1].split(".")[0]+".png", im_rgb)
+ # trans = np.array([xyz_scale, xyz_mean[0], xyz_mean[1], xyz_mean[2]])
+
+ cv2.imwrite(save_root + label+data_path.split("/")[-1].split(".")[0]+"_raw.png", im_rgb)
+ draw_bbox(im_rgb, bboxes[0], trans_gapartnet)
+ cv2.imwrite(save_root + label+data_path.split("/")[-1].split(".")[0]+"_bbox.png", im_rgb)
+ # for i,bbox in enumerate(bboxes[0]):
+ # bbox_now = [bbox,]
+ # img_now = img_raw.copy()
+ # draw_bbox(img_now, bbox_now, trans_gapartnet)
+ # cv2.imwrite(save_root + label+data_path.split("/")[-1].split(".")[0]+f"_bbox_{i}.png", img_now)
+ # npcs_maps_now = npcs_maps.clone()
+ # npcs_maps_now[:] = 230./255.
+ # # import pdb
+ # # pdb.set_trace()
+ # npcs_maps_now[proposal_indices[proposal_offsets[i]:proposal_offsets[i+1]]]=npcs_maps[proposal_indices[proposal_offsets[i]:proposal_offsets[i+1]]]
+ # img_npcs_now = map2image(point_img[:,:3], npcs_maps_now.cpu().numpy()*255)
+ # im_rgb_npcs_now = cv2.cvtColor(img_npcs_now, cv2.COLOR_BGR2RGB)
+ # cv2.imwrite(save_root + label+data_path.split("/")[-1].split(".")[0]+f"_npcs{i}.png", im_rgb_npcs_now)
+
+ rgb_sem = COLOR20[sem_preds.cpu().numpy()]
+ img_sem = map2image(point_img[:,:3], rgb_sem)
+ im_rgb_sem = cv2.cvtColor(img_sem, cv2.COLOR_BGR2RGB)
+ # draw_bbox(im_rgb_sem, bboxes[0], trans_gapartnet)
+ cv2.imwrite(save_root + label+data_path.split("/")[-1].split(".")[0]+"_sem.png", im_rgb_sem)
+
+ img_npcs = map2image(point_img[:,:3], npcs_maps.cpu().numpy()*255)
+ im_rgb_npcs = cv2.cvtColor(img_npcs, cv2.COLOR_BGR2RGB)
+ draw_bbox(im_rgb_npcs, bboxes[0], trans_gapartnet)
+ cv2.imwrite(save_root + label+data_path.split("/")[-1].split(".")[0]+"_npcs.png", im_rgb_npcs)
+
+ def inference_gapartnet(self, name, split = "train", other_string = ""):
+ data_path = f"{GAPARTNET_DATA_ROOT}/{split}/pth/{name}.pth"
+ trans_path = f"{GAPARTNET_DATA_ROOT}/{split}/meta/{name}.txt"
+ pc, rgb, semantic_label, instance_label, npcs_map = torch.load(data_path)
+
+ trans = np.loadtxt(trans_path)
+ xyz = pc * trans[0] + trans[1:4]
+
+ # save_point_cloud_to_ply(xyz, rgb*255, data_path.split("/")[-1].split(".")[0]+"_preinput.ply")
+ # save_point_cloud_to_ply(pc, rgb*255, data_path.split("/")[-1].split(".")[0]+"_input.ply")
+
+ points_input = torch.cat((torch.tensor(pc, device = self.perception_model.device),torch.tensor(rgb, device = self.perception_model.device)), dim = 1)
+
+ bboxes, sem_preds, npcs_maps, proposal_indices, proposal_offsets = self.perception_model._inference_perception_model([points_input])
+
+ # img = map2image(xyz, rgb*255.0)
+ # im_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ # cv2.imwrite(save_root+data_path.split("/")[-1].split(".")[0]+".png", im_rgb)
+ # draw_bbox(im_rgb, bboxes[0], trans)
+ # cv2.imwrite(save_root+data_path.split("/")[-1].split(".")[0]+"_bbox.png", im_rgb)
+ img = map2image(xyz, rgb*255.0)
+
+ im_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_raw = im_rgb.copy()
+ # cv2.imwrite(save_root + label + data_path.split("/")[-1].split(".")[0]+".png", im_rgb)
+ # trans = np.array([xyz_scale, xyz_mean[0], xyz_mean[1], xyz_mean[2]])
+
+ cv2.imwrite(save_root+others+"raw.png", im_rgb)
+ # import pdb
+ # pdb.set_trace()
+ draw_bbox(im_rgb, bboxes[0], trans)
+ cv2.imwrite(save_root+others+"bbox.png", im_rgb)
+ # for i,bbox in enumerate(bboxes[0]):
+ # bbox_now = [bbox,]
+ # img_now = img_raw.copy()
+ # draw_bbox(img_now, bbox_now, trans)
+ # cv2.imwrite(save_root + label+data_path.split("/")[-1].split(".")[0]+f"_bbox_{i}.png", img_now)
+ # npcs_maps_now = npcs_maps.clone()
+ # npcs_maps_now[:] = 230./255.
+ # # import pdb
+ # # pdb.set_trace()
+ # npcs_maps_now[proposal_indices[proposal_offsets[i]:proposal_offsets[i+1]]]=npcs_maps[proposal_indices[proposal_offsets[i]:proposal_offsets[i+1]]]
+ # img_npcs_now = map2image(point_img[:,:3], npcs_maps_now.cpu().numpy()*255)
+ # im_rgb_npcs_now = cv2.cvtColor(img_npcs_now, cv2.COLOR_BGR2RGB)
+ # cv2.imwrite(save_root + label+data_path.split("/")[-1].split(".")[0]+f"_npcs{i}.png", im_rgb_npcs_now)
+
+ rgb_sem = COLOR20[sem_preds.cpu().numpy()]
+ img_sem = map2image(xyz[:,:3], rgb_sem)
+ im_rgb_sem = cv2.cvtColor(img_sem, cv2.COLOR_BGR2RGB)
+ # draw_bbox(im_rgb_sem, bboxes[0], trans)
+ cv2.imwrite(save_root+others+"sem.png", im_rgb_sem)
+
+ img_npcs = map2image(xyz[:,:3], npcs_maps.cpu().numpy()*255)
+ im_rgb_npcs = cv2.cvtColor(img_npcs, cv2.COLOR_BGR2RGB)
+ draw_bbox(im_rgb_npcs, bboxes[0], trans)
+ cv2.imwrite(save_root+others+"npcs.png", im_rgb_npcs)
+
+ def process_objfile(self, file_path, label = "", save_root = "/scratch/genghaoran/GAPartNet/GAPartNet_inference/asset/"):
+ data_path = file_path
+ if ".obj" in file_path:
+ points = OBJfile2points(data_path)
+ points[:, 2] = -points[:,2]
+ points[:, 1] = -points[:,1]
+ save_point_cloud_to_ply(points[:,:3], points[:,3:6], data_path.split("/")[-1].split(".")[0] + label+"_preinput.ply")
+ xyz, max_radius, center = WorldSpaceToBallSpace(points[:,:3])
+ trans = np.array([max_radius, center[0], center[1], center[2]])
+ else:
+ import pdb
+ pdb.set_trace()
+ points_input = torch.cat(
+ (torch.tensor(xyz, dtype=torch.float32 ,device = self.perception_model.device),
+ torch.tensor(points[:,-3:], dtype=torch.float32 ,device = self.perception_model.device)),
+ dim = 1)
+ trans_gapartnet = np.array([ 1.26171422e+00, -6.60613179e-04, 4.20249701e-02, 4.23497820e+00])
+ return points_input, trans_gapartnet
+
+ def process_gapartnetfile(self, name, split = "train"):
+ data_path = f"{GAPARTNET_DATA_ROOT}/{split}/pth/{name}.pth"
+ trans_path = f"{GAPARTNET_DATA_ROOT}/{split}/meta/{name}.txt"
+
+ pc, rgb, semantic_label, instance_label, npcs_map = torch.load(data_path)
+
+ trans = np.loadtxt(trans_path)
+ xyz = pc * trans[0] + trans[1:4]
+
+ # save_point_cloud_to_ply(xyz, rgb*255, data_path.split("/")[-1].split(".")[0]+"_preinput.ply")
+ # save_point_cloud_to_ply(pc, rgb*255, data_path.split("/")[-1].split(".")[0]+"_input.ply")
+
+ points_input = torch.cat((torch.tensor(pc, device = self.perception_model.device),torch.tensor(rgb, device = self.perception_model.device)), dim = 1)
+ return points_input, trans, semantic_label, instance_label, npcs_map
+
+def draw_result(save_option, save_root, name, points_input, trans, bboxes, sem_preds, npcs_maps, proposal_indices, proposal_offsets, gts=None, have_proposal = True, save_local = False):
+
+ final_save_root = f"{save_root}/"
+ save_root = f"{save_root}/{name}/"
+ if save_local:
+ os.makedirs(save_root, exist_ok=True)
+ final_img = np.ones((3 * (HEIGHT + EDGE) + EDGE, 4 * (WIDTH + EDGE) + EDGE, 3), dtype=np.uint8) * 255
+ xyz_input = points_input[:,:3]
+ rgb = points_input[:,3:6]
+ xyz = xyz_input * trans[0] + trans[1:4]
+ pc_img = map2image(xyz, rgb*255.0)
+ pc_img = cv2.cvtColor(pc_img, cv2.COLOR_BGR2RGB)
+ if "raw" in save_option:
+ raw_img_path = f"{RAW_IMG_ROOT}/{name}.png"
+ if os.path.exists(raw_img_path):
+ raw_img = cv2.imread(raw_img_path)
+ if save_local:
+ cv2.imwrite(f"{save_root}/raw.png", raw_img)
+ X_START = EDGE
+ Y_START = EDGE
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = raw_img
+ text = "raw"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "pc" in save_option:
+ if save_local:
+ cv2.imwrite(f"{save_root}/pc.png", pc_img)
+ X_START = EDGE + (HEIGHT + EDGE)
+ Y_START = EDGE
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = pc_img
+ text = "pc"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "sem_pred" in save_option:
+ sem_pred_img = map2image(xyz, COLOR20[sem_preds])
+ sem_pred_img = cv2.cvtColor(sem_pred_img, cv2.COLOR_BGR2RGB)
+ if save_local:
+ cv2.imwrite(f"{save_root}/sem_pred.png", sem_pred_img)
+ X_START = EDGE + (WIDTH + EDGE)
+ Y_START = EDGE + (HEIGHT + EDGE)
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = sem_pred_img
+ text = "sem_pred"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "ins_pred" in save_option:
+ ins_pred_color = np.ones_like(xyz) * 230
+ if have_proposal:
+ for ins_i in range(len(proposal_offsets) - 1):
+ ins_pred_color[proposal_indices[proposal_offsets[ins_i]:proposal_offsets[ins_i + 1]]] = COLOR20[ins_i%19 + 1]
+
+ ins_pred_img = map2image(xyz, ins_pred_color)
+ ins_pred_img = cv2.cvtColor(ins_pred_img, cv2.COLOR_BGR2RGB)
+ if save_local:
+ cv2.imwrite(f"{save_root}/ins_pred.png", ins_pred_img)
+ X_START = EDGE + (WIDTH + EDGE) * 1
+ Y_START = EDGE + (HEIGHT + EDGE) * 2
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = ins_pred_img
+ text = "ins_pred"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "npcs_pred" in save_option:
+ npcs_pred_img = map2image(xyz, npcs_maps*255.0)
+ npcs_pred_img = cv2.cvtColor(npcs_pred_img, cv2.COLOR_BGR2RGB)
+ if save_local:
+ cv2.imwrite(f"{save_root}/npcs_pred.png", npcs_pred_img)
+ X_START = EDGE + (WIDTH + EDGE) * 1
+ Y_START = EDGE + (HEIGHT + EDGE) * 3
+ # import pdb
+ # pdb.set_trace()
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = npcs_pred_img
+ text = "npcs_pred"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "bbox_pred" in save_option:
+ img_bbox_pred = pc_img.copy()
+ draw_bbox(img_bbox_pred, bboxes[0], trans)
+ if save_local:
+ cv2.imwrite(f"{save_root}/bbox_pred.png", img_bbox_pred)
+ X_START = EDGE + (WIDTH + EDGE) * 2
+ Y_START = EDGE + (HEIGHT + EDGE) * 2
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = img_bbox_pred
+ text = "bbox_pred"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "pure_bbox" in save_option:
+ img_empty = np.ones((HEIGHT, WIDTH, 3), dtype=np.uint8) * 255
+ draw_bbox(img_empty, bboxes[0], trans)
+ if save_local:
+ cv2.imwrite(f"{save_root}/bbox_pure.png", img_empty)
+ X_START = EDGE + (WIDTH + EDGE) * 2
+ Y_START = EDGE + (HEIGHT + EDGE) * 3
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = img_empty
+ text = "bbox_pred_pure"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "sem_gt" in save_option:
+ sem_gt = gts[0]
+ sem_gt_img = map2image(xyz, COLOR20[sem_gt])
+ sem_gt_img = cv2.cvtColor(sem_gt_img, cv2.COLOR_BGR2RGB)
+ if save_local:
+ cv2.imwrite(f"{save_root}/sem_gt.png", sem_gt_img)
+ X_START = EDGE + (WIDTH + EDGE) * 0
+ Y_START = EDGE + (HEIGHT + EDGE) * 1
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = sem_gt_img
+ text = "sem_gt"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "ins_gt" in save_option:
+ ins_gt = gts[1]
+ ins_color = COLOR20[ins_gt%19 + 1]
+ ins_color[np.where(ins_gt == -100)] = 230
+ ins_gt_img = map2image(xyz, ins_color)
+
+ ins_gt_img = cv2.cvtColor(ins_gt_img, cv2.COLOR_BGR2RGB)
+ if save_local:
+ cv2.imwrite(f"{save_root}/ins_gt.png", ins_gt_img)
+ X_START = EDGE + (WIDTH + EDGE) * 0
+ Y_START = EDGE + (HEIGHT + EDGE) * 2
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = ins_gt_img
+ text = "ins_gt"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "npcs_gt" in save_option:
+ npcs_gt = gts[2] + 0.5
+ npcs_gt_img = map2image(xyz, npcs_gt*255.0)
+ npcs_gt_img = cv2.cvtColor(npcs_gt_img, cv2.COLOR_BGR2RGB)
+ if save_local:
+ cv2.imwrite(f"{save_root}/npcs_gt.png", npcs_gt_img)
+ X_START = EDGE + (WIDTH + EDGE) * 0
+ Y_START = EDGE + (HEIGHT + EDGE) * 3
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = npcs_gt_img
+ text = "npcs_gt"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "bbox_gt" in save_option:
+ bboxes_gt = [[]]
+ ins_gt = gts[1]
+ npcs_gt = gts[2]
+ # import pdb
+ # pdb.set_trace()
+ num_ins = ins_gt.max()+1
+ if num_ins >= 1:
+ for ins_i in range(num_ins):
+ mask_i = ins_gt == ins_i
+ xyz_input_i = xyz_input[mask_i]
+ npcs_i = npcs_gt[mask_i]
+ if xyz_input_i.shape[0]<=5:
+ continue
+
+ bbox_xyz, scale, rotation, translation, out_transform, best_inlier_idx = \
+ estimate_pose_from_npcs(xyz_input_i, npcs_i)
+ if scale[0] == None:
+ continue
+ bboxes_gt[0].append(bbox_xyz.tolist())
+ img_bbox_gt = pc_img.copy()
+ draw_bbox(img_bbox_gt, bboxes_gt[0], trans)
+ if save_local:
+ cv2.imwrite(f"{save_root}/bbox_gt.png", img_bbox_gt)
+ X_START = EDGE + (WIDTH + EDGE) * 2
+ Y_START = EDGE + (HEIGHT + EDGE) * 1
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = img_bbox_gt
+ text = "bbox_gt"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ if "bbox_gt_pure" in save_option:
+ bboxes_gt = [[]]
+ ins_gt = gts[1]
+ npcs_gt = gts[2]
+ # import pdb
+ # pdb.set_trace()
+ num_ins = ins_gt.max()+1
+ if num_ins >= 1:
+ for ins_i in range(num_ins):
+ mask_i = ins_gt == ins_i
+ xyz_input_i = xyz_input[mask_i]
+ npcs_i = npcs_gt[mask_i]
+ if xyz_input_i.shape[0]<=5:
+ continue
+
+ bbox_xyz, scale, rotation, translation, out_transform, best_inlier_idx = \
+ estimate_pose_from_npcs(xyz_input_i, npcs_i)
+ if scale[0] == None:
+ continue
+
+ bboxes_gt[0].append(bbox_xyz.tolist())
+ img_bbox_gt_pure = np.ones((HEIGHT, WIDTH, 3), dtype=np.uint8) * 255
+ draw_bbox(img_bbox_gt_pure, bboxes_gt[0], trans)
+ if save_local:
+ cv2.imwrite(f"{save_root}/bbox_gt_pure.png", img_bbox_gt_pure)
+ X_START = EDGE + (WIDTH + EDGE) * 2
+ Y_START = EDGE + (HEIGHT + EDGE) * 0
+ final_img[X_START:X_START+HEIGHT, Y_START:Y_START+WIDTH, :] = img_bbox_gt_pure
+ text = "bbox_gt_pure"
+ cv2.putText(final_img, text,
+ (Y_START + int(0.5*(WIDTH - 3 * EDGE)), X_START + HEIGHT + int(0.5*EDGE)),
+ font, fontScale, fontColor, thickness, lineType)
+ cv2.imwrite(f"{final_save_root}/{name}.png", final_img)
+
+def main():
+ args, perception_cfg = get_args()
+
+ # initialize the perception model
+ model = MODEL(perception_cfg)
+ print("finish load model")
+
+ FAIL = []
+ for split in splits:
+ paths = glob.glob(GAPARTNET_DATA_ROOT + "/" + split + "/pth/*")
+
+ if FEW_SHOT:
+ import random
+ r_nums = random.sample(list(range(0, len(paths))), FEW_NUM)
+ used_paths = []
+ for r_num in r_nums:
+ used_paths.append(paths[r_num])
+ paths = used_paths
+
+ for i, path in enumerate(paths):
+ name = path.split(".")[0].split("/")[-1]
+ print(split, " ", i, " ", name)
+
+ save_root = f"{SAVE_ROOT}/{dir_name}/{split}"
+ os.makedirs(save_root,exist_ok = True)
+ final_save_root = f"{save_root}/"
+ if os.path.exists(f"{final_save_root}/{name}.png"):
+ continue
+
+ points_input, trans, semantic_label, instance_label, npcs_map = model.process_gapartnetfile(name, split)
+ bboxes, sem_preds, npcs_maps, proposal_indices, proposal_offsets = model.inference(points_input)
+
+ # visualize results in the image
+
+ # try:
+ if proposal_indices == None:
+ draw_result(save_option, save_root, name,
+ points_input.cpu().numpy(), trans, bboxes, sem_preds.cpu().numpy(), npcs_maps.cpu().numpy(),
+ proposal_indices, proposal_offsets, gts = [semantic_label, instance_label, npcs_map],
+ have_proposal = False, save_local=SAVE_LOCAL)
+ else:
+ draw_result(save_option, save_root, name, points_input.cpu().numpy(), trans, bboxes, sem_preds.cpu().numpy(),
+ npcs_maps.cpu().numpy(), proposal_indices.cpu().numpy(), proposal_offsets.cpu().numpy(),
+ gts = [semantic_label, instance_label, npcs_map], have_proposal = True, save_local=SAVE_LOCAL)
+
+ # return model
+
+
+if __name__ == "__main__":
+ model = main()
\ No newline at end of file
diff --git a/gapartnet/tools/visu_utils.py b/gapartnet/tools/visu_utils.py
new file mode 100644
index 0000000..790c2b6
--- /dev/null
+++ b/gapartnet/tools/visu_utils.py
@@ -0,0 +1,173 @@
+
+from os.path import join as pjoin
+import cv2
+import numpy as np
+
+COLOR20 = np.array(
+ [[230, 230, 230], [0, 128, 128], [230, 190, 255], [170, 110, 40], [255, 250, 200], [128, 0, 0],
+ [170, 255, 195], [128, 128, 0], [255, 215, 180], [0, 0, 128], [128, 128, 128],
+ [230, 25, 75], [60, 180, 75], [255, 225, 25], [0, 130, 200], [245, 130, 48],
+ [145, 30, 180], [70, 240, 240], [240, 50, 230], [210, 245, 60], [250, 190, 190]])
+HEIGHT = int(800)
+WIDTH = int(800)
+EDGE = int(40)
+K = np.array([[1268.637939453125, 0, 400, 0], [0, 1268.637939453125, 400, 0],
+ [0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32)
+
+OTHER_COLOR = np.array([230, 230, 230])
+
+font = cv2.FONT_HERSHEY_SIMPLEX
+fontScale = 2
+fontColor = (0,0,0)
+thickness = 2
+lineType = 3
+
+def save_point_cloud_to_ply(points, colors, save_name='01.ply', save_root='/scratch/genghaoran/GAPartNet/GAPartNet_inference/asset/real'):
+ '''
+ Save point cloud to ply file
+ '''
+ PLY_HEAD = f"ply\nformat ascii 1.0\nelement vertex {len(points)}\nproperty float x\nproperty float y\nproperty float z\nproperty uchar red\nproperty uchar green\nproperty uchar blue\nend_header\n"
+ file_sting = PLY_HEAD
+ for i in range(len(points)):
+ file_sting += f'{points[i][0]} {points[i][1]} {points[i][2]} {int(colors[i][0])} {int(colors[i][1])} {int(colors[i][2])}\n'
+ f = open(pjoin(save_root, save_name), 'w')
+ f.write(file_sting)
+ f.close()
+
+def draw_bbox(img, bbox_list, trans):
+ for i,bbox in enumerate(bbox_list):
+ if len(bbox) == 0:
+ continue
+ bbox = np.array(bbox)
+ bbox = bbox * trans[0]+trans[1:4]
+ K = np.array([[1268.637939453125, 0, 400, 0], [0, 1268.637939453125, 400, 0],
+ [0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32)
+ point2image = []
+ for pts in bbox:
+ x = pts[0]
+ y = pts[1]
+ z = pts[2]
+ x_new = (np.around(x * K[0][0] / z + K[0][2])).astype(dtype=int)
+ y_new = (np.around(y * K[1][1] / z + K[1][2])).astype(dtype=int)
+ point2image.append([x_new, y_new])
+ cl = [255,0,255]
+ # import pdb
+ # pdb.set_trace()
+ cv2.line(img,point2image[0],point2image[1],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[0],point2image[2],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[0],point2image[3],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[1],point2image[4],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[1],point2image[5],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[2],point2image[6],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[6],point2image[3],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[4],point2image[7],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[5],point2image[7],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[3],point2image[5],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[2],point2image[4],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[6],point2image[7],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[0],point2image[1],color=(0,0,255),thickness=3) # red
+ cv2.line(img,point2image[0],point2image[3],color=(255,0,0),thickness=3) # green
+ cv2.line(img,point2image[0],point2image[2],color=(0,255,0),thickness=3) # blue
+ return img
+
+def draw_bbox_old(img, bbox_list, trans):
+ for i,bbox in enumerate(bbox_list):
+ if len(bbox) == 0:
+ continue
+ bbox = np.array(bbox)
+ bbox = bbox * trans[0]+trans[1:4]
+ K = np.array([[1268.637939453125, 0, 400, 0], [0, 1268.637939453125, 400, 0],
+ [0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32)
+ point2image = []
+ for pts in bbox:
+ x = pts[0]
+ y = pts[1]
+ z = pts[2]
+ x_new = (np.around(x * K[0][0] / z + K[0][2])).astype(dtype=int)
+ y_new = (np.around(y * K[1][1] / z + K[1][2])).astype(dtype=int)
+ point2image.append([x_new, y_new])
+ cl = [255,0,0]
+ cv2.line(img,point2image[0],point2image[1],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[0],point2image[1],color=(255,0,0),thickness=1)
+ cv2.line(img,point2image[1],point2image[2],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[1],point2image[2],color=(0,255,0),thickness=1)
+ cv2.line(img,point2image[2],point2image[3],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[2],point2image[3],color=(0,0,255),thickness=1)
+ cv2.line(img,point2image[3],point2image[0],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[4],point2image[5],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[5],point2image[6],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[6],point2image[7],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[4],point2image[7],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[4],point2image[0],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[1],point2image[5],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[2],point2image[6],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ cv2.line(img,point2image[3],point2image[7],color=(int(cl[0]),int(cl[1]),int(cl[2])),thickness=2)
+ return img
+
+def map2image(pts, rgb):
+ # input为每个shape的info,取第idx行
+ image_rgb = np.ones((HEIGHT, WIDTH, 3), dtype=np.uint8) * 255
+ K = np.array([[1268.637939453125, 0, 400, 0], [0, 1268.637939453125, 400, 0],
+ [0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32)
+
+ num_point = pts.shape[0]
+ # print(num_point)
+ # print(pts)
+ # print(rgb.shape)
+
+ point2image = {}
+ for i in range(num_point):
+ x = pts[i][0]
+ y = pts[i][1]
+ z = pts[i][2]
+ x_new = (np.around(x * K[0][0] / z + K[0][2])).astype(dtype=int)
+ y_new = (np.around(y * K[1][1] / z + K[1][2])).astype(dtype=int)
+ point2image[i] = (y_new, x_new)
+
+ # 还原原始的RGB图
+ for i in range(num_point):
+ # print(i, point2image[i][0], point2image[i][1])
+ if point2image[i][0]+1 >= HEIGHT or point2image[i][0] < 0 or point2image[i][1]+1 >= WIDTH or point2image[i][1] < 0:
+ continue
+ image_rgb[point2image[i][0]][point2image[i][1]] = rgb[i]
+ image_rgb[point2image[i][0]+1][point2image[i][1]] = rgb[i]
+ image_rgb[point2image[i][0]+1][point2image[i][1]+1] = rgb[i]
+ image_rgb[point2image[i][0]][point2image[i][1]+1] = rgb[i]
+
+ # rgb_pil = Image.fromarray(image_rgb, mode='RGB')
+ # rgb_pil.save(os.path.join(save_path, f'{instance_name}_{task}.png'))
+ return image_rgb
+
+def OBJfile2points(file):
+ objFilePath = file
+ with open(objFilePath) as file:
+ points = []
+ while 1:
+ line = file.readline()
+ if not line:
+ break
+ strs = line.split(" ")
+ if strs[0] == "v":
+ points.append((float(strs[1]), float(strs[2]), float(strs[3]),float(strs[4]), float(strs[5]), float(strs[6])))
+ if strs[0] == "vt":
+ break
+ points = np.array(points)
+ return points
+
+def FindMaxDis(pointcloud):
+ max_xyz = pointcloud.max(0)
+ min_xyz = pointcloud.min(0)
+ center = (max_xyz + min_xyz) / 2
+ max_radius = ((((pointcloud - center)**2).sum(1))**0.5).max()
+ return max_radius, center
+
+def WorldSpaceToBallSpace(pointcloud):
+ """
+ change the raw pointcloud in world space to united vector ball space
+ pay attention: raw data changed
+ return: max_radius: the max_distance in raw pointcloud to center
+ center: [x,y,z] of the raw center
+ """
+ max_radius, center = FindMaxDis(pointcloud)
+ pointcloud_normalized = (pointcloud - center) / max_radius
+ return pointcloud_normalized, max_radius, center
diff --git a/gapartnet/train.py b/gapartnet/train.py
new file mode 100644
index 0000000..17d19c9
--- /dev/null
+++ b/gapartnet/train.py
@@ -0,0 +1,70 @@
+from lightning.pytorch.loggers import WandbLogger
+from lightning.pytorch.cli import LightningCLI
+import lightning.pytorch as pl
+import torch
+import wandb
+torch.set_float32_matmul_precision('medium')
+def log_name(config):
+ # model
+ model_str = ""
+ if config["model"]["init_args"]["backbone_type"] == "SparseUNet":
+ model_str += "SU"
+ else:
+ raise NotImplementedError(f"backbone type {config['model']['init_args']['backbone_type']} not implemented")
+
+ model_str += "_"
+
+ if config["model"]["init_args"]["use_sem_focal_loss"]:
+ model_str += "T"
+ else:
+ model_str += "F"
+ if config["model"]["init_args"]["use_sem_dice_loss"]:
+ model_str += "T"
+ else:
+ model_str += "F"
+
+ # data
+ data_str = ""
+ data_str += "BS" + str(config["data"]["init_args"]["train_batch_size"]) + "_"
+ data_str += "Aug" + \
+ ""+str(config["data"]["init_args"]["pos_jitter"]) +\
+ "-"+str(config["data"]["init_args"]["color_jitter"]) +\
+ "-"+str(config["data"]["init_args"]["flip_prob"]) +\
+ "-"+str(config["data"]["init_args"]["rotate_prob"])
+
+ # time
+ from datetime import datetime
+ now = datetime.now()
+ time_str = now.strftime("%m-%d-%H-%M")
+ return model_str, data_str, time_str
+
+class CustomCLI(LightningCLI):
+ def before_fit(self):
+ # Use the parsed arguments to create a name
+ if self.config["fit"]["model"]["init_args"]["debug"] == False:
+ model_str, data_str, time_str = log_name(self.config["fit"])
+ self.trainer.logger = WandbLogger(
+ save_dir = "wandb",
+ project = "perception",
+ entity = "haoran-geng",
+ group = "train_new",
+ name = model_str + "_" + data_str + "_" + time_str,
+ notes = "GAPartNet",
+ tags = ["GAPartNet", "score", "npcs"],
+ save_code = True,
+ mode = "online",
+ )
+ else:
+ print("Debugging, not using wandb logger")
+
+def main():
+ _ = CustomCLI(
+ pl.LightningModule, pl.LightningDataModule,
+ subclass_mode_model=True,
+ subclass_mode_data=True,
+ seed_everything_default=233,
+ save_config_kwargs={"overwrite": True},
+ )
+
+if __name__ == "__main__":
+ main()
diff --git a/gapartnet/train.sh b/gapartnet/train.sh
new file mode 100644
index 0000000..de4c04a
--- /dev/null
+++ b/gapartnet/train.sh
@@ -0,0 +1,12 @@
+CUDA_VISIBLE_DEVICES=7 \
+python train.py fit -c gapartnet.yaml \
+--model.init_args.ckpt ckpt/sem_seg_accu_82.7.ckpt \
+--model.init_args.debug True
+
+
+CUDA_VISIBLE_DEVICES=0 \
+python train.py test -c gapartnet.yaml \
+--model.init_args.ckpt ckpt/new.ckpt
+
+CUDA_VISIBLE_DEVICES=0 \
+python train.py fit -c gapartnet.yaml
\ No newline at end of file