diff --git a/.github/workflows/core_code_checks.yml b/.github/workflows/core_code_checks.yml index 1b4e9aae..0cf6128d 100644 --- a/.github/workflows/core_code_checks.yml +++ b/.github/workflows/core_code_checks.yml @@ -15,10 +15,10 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up Python 3.8.13 + - name: Set up Python 3.10.14 uses: actions/setup-python@v4 with: - python-version: '3.8.13' + python-version: '3.10.14' - uses: actions/cache@v2 with: path: ${{ env.pythonLocation }} @@ -26,6 +26,7 @@ jobs: - name: Install dependencies run: | pip install --upgrade --upgrade-strategy eager -e .[dev] + pip install waymo-open-dataset-tf-2-11-0==1.6.1 - name: Check notebook cell metadata run: | python ./nerfstudio/scripts/docs/add_nb_tags.py --check diff --git a/Dockerfile b/Dockerfile index 242f587d..5f20383a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -138,6 +138,9 @@ RUN git clone --recursive https://github.com/cvg/pixel-perfect-sfm.git && \ python3.10 -m pip install --no-cache-dir -e . && \ cd .. +# Install waymo-open-dataset +RUN python3.10 -m pip install --no-cache-dir waymo-open-dataset-tf-2-11-0==1.6.1 + # Copy nerfstudio folder. ADD . /nerfstudio diff --git a/README.md b/README.md index adbbe4f8..c90af034 100644 --- a/README.md +++ b/README.md @@ -82,10 +82,10 @@ Our installation steps largely follow Nerfstudio, with some added dataset-specif ### Create environment -NeuRAD requires `python >= 3.8`. We recommend using conda to manage dependencies. Make sure to install [Conda](https://docs.conda.io/miniconda.html) before proceeding. +NeuRAD requires `python >= 3.10`. We recommend using conda to manage dependencies. Make sure to install [Conda](https://docs.conda.io/miniconda.html) before proceeding. ```bash -conda create --name neurad -y python=3.8 +conda create --name neurad -y python=3.10 conda activate neurad pip install --upgrade pip ``` @@ -108,6 +108,11 @@ pip install dill --upgrade pip install ninja git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch ``` +For support of Waymo-Open-Dataset v2 (requires python3.10, also dependencies from this package are very strict so cannot add it to pyproject.toml and need install first): +```bash +pip install waymo-open-dataset-tf-2-11-0==1.6.1 +``` + We refer to [Nerfstudio](https://github.com/nerfstudio-project/nerfstudio/blob/v1.0.3/docs/quickstart/installation.md) for more installation support. ### Installing NeuRAD @@ -227,8 +232,9 @@ To add a dataset, create `nerfstudio/data/dataparsers/mydataset.py` containing o | 🚗 [Argoverse 2](https://www.argoverse.org/av2.html) | 7 ring cameras + 2 stereo cameras | 2 x 32-beam lidars | | 🚗 [PandaSet](https://pandaset.org/) ([huggingface download](https://huggingface.co/datasets/georghess/pandaset)) | 6 cameras | 64-beam lidar | | 🚗 [KITTIMOT](https://www.cvlibs.net/datasets/kitti/eval_tracking.php) ([Timestamps](https://www.cvlibs.net/datasets/kitti/raw_data.php)) | 2 stereo cameras | 64-beam lidar +| 🚗 [Waymo v2](https://waymo.com/open/) | 5 cameras | 64-beam lidar - +A brief introduction about Waymo dataparser for NeuRAD can be found in [waymo_dataparser.md](./nerfstudio/data//dataparsers/waymo_dataparser.md) ## Adding Methods diff --git a/docs/_static/imgs/NeuRAD-RS-Waymo-Front.png b/docs/_static/imgs/NeuRAD-RS-Waymo-Front.png new file mode 100644 index 00000000..10befb02 Binary files /dev/null and b/docs/_static/imgs/NeuRAD-RS-Waymo-Front.png differ diff --git a/docs/_static/imgs/NeuRAD-RS-Waymo-Left.png b/docs/_static/imgs/NeuRAD-RS-Waymo-Left.png new file mode 100644 index 00000000..3ed25d9f Binary files /dev/null and b/docs/_static/imgs/NeuRAD-RS-Waymo-Left.png differ diff --git a/docs/_static/imgs/NeuRAD-RS-Waymo-Right.png b/docs/_static/imgs/NeuRAD-RS-Waymo-Right.png new file mode 100644 index 00000000..eb5666d2 Binary files /dev/null and b/docs/_static/imgs/NeuRAD-RS-Waymo-Right.png differ diff --git a/nerfstudio/cameras/cameras.py b/nerfstudio/cameras/cameras.py index c865cd13..a723970b 100644 --- a/nerfstudio/cameras/cameras.py +++ b/nerfstudio/cameras/cameras.py @@ -921,12 +921,22 @@ def _compute_rays_for_vr180( if self.metadata and "rolling_shutter_offsets" in self.metadata and "velocities" in self.metadata: cam_idx = camera_indices.squeeze(-1) - heights, rows = self.height[cam_idx], coords[..., 0:1] - duration = self.metadata["rolling_shutter_offsets"][cam_idx].diff() - time_offsets = rows / heights * duration + self.metadata["rolling_shutter_offsets"][cam_idx][..., 0:1] + offsets = self.metadata["rolling_shutter_offsets"][cam_idx] + duration = offsets.diff() + if "rs_direction" in metadata and metadata["rs_direction"] == "Horizontal": + # wod (LEFT_TO_RIGHT or RIGHT_TO_LEFT) + width, cols = self.width[cam_idx], coords[..., 1:2] + time_offsets = cols / width * duration + offsets[..., 0:1] + else: + # pandaset (TOP_TO_BOTTOM) + heights, rows = self.height[cam_idx], coords[..., 0:1] + time_offsets = rows / heights * duration + offsets[..., 0:1] + origins = origins + self.metadata["velocities"][cam_idx] * time_offsets times = times + time_offsets del metadata["rolling_shutter_offsets"] # it has served its purpose + if "rs_direction" in metadata: + del metadata["rs_direction"] # it has served its purpose return RayBundle( origins=origins, diff --git a/nerfstudio/cameras/lidars.py b/nerfstudio/cameras/lidars.py index 811439ce..905b907a 100644 --- a/nerfstudio/cameras/lidars.py +++ b/nerfstudio/cameras/lidars.py @@ -35,6 +35,7 @@ from nerfstudio.utils.misc import strtobool, torch_compile from nerfstudio.utils.tensor_dataclass import TensorDataclass +# torch._dynamo.config.suppress_errors = True TORCH_DEVICE = Union[torch.device, str] # pylint: disable=invalid-name HORIZONTAL_BEAM_DIVERGENCE = 3.0e-3 # radians, or meters at a distance of 1m @@ -50,6 +51,7 @@ class LidarType(Enum): VELODYNE64E = auto() VELODYNE128 = auto() PANDAR64 = auto() + WOD64 = auto() LIDAR_MODEL_TO_TYPE = { @@ -59,6 +61,7 @@ class LidarType(Enum): "VELODYNE64E": LidarType.VELODYNE64E, "VELODYNE128": LidarType.VELODYNE128, "PANDAR64": LidarType.PANDAR64, + "WOD64": LidarType.WOD64, } diff --git a/nerfstudio/configs/dataparser_configs.py b/nerfstudio/configs/dataparser_configs.py index a0453d80..23b7846c 100644 --- a/nerfstudio/configs/dataparser_configs.py +++ b/nerfstudio/configs/dataparser_configs.py @@ -26,6 +26,7 @@ from nerfstudio.data.dataparsers.kittimot_dataparser import KittiMotDataParserConfig from nerfstudio.data.dataparsers.nuscenes_dataparser import NuScenesDataParserConfig from nerfstudio.data.dataparsers.pandaset_dataparser import PandaSetDataParserConfig +from nerfstudio.data.dataparsers.wod_dataparser import WoDParserConfig from nerfstudio.data.dataparsers.zod_dataparser import ZodDataParserConfig from nerfstudio.plugins.registry_dataparser import discover_dataparsers @@ -35,6 +36,7 @@ "argoverse2-data": Argoverse2DataParserConfig(), "zod-data": ZodDataParserConfig(), "pandaset-data": PandaSetDataParserConfig(), + "wod-data": WoDParserConfig(), } external_dataparsers, _ = discover_dataparsers() diff --git a/nerfstudio/data/dataparsers/waymo_dataparser.md b/nerfstudio/data/dataparsers/waymo_dataparser.md new file mode 100644 index 00000000..316f73a2 --- /dev/null +++ b/nerfstudio/data/dataparsers/waymo_dataparser.md @@ -0,0 +1,52 @@ +# NeuRAD on Waymo open dataset + +## About +Thanks to the excellent work of NeuRAD, we reproduce some results on the Waymo open dataset. + +Our goal in reproducing and open-sourcing this waymo dataparser for NeuRAD is to provide a basic reference for the self-driving community and to inspire more work. + +In the same folder, there is [wod_dataparser.py](./wod_dataparser.py) which followed the [README-Adding Datasets](https://github.com/georghess/neurad-studio?tab=readme-ov-file#adding-datasets) suggestions. In addition, we added also [wod_utils.py](./wod_utils.py) which did the main work for converting/exporting Waymo dataset. + +In addition, we have also added the rolling shutter support for Waymo dataset as the rolling shutter direction is horizontal instead of the vertical one in Pandaset. Here are some examples of the comparison results (on squence of 10588): +![](./../../../docs/_static/imgs/NeuRAD-RS-Waymo-Front.png) +![](./../../../docs/_static/imgs/NeuRAD-RS-Waymo-Left.png) +![](./../../../docs/_static/imgs/NeuRAD-RS-Waymo-Right.png) + + +### Benchmark between Pandaset & Waymo +| Dataset | Sequence | Frames | Cameras | PSNR | SSIM | LIPS | +|--- |--- |--- |--- |--- |--- |--- | +| Pandaset | 006 |80 | FC |25.1562​|0.8044​ |0.1575​| +| Pandaset | 011 |80 | 360 |26.3919​|0.8057​ |0.2029​| +| Waymo | 10588771936253546636| 50 | FC | 27.5555|0.8547|0.121 +| Waymo | 473735159277431842 | 150| FC | 29.1758|0.8717|0.1592 +| Waymo | 4468278022208380281 | ALL| FC |30.5247​|0.8787​|0.1701​ + +Notes: All above results were obtained with the same hyperparameters and configurations from NeuRAD paper (**Appendix A**) + +### Results +#### Waymo RGB rendering - Sequence 10588 - 3 cameras (FC_LEFT, FC, FC_RIGHT) +[![Sequence 10588 - 3 cameras](http://img.youtube.com/vi/eR1bHeh7p8A/0.jpg)](https://www.youtube.com/watch?v=eR1bHeh7p8A) +> Up is ground truth, bottom is rendered. + +#### Actor removal - Sequence 20946​ - FC cameras +[![Sequence 20946](http://img.youtube.com/vi/mkMdzAvTez4/0.jpg)](https://www.youtube.com/watch?v=mkMdzAvTez4) +> Left is ground truth, right is rendered. + +#### Novel view synthesis - Sequence 20946​ - Ego vehicle 1m up +[![Ego vehicle 1m up](http://img.youtube.com/vi/U8VRboWLj_c/0.jpg)](https://www.youtube.com/watch?v=U8VRboWLj_c) +> Left is ground truth, right is rendered. + +#### Novel view synthesis - Sequence 20946​ - Ego vehicle 1m left +[![Ego vehicle 1m left](http://img.youtube.com/vi/q_HFmc6JPzQ/0.jpg)](https://www.youtube.com/watch?v=q_HFmc6JPzQ) +> Left is ground truth, right is rendered. + +## Links + +Results has been done with waymo open dataset [v2.0.0, gcloud link](https://console.cloud.google.com/storage/browser/waymo_open_dataset_v_2_0_0) + +## Contributors + +- Lei Lei, Leddartech +- Julien Stanguennec, Leddartech +- Pierre Merriaux, Leddartech \ No newline at end of file diff --git a/nerfstudio/data/dataparsers/wod_dataparser.py b/nerfstudio/data/dataparsers/wod_dataparser.py new file mode 100644 index 00000000..74df1187 --- /dev/null +++ b/nerfstudio/data/dataparsers/wod_dataparser.py @@ -0,0 +1,303 @@ +# Copyright 2024 the authors of NeuRAD and contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data parser for Waymo Open Dataset""" + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Literal, Optional, Tuple, Type + +import numpy as np +import torch +import transforms3d +from waymo_open_dataset.v2.perception import camera_image + +from nerfstudio.cameras.cameras import Cameras, CameraType +from nerfstudio.cameras.lidars import Lidars, LidarType +from nerfstudio.data.dataparsers.ad_dataparser import DUMMY_DISTANCE_VALUE, ADDataParser, ADDataParserConfig +from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs +from nerfstudio.data.dataparsers.wod_utils import ExportImages, ExportLidar, ObjectsID, ParquetReader, SelectedTimestamp +from nerfstudio.data.utils.lidar_elevation_mappings import WOD64_ELEVATION_MAPPING + +WOD_ELEVATION_MAPPING = {"Wod64": WOD64_ELEVATION_MAPPING} +WOD_AZIMUT_RESOLUTION = {"Wod64": 0.140625} +WOD_SKIP_ELEVATION_CHANNELS = {"Wod64": ()} + +HORIZONTAL_BEAM_DIVERGENCE = 2.4e-3 # radians +VERTICAL_BEAM_DIVERGENCE = 1.5e-3 # radians + +ALLOWED_DEFORMABLE_CLASSES = ( + "TYPE_PEDESTRIAN", + "TYPE_CYCLIST", +) + +ALLOWED_RIGID_CLASSES = ( + "TYPE_VEHICLE", + "TYPE_SIGN", +) +WOD_CAMERA_NAME_2_ID = {e.name: e.value for e in camera_image.CameraName if e.name != "UNKNOWN"} + + +@dataclass +class WoDParserConfig(ADDataParserConfig): + """Waymo Open Dataset config.""" + + _target: Type = field(default_factory=lambda: WoD) + """target class to instantiate""" + sequence: str = "10588771936253546636_2300_000_2320_000" + """Name of the scene (ie: so-called context_name).""" + data: Path = Path("/data/dataset/wod/") + """Raw dataset path to WOD""" + parquet_dir: str = "training" + """Change to validation when some sequence is in validation""" + output_folder: Path = Path("/data/dataset/wod/images") + """Output saving folder for images, by defaut it will be set with wod dataset path.""" + train_split_fraction: float = 0.5 + """The percent of images to use for training. The remaining images are for eval.""" + start_frame: int = 0 + """Start frame""" + end_frame: Optional[int] = None + """End frame. When set to known end frame will be the last one.""" + dataset_end_fraction: float = 1.0 + """At what fraction of the dataset to end. Different value than 1.0 not supported with current implementation of wod dataset.""" + cameras: Tuple[Literal["FRONT", "FRONT_LEFT", "FRONT_RIGHT", "SIDE_LEFT", "SIDE_RIGHT"], ...] = ( + "FRONT", + "FRONT_LEFT", + "FRONT_RIGHT", + # "SIDE_LEFT", + # "SIDE_RIGHT", + ) + """Which cameras to use.""" + lidars: Tuple[Literal["Top"], ...] = ("Top",) + """Which lidars to use, only lidar TOP is supported.""" + load_cuboids: bool = True + """Whether to load cuboid annotations.""" + cuboids_ids: Optional[Tuple[int, ...]] = None + """Selection of cuboids_ids if cuboid_annotations is set to True. If None, all dynamic cuboids will be exported.""" + annotation_interval: float = 0.1 # 10 Hz of capture + """Interval between annotations in seconds.""" + correct_cuboid_time: bool = True + """Whether to correct the cuboid time to match the actual time of observation, not the end of the lidar sweep.""" + min_lidar_dist: Tuple[float, float, float] = (1.0, 1.0, 1.0) + """Wod Top lidar is x-forward, y-left, z-up.""" + add_missing_points: bool = True + """Whether to add missing points (rays that did not return) to the point clouds.""" + lidar_elevation_mapping: Dict[str, Dict] = field(default_factory=lambda: WOD_ELEVATION_MAPPING) + """Elevation mapping for each lidar.""" + skip_elevation_channels: Dict[str, Tuple] = field(default_factory=lambda: WOD_SKIP_ELEVATION_CHANNELS) + """Channels to skip when adding missing points.""" + lidar_azimuth_resolution: Dict[str, float] = field(default_factory=lambda: WOD_AZIMUT_RESOLUTION) + """Azimuth resolution for each lidar.""" + rolling_shutter_offsets: Tuple[float, float] = (-0.022, 0.022) + """In Waymo the image time is captured either left_2_right or right_2_left with cols.""" + paint_points: bool = True + """Whether to paint the points in the point cloud.""" + + +@dataclass +class WoD(ADDataParser): + """Waymo Open Dataset DatasetParser""" + + config: WoDParserConfig + + def _get_cameras(self) -> Tuple[Cameras, List[Path]]: + """Images are exported from parquet files to jpg in the dataset folder, and filepaths are returns with Cameras.""" + + output_folder_name = f"{self.config.sequence}_start{self.config.start_frame}_end{self.config.end_frame}" + output_folder_name += "_cameras_" + "_".join([str(id) for id in self.cameras_ids]) + images_output_folder: Path = Path(self.config.output_folder) / output_folder_name # type: ignore + + export_images = ExportImages( + self.parquet_reader, + output_folder=str(images_output_folder), + select_ts=self.select_ts, + cameras_ids=self.cameras_ids, + ) + + data_out, (rolling_shutter, rolling_shutter_direction) = export_images.process() + rolling_shutter = round(rolling_shutter, 3) + rs_offfsets = (-rolling_shutter, rolling_shutter) + + # rolling shutter offset reverse when right to left + if rolling_shutter_direction == 4: # RIGHT_TO_LEFT + rs_offfsets = (rolling_shutter, -rolling_shutter) + + self.config.rolling_shutter_offsets = rs_offfsets + rs_direction = "Horizontal" if rolling_shutter_direction in (2, 4) else "Vertical" + + img_filenames = [] + intrinsics = [] + poses = [] + idxs = [] + heights = [] + widths = [] + times = [] + for frame in data_out["frames"]: + img_filenames.append(str(images_output_folder / frame["file_path"])) + poses.append(frame["transform_matrix"]) + intrinsic = np.array( + [ + [frame["f_u"], 0, frame["c_u"]], + [0, frame["f_v"], frame["c_v"]], + [0, 0, 1], + ] + ) + intrinsics.append(intrinsic) + idxs.append(frame["sensor_id"]) + heights.append(frame["h"]) + widths.append(frame["w"]) + times.append(frame["time"]) + + intrinsics = torch.tensor(np.array(intrinsics), dtype=torch.float32) + poses = torch.tensor(np.array(poses), dtype=torch.float32) + times = torch.tensor(times, dtype=torch.float64) + idxs = torch.tensor(idxs).int().unsqueeze(-1) + cameras = Cameras( + fx=intrinsics[:, 0, 0], + fy=intrinsics[:, 1, 1], + cx=intrinsics[:, 0, 2], + cy=intrinsics[:, 1, 2], + height=torch.tensor(heights), + width=torch.tensor(widths), + camera_to_worlds=poses[:, :3, :4], + camera_type=CameraType.PERSPECTIVE, + times=times, + metadata={ + "sensor_idxs": idxs, + "rs_direction": rs_direction, + }, + ) + return cameras, img_filenames + + def _get_lidars(self) -> Tuple[Lidars, Tuple[List[torch.Tensor], List[torch.Tensor]]]: + """The implementation of _get_lidar for WoD dataset actually returns directly tensors for pts_lidar and pts_missing, while + other dataparser provide links to files containing the point-cloud which are then processed with _read_lidar function in + _generate_dataparser_output. With WoD all lidar point-cloud are stored in parquet files, and points cloud are eventually + stored in memory in DataParserOutput object. So most of the job is done within _get_lidars function. + + :return: Tuple[Lidars, Tuple[List[Point-clouds],List[MissingPointsPcd]]] + """ + if self.config.load_cuboids: + objects_id_to_extract = ( + list(self.config.cuboids_ids) if self.config.cuboids_ids is not None else self.objects_id.dynamic_id + ) + else: + objects_id_to_extract = [] + + export_lidar = ExportLidar(self.parquet_reader, self.select_ts, self.objects_id, self.config.output_folder) + poses, pts_lidar_list, missing_pts_list, times, actors = export_lidar.process( + objects_id_to_extract=objects_id_to_extract + ) + + # save actors for later trajectories calculation + self.actors = actors + + pts_lidar_list = [torch.from_numpy(pts) for pts in pts_lidar_list] + missing_pts_list = [torch.from_numpy(pts) for pts in missing_pts_list] + + times = torch.tensor(times, dtype=torch.float64) + idxs = torch.zeros_like(times).int().unsqueeze(-1) + + poses = torch.from_numpy(np.array(poses)) + lidars = Lidars( + lidar_to_worlds=poses[:, :3, :4], + lidar_type=LidarType.WOD64, + times=times, + metadata={"sensor_idxs": idxs}, + horizontal_beam_divergence=HORIZONTAL_BEAM_DIVERGENCE, + vertical_beam_divergence=VERTICAL_BEAM_DIVERGENCE, + valid_lidar_distance_threshold=DUMMY_DISTANCE_VALUE / 2, + ) + return lidars, (pts_lidar_list, missing_pts_list) + + def _read_lidars( + self, lidars: Lidars, pts_list_tuple: Tuple[List[torch.Tensor], List[torch.Tensor]] + ) -> List[torch.Tensor]: + """Reads the point clouds from the given filenames. Should be in x,y,z,r,t order. t is optional.""" + + pts_lidar_list, missing_pts_list = pts_list_tuple + if self.config.add_missing_points: + """Currently this part has been done during wod_export, here we only concatenate together. + For future modification, refer to _read_lidars method from pandaset_dataparser.py + """ + point_clouds = [torch.cat([pc, missing], dim=0) for pc, missing in zip(pts_lidar_list, missing_pts_list)] + else: + point_clouds = pts_lidar_list + + lidars.lidar_to_worlds = lidars.lidar_to_worlds.float() + return point_clouds + + def _get_actor_trajectories(self) -> List[Dict]: + """Returns a list of actor trajectories. + + Each trajectory is a dictionary with the following keys: + - poses: the poses of the actor (float32) + - timestamps: the timestamps of the actor (float64) + - dims: the dimensions of the actor, wlh order (float32) + - label: the label of the actor (str) + - stationary: whether the actor is stationary (bool) + - symmetric: whether the actor is expected to be symmetric (bool) + - deformable: whether the actor is expected to be deformable (e.g. pedestrian) + """ + trajs_list = [] + allowed_classes = ALLOWED_RIGID_CLASSES + if self.config.include_deformable_actors: + allowed_classes += ALLOWED_DEFORMABLE_CLASSES + + rot_minus_90 = np.eye(4) + rot_minus_90[:3, :3] = transforms3d.euler.euler2mat(0.0, 0.0, -np.pi / 2) + + for index, actor in self.actors.items(): + actor_type = actor["label"] + + if actor_type not in allowed_classes: + continue + poses = np.array(actor["poses"]) @ rot_minus_90 + timestamps = actor["timestamps"] + actor_dimensions = self.objects_id.id2box_dimensions[index] # (length, width, height) + lenght, width, height = actor_dimensions.values() + dims = np.array([width, lenght, height], dtype=np.float32) + + symmetric = actor_type == "TYPE_VEHICLE" + deformable = actor_type in ALLOWED_DEFORMABLE_CLASSES + + trajs_list.append( + { + "poses": torch.tensor(poses).float(), + "timestamps": torch.tensor(timestamps, dtype=torch.float64), + "dims": torch.tensor(dims, dtype=torch.float32), + "label": actor_type, + "stationary": False, # Only 'export' dynamic objects from ExportLidar + "symmetric": symmetric, + "deformable": deformable, + } + ) + return trajs_list + + def _generate_dataparser_outputs(self, split="train") -> DataparserOutputs: + assert ( + self.config.dataset_end_fraction == 1.0 + ), f"Wod data parser only support dataset_end_fraction == 1.0, value received {self.config.dataset_end_fraction}" + self.cameras_ids = [WOD_CAMERA_NAME_2_ID[cam] for cam in self.config.cameras] + parquet_dir = str(self.config.data / self.config.parquet_dir) + self.parquet_reader = ParquetReader(self.config.sequence, dataset_dir=parquet_dir) + self.select_ts = SelectedTimestamp(self.parquet_reader, self.config.start_frame, self.config.end_frame) + self.objects_id = ObjectsID(self.parquet_reader, self.select_ts) + + return super()._generate_dataparser_outputs(split) + + +if __name__ == "__main__": + wod_test = WoD(config=WoDParserConfig()) + do = wod_test._generate_dataparser_outputs() + print(do) diff --git a/nerfstudio/data/dataparsers/wod_utils.py b/nerfstudio/data/dataparsers/wod_utils.py new file mode 100644 index 00000000..8d790ba3 --- /dev/null +++ b/nerfstudio/data/dataparsers/wod_utils.py @@ -0,0 +1,654 @@ +# Copyright 2024 the authors of NeuRAD and contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import glob +import os +import warnings +from copy import deepcopy +from dataclasses import asdict +from typing import Dict, List, Optional, Tuple, TypedDict + +import dask.dataframe as dd +import numpy as np +import numpy.typing as npt +import tensorflow as tf +import transforms3d +from tqdm import tqdm +from waymo_open_dataset import v2 +from waymo_open_dataset.utils import box_utils, transform_utils +from waymo_open_dataset.v2.perception import ( + box as _v2_box, + camera_image as _v2_camera_image, + context as _v2_context, + lidar as _v2_lidar, + pose as _v2_pose, +) +from waymo_open_dataset.v2.perception.utils.lidar_utils import convert_range_image_to_cartesian + +tf.config.set_visible_devices([], "GPU") # Not useful for parsing data. + +# Disable annoying warnings from PyArrow using under the hood. +warnings.simplefilter(action="ignore", category=FutureWarning) + +DATA_FREQUENCY = 10.0 # 10 Hz +DUMMY_DISTANCE_VALUE = 2e3 # meters, used for missing points +TIME_OFFSET = 50e-3 # => 50ms ,time offset in sec; half scanning period + + +class ActorsDict(TypedDict): + poses: List[np.ndarray] + timestamps: List[float] + label: str + + +class ImageFrame(TypedDict): + file_path: str + transform_matrix: List[List[float]] + frame_id: int + time: float + sensor_id: int + f_u: float + f_v: float + c_u: float + c_v: float + k1: float + k2: float + p1: float + p2: float + k3: float + h: int + w: int + + +def get_camera_names(): + return [f"{e.value}:{e.name}" for e in _v2_camera_image.CameraName if e.name != "UNKNOWN"] + + +def get_mock_timestamps(points: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: + """Get mock relative timestamps for the wod points.""" + # the wod has x forward, y left, z up and the sweep is split behind the car. + # it is also rotating clockwise, meaning that the angles close to -pi are the + # first ones in the sweep and the ones close to pi are the last ones in the sweep. + angles = -np.arctan2(points[:, 1], points[:, 0]) # N, [-pi, pi] + # angles += np.pi # N, [0, 2pi] + # see how much of the rotation have finished + fraction_of_rotation = angles / (2 * np.pi) # N, [0, 1] + # get the pseudo timestamps based on the total rotation time + timestamps = fraction_of_rotation * 1.0 / DATA_FREQUENCY + return timestamps + + +class ParquetReader: + def __init__(self, context_name: str, dataset_dir: str = "/data/dataset/wod/training", nb_partitions: int = 120): + self.context_name = context_name + self.dataset_dir = dataset_dir + self.nb_partitions = nb_partitions + + def read(self, tag: str) -> dd.DataFrame: + """Creates a Dask DataFrame for the component specified by its tag.""" + paths = glob.glob(f"{self.dataset_dir}/{tag}/{self.context_name}.parquet") + return dd.read_parquet(paths, npartitions=self.nb_partitions) # type: ignore + + def __call__(self, tag: str) -> dd.DataFrame: + return self.read(tag) + + +class SelectedTimestamp: + def __init__(self, reader: ParquetReader, start_frame: int = 0, end_frame: Optional[int] = None): + cam_image_df = reader("camera_image") + cam_image_df = cam_image_df["key.frame_timestamp_micros"] + self.ts_list = np.unique(np.array(cam_image_df.compute())) + self.ts_selected = self.ts_list[start_frame:end_frame] + pass + + def __len__(self) -> int: + return len(self.ts_selected) + + def sequence_len(self) -> int: + return len(self.ts_list) + + def get_selected_ts(self) -> List[int]: + return self.ts_selected.tolist() + + def is_selected(self, ts: int) -> bool: + return ts in self.ts_selected + + def ts2frame_idx(self, ts: int) -> int: + if ts not in self.ts_selected: + raise IndexError(f"{ts} is not in selected timestamps") + return np.where(self.ts_selected == ts)[0][0] + + +class ObjectsID: + """Helper extraction object static/dynamic IDs to be processed by ExportLidar class.""" + + def __init__(self, reader: ParquetReader, selected_ts: SelectedTimestamp, speed_static_threshold: float = 0.2): + self.reader = reader + self.speed_static_threshold = speed_static_threshold + self.dynamic_id: list[int] = [] + self.dynamic_uuid: list[str] = [] + self.dynamic_type: list[str] = [] + self.id2uuid: dict[int, str] = {} + self.uuid2id: dict[str, int] = {} + self.id2box_dimensions: dict[int, dict[str, float]] = {} + self.selected_ts = selected_ts + self.keep_id_after_lidar_extraction = [] + self.build_dict() + + def build_dict(self): + lidar_box_df = self.reader("lidar_box") + + lidar_box_df2 = ( + lidar_box_df.groupby(["key.segment_context_name", "key.laser_object_id"]).agg(list).reset_index() + ) + + for object_id, (_, r) in enumerate(lidar_box_df2.iterrows()): + LiDARBoxCom = v2.LiDARBoxComponent.from_dict(r) + ts_mask = np.isin(np.array(LiDARBoxCom.key.frame_timestamp_micros), self.selected_ts.get_selected_ts()) + if not np.any(ts_mask): + continue + dimensions = LiDARBoxCom.box.size + + length, width, height = ( + np.array(dimensions.x)[ts_mask][0], + np.array(dimensions.y)[ts_mask][0], + np.array(dimensions.z)[ts_mask][0], + ) + + self.id2box_dimensions[object_id] = {"length": length, "width": width, "height": height} + object_uuid = LiDARBoxCom.key.laser_object_id + self.id2uuid[object_id] = object_uuid + # object is considered static if static in frames selection ( < speed threshold ) + speed = np.array( + [ + np.array(LiDARBoxCom.speed.x)[ts_mask], # type: ignore + np.array(LiDARBoxCom.speed.y)[ts_mask], # type: ignore + np.array(LiDARBoxCom.speed.z)[ts_mask], # type: ignore + ] + ) + speed = speed[~np.isnan(speed).any(axis=1)] + speed = np.linalg.norm(speed, axis=0) + dynamic = np.any(speed > self.speed_static_threshold) + if dynamic: + self.dynamic_id.append(object_id) + self.dynamic_uuid.append(object_uuid) + self.dynamic_type.append(_v2_box.BoxType(LiDARBoxCom.type[0]).name) # type: ignore + + for id, uuid in self.id2uuid.items(): + self.uuid2id[uuid] = id + + def is_dynamic(self, id: int | str): + if isinstance(id, int): + return id in self.dynamic_id + if isinstance(id, str): + return self.uuid2id[id] in self.dynamic_id + + def get_box_dimensions(self, id: int | str): + if isinstance(id, int): + return self.id2box_dimensions[id] + if isinstance(id, str): + return self.id2box_dimensions[self.uuid2id[id]] + + def get_box_coordinates(self, dynamic_only: bool = True) -> Dict[str, np.ndarray]: + lidar_box_df = self.reader("lidar_box") + + lidar_box_df2 = ( + lidar_box_df.groupby(["key.segment_context_name", "key.laser_object_id"]).agg(list).reset_index() + ) + + objects_coordinates = {} + for object_id, (_, r) in enumerate(lidar_box_df2.iterrows()): + LiDARBoxCom = v2.LiDARBoxComponent.from_dict(r) + ts_mask = np.isin(np.array(LiDARBoxCom.key.frame_timestamp_micros), self.selected_ts.get_selected_ts()) + if not np.any(ts_mask): + continue + + object_uuid = LiDARBoxCom.key.laser_object_id + object_id = self.uuid2id[object_uuid] + + if dynamic_only: + if object_id in self.dynamic_id: + objects_coordinates[object_id] = LiDARBoxCom.box.center + else: + objects_coordinates[object_id] = LiDARBoxCom.box.center + + return objects_coordinates + + def print_dynamic(self): + for id, type in zip(self.dynamic_id, self.dynamic_type): + print(f"{id}:{type}, ", end="") + + +class ExportImages: + """ + Used to create folder and save image in images, and returns a tuple with: + - a list of images dict with (image path, frame_id, time, pose (nerf), sensor_id, intrinsic)) + - a tuple of rolling shutter information (duration and direction) + + :param reader: ParquetReader object + :param select_ts: SelectedTimestamp object + :param output_folder: Root folder images where will be saved. + :param cameras_ids: Select which cameras_ids to export, defaults to list(range(1, len(get_camera_names()) + 1)) + """ + + IMAGE_FOLDER = "images" + + def __init__( + self, + reader: ParquetReader, + select_ts: SelectedTimestamp, + output_folder: str, + cameras_ids: List[int] = list(range(1, len(get_camera_names()) + 1)), + ): + self.reader: ParquetReader = reader + self.select_ts = select_ts + self.cameras_ids = cameras_ids + + self.output_folder = os.path.join(output_folder, self.IMAGE_FOLDER) + if not os.path.exists(self.output_folder): + os.makedirs(self.output_folder) + + def process(self) -> Tuple[dict[str, List[ImageFrame]], Tuple[float, int]]: + cam_calib = self.reader("camera_calibration") + camera_calib = {} + data_out: dict[str, List[ImageFrame]] = {} + + data_out["frames"] = [] + for i, (_, r) in enumerate(cam_calib.iterrows()): + calib = v2.CameraCalibrationComponent.from_dict(r) + camera_calib["cam" + v2.perception.camera_image.CameraName(calib.key.camera_name).name] = ( # type: ignore + calib.extrinsic.transform.reshape(4, 4) # type: ignore + ) + camera_calib["cam" + v2.perception.camera_image.CameraName(calib.key.camera_name).name + "_intrinsics"] = ( # type: ignore + asdict(calib.intrinsic) | {"h": calib.height, "w": calib.width} # type: ignore + ) + # rolling shutter direction for offset calculation + rolling_shutter_direction = calib.rolling_shutter_direction + + print("Camera processing...") + cam_image_df = self.reader("camera_image") + cam_image_df = cam_image_df[ + (cam_image_df["key.camera_name"].isin(self.cameras_ids)) # type: ignore + & (cam_image_df["key.frame_timestamp_micros"].isin(self.select_ts.get_selected_ts())) # type: ignore + ] + camera_poses = [] + rolling_shutter_list = [] + for i, (_, r) in tqdm(enumerate(cam_image_df.iterrows())): # type: ignore + CamComp = v2.CameraImageComponent.from_dict(r) + tr_image = CamComp.pose.transform.reshape(4, 4) # type: ignore + delta_time = ( + CamComp.rolling_shutter_params.camera_readout_done_time + + CamComp.rolling_shutter_params.camera_trigger_time + ) / 2 - CamComp.pose_timestamp + + rolling_shutter = ( + CamComp.rolling_shutter_params.camera_readout_done_time + - CamComp.rolling_shutter_params.camera_trigger_time + ) / 2 + rolling_shutter_list.append(rolling_shutter) + + avx, avy, avz = ( + CamComp.velocity.angular_velocity.x, + CamComp.velocity.angular_velocity.y, + CamComp.velocity.angular_velocity.z, + ) + skm = np.array([[0, -avz, avy], [avz, 0, -avx], [-avy, avx, 0]]) + r_image = tr_image[:3, :3] + + r_updated = ( + (np.eye(3) + delta_time * skm) @ r_image + ) # probably another way to do it : R_derivative = skm@r_image; r_image + delta_time * R_derivative + t_updated = tr_image[:3, 3] + delta_time * np.array( + [ + CamComp.velocity.linear_velocity.x, + CamComp.velocity.linear_velocity.y, + CamComp.velocity.linear_velocity.z, + ] + ) + tr_updated = np.eye(4) + tr_updated[:3, 3] = t_updated + tr_updated[:3, :3] = r_updated + + frame_id = self.select_ts.ts2frame_idx(CamComp.key.frame_timestamp_micros) + filename = f"{v2.perception.camera_image.CameraName(CamComp.key.camera_name).name}_{frame_id:08d}.jpg" # type: ignore + + nerfstudio2waymo = np.eye(4) + nerfstudio2waymo[:3, :3] = np.array([[0, -1, 0], [0, 0, 1], [-1, 0, 0]]).T + # opencv2waymo = np.eye(4) + # opencv2waymo[:3,:3] = np.array([[0,-1,0],[0,0,-1],[1,0,0]]).T + calib = camera_calib["cam" + v2.perception.camera_image.CameraName(CamComp.key.camera_name).name] # type: ignore + camera_poses.append(tr_updated @ calib @ nerfstudio2waymo) + data_out["frames"].append( + { + "file_path": os.path.join(self.IMAGE_FOLDER, filename), + "transform_matrix": (camera_poses[-1]).tolist(), + "frame_id": int(frame_id), + "time": delta_time + CamComp.pose_timestamp, + "sensor_id": CamComp.key.camera_name - 1, # sensor_id for NeuRAD, WOD 0 == Unkown + } + | camera_calib[ + "cam" + v2.perception.camera_image.CameraName(CamComp.key.camera_name).name + "_intrinsics" # type: ignore + ] + ) + + save_file = os.path.join(self.output_folder, filename) + if not os.path.exists(save_file): + with open(save_file, "wb") as binary_file: + binary_file.write(CamComp.image) + + # get the mean value for rolling shutter + rolling_shutter = sum(rolling_shutter_list) / (i + 1) + return (data_out, (rolling_shutter, rolling_shutter_direction)) + + +class ExportLidar: + """Utility class for extracting lidar point-cloud and objects from parquet files of WoD v2 dataset.""" + + def __init__( + self, + reader: ParquetReader, + select_ts: SelectedTimestamp, + objects_id: ObjectsID, + output_folder: str, + extract_objects=True, + cameras_ids: List[int] = list(range(1, len(get_camera_names()) + 1)), + ): + self.reader: ParquetReader = reader + self.select_ts = select_ts + self.cameras_ids = cameras_ids + + self.output_folder = output_folder + self.extract_objects = extract_objects + self.objects_id = objects_id + self.cameras_calibration = None + + def convert_range_image_to_point_cloud( + self, + range_image: _v2_lidar.RangeImage, + calibration: _v2_context.LiDARCalibrationComponent, + pixel_pose: Optional[_v2_lidar.PoseRangeImage] = None, + frame_pose: Optional[_v2_pose.VehiclePoseComponent] = None, + keep_polar_features=False, + ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + """Converts one range image from polar coordinates to point cloud. + same as in wod api, but return the mask in addition plus channel id + + Args: + range_image: One range image return captured by a LiDAR sensor. + calibration: Parameters for calibration of a LiDAR sensor. + pixel_pose: If not none, it sets pose for each range image pixel. + frame_pose: This must be set when `pose` is set. + keep_polar_features: If true, keep the features from the polar range image + (i.e. range, intensity, and elongation) as the first features in the + output range image. + + Returns: + A 3 [N, D] tensor of 3D LiDAR points. D will be 3 if keep_polar_features is + False (x, y, z) and 6 if keep_polar_features is True (range, intensity, + elongation, x, y, z). + 1. Lidar points-cloud + 2. Missing points points-cloud + 3. Range image mask above dummy distance. + + """ + + # missing points are found directly from range image + val_clone = deepcopy(range_image.tensor.numpy()) # type: ignore + no_return = val_clone[..., 0] == -1 # where range is -1 + val_clone[..., 0][no_return] = DUMMY_DISTANCE_VALUE + # re-assign the field + object.__setattr__(range_image, "values", val_clone.flatten()) + + # From range image, missing points do not have a pose. + # So we replace their pose with the vehicle pose. + # pixel pose & frame pose + pixel_pose_clone = deepcopy(pixel_pose.tensor.numpy()) # type: ignore + pixel_pose_mask = pixel_pose_clone[..., 0] == 0 + tr_orig = frame_pose.world_from_vehicle.transform.reshape(4, 4) # type: ignore + rot = tr_orig[:3, :3] + x, y, z = tr_orig[:3, 3] + yaw, pitch, roll = transforms3d.euler.mat2euler(rot, "szyx") + # ` [roll, pitch, yaw, x, y, z]` + pixel_pose_clone[..., 0][pixel_pose_mask] = roll + pixel_pose_clone[..., 1][pixel_pose_mask] = pitch + pixel_pose_clone[..., 2][pixel_pose_mask] = yaw + pixel_pose_clone[..., 3][pixel_pose_mask] = x + pixel_pose_clone[..., 4][pixel_pose_mask] = y + pixel_pose_clone[..., 5][pixel_pose_mask] = z + # re-assign the field + object.__setattr__(pixel_pose, "values", pixel_pose_clone.flatten()) + + range_image_cartesian = convert_range_image_to_cartesian( + range_image=range_image, + calibration=calibration, + pixel_pose=pixel_pose, + frame_pose=frame_pose, + keep_polar_features=keep_polar_features, + ) + + range_image_tensor = range_image.tensor + range_image_mask = DUMMY_DISTANCE_VALUE / 2 > range_image_tensor[..., 0] # 0 # type: ignore + points_tensor = tf.gather_nd(range_image_cartesian, tf.compat.v1.where(range_image_mask)) + missing_points_tensor = tf.gather_nd(range_image_cartesian, tf.compat.v1.where(~range_image_mask)) + + return points_tensor, missing_points_tensor, range_image_mask + + def is_within_box_3d(self, point, box, name=None): + """Checks whether a point is in a 3d box given a set of points and boxes. + + Args: + point: [N, 3] tensor. Inner dims are: [x, y, z]. + box: [M, 7] tensor. Inner dims are: [center_x, center_y, center_z, length, + width, height, heading]. + name: tf name scope. + + Returns: + point_in_box; [N, M] boolean tensor. + + """ + + with tf.compat.v1.name_scope(name, "IsWithinBox3D", [point, box]): + center = box[:, 0:3] + dim = box[:, 3:6] + heading = box[:, 6] + # [M, 3, 3] + rotation = transform_utils.get_yaw_rotation(heading) + # [M, 4, 4] + transform = transform_utils.get_transform(rotation, center) + # [M, 4, 4] + transform = tf.linalg.inv(transform) + # [M, 3, 3] + rotation = transform[:, 0:3, 0:3] # type: ignore + # [M, 3] + translation = transform[:, 0:3, 3] # type: ignore + + # [N, M, 3] + point_in_box_frame = tf.einsum("nj,mij->nmi", point, rotation) + translation + # [N, M, 3] + point_in_box = tf.logical_and( + tf.logical_and(point_in_box_frame <= dim * 0.5, point_in_box_frame >= -dim * 0.5), + tf.reduce_all(tf.not_equal(dim, 0), axis=-1, keepdims=True), + ) + # [N, M] + point_in_box = tf.cast( + tf.reduce_prod(input_tensor=tf.cast(point_in_box, dtype=tf.uint8), axis=-1), dtype=tf.bool + ) + + return point_in_box, point_in_box_frame[point_in_box] + + def _load_camera_calibration(self): + """Loads camera calibration from parquet file to dictionnary.""" + cam_calib_df = self.reader("camera_calibration").compute() + self.cameras_calibration = {} + for i, (_, r) in enumerate(cam_calib_df.iterrows()): + calib = v2.CameraCalibrationComponent.from_dict(r) + self.cameras_calibration[calib.key.camera_name] = calib + + def process( + self, objects_id_to_extract: List[int] = [] + ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[float], Dict[int, ActorsDict]]: + print("Lidar processing...") + objects_uuid_to_extract = [ + self.objects_id.id2uuid[object_id_to_extract] for object_id_to_extract in objects_id_to_extract + ] + + self._load_camera_calibration() + lidar_calib = self.reader("lidar_calibration").compute() + + lidar_df = self.reader("lidar").compute() + lidar_df = lidar_df[ + (lidar_df["key.laser_name"] == _v2_lidar.LaserName.TOP.value) # Only lidar TOP is used + & (lidar_df["key.frame_timestamp_micros"].isin(self.select_ts.get_selected_ts())) + ] + + lidar_pose_df = self.reader("lidar_pose").compute() + + vehicle_pose_df = self.reader("vehicle_pose").compute() + vehicle_pose_df = vehicle_pose_df[ + vehicle_pose_df["key.frame_timestamp_micros"].isin(self.select_ts.get_selected_ts()) + ] + + lidar_box_df = self.reader("lidar_box").compute() + lidar_box_df = lidar_box_df[lidar_box_df["key.frame_timestamp_micros"].isin(self.select_ts.get_selected_ts())] + + pts_lidar_list = [] + missing_pts_list = [] + poses = [] + times = [] + + # Neurad actor trajectories + actors: Dict[int, ActorsDict] = {} + + for i, (_, r) in tqdm(enumerate(lidar_df.iterrows())): + LidarComp = v2.LiDARComponent.from_dict(r) + lidar_pose_df_ = lidar_pose_df[ + (lidar_pose_df["key.frame_timestamp_micros"] == LidarComp.key.frame_timestamp_micros) + & (lidar_pose_df["key.laser_name"] == _v2_lidar.LaserName.TOP.value) + ] + LidarPoseComp = v2.LiDARPoseComponent.from_dict(lidar_pose_df_.iloc[0]) + lidar_calib_ = lidar_calib[lidar_calib["key.laser_name"] == _v2_lidar.LaserName.TOP.value] + LidarCalibComp = v2.LiDARCalibrationComponent.from_dict(lidar_calib_.iloc[0]) + vehicle_pose_df_ = vehicle_pose_df[ + vehicle_pose_df["key.frame_timestamp_micros"] == LidarComp.key.frame_timestamp_micros + ] + VehiclePoseCom = v2.VehiclePoseComponent.from_dict(vehicle_pose_df_.iloc[0]) + + lidar_box_df_ = lidar_box_df[ + (lidar_box_df["key.frame_timestamp_micros"] == LidarComp.key.frame_timestamp_micros) + & (lidar_box_df["key.laser_object_id"].isin(self.objects_id.dynamic_uuid)) + ] + + pts_lidar, missing_pts, _ = self.convert_range_image_to_point_cloud( + LidarComp.range_image_return1, + LidarCalibComp, + LidarPoseComp.range_image_return1, + VehiclePoseCom, + keep_polar_features=True, + ) + missing_pts = missing_pts.numpy() + + # compute timestamp for each lidar frame + time = LidarComp.key.frame_timestamp_micros / 1e6 + TIME_OFFSET # convert to seconds + times.append(time) + + timestamps = get_mock_timestamps(pts_lidar[:, 3:6]) # (N, 6)->(..., x,y,z) + timestamps = np.expand_dims(timestamps, axis=1) + + timestamps_miss = get_mock_timestamps(missing_pts[:, 3:6]) # (N, 6)->(..., x,y,z) + timestamps_miss = np.expand_dims(timestamps_miss, axis=1) + + pts_lidar = pts_lidar.numpy() + intensity = pts_lidar[:, 1:2] # (range, intensity, elongation, x, y, z) => (N, 1) + intensity = self._normalize(intensity) # => [0.0, 1.0] + + pts_lidar = np.hstack((pts_lidar[:, 3:6], np.ones((pts_lidar.shape[0], 1)))) + + pts_lidar_in_vehicle = pts_lidar + l2v = LidarCalibComp.extrinsic.transform.reshape(4, 4) # type: ignore + pts_lidar_sensor = (np.linalg.inv(l2v) @ pts_lidar_in_vehicle.T).T[:, :3] + v2w = VehiclePoseCom.world_from_vehicle.transform.reshape(4, 4) # type: ignore + l2w = v2w @ l2v + + pts_lidar_world = (v2w @ pts_lidar_in_vehicle.T).T[:, :3] + + lidar_box_df_selected_boxes = lidar_box_df_[ + lidar_box_df_["key.laser_object_id"].isin(objects_uuid_to_extract) + ] + for _, lidar_box in lidar_box_df_selected_boxes.iterrows(): + v1_box = tf.transpose( + tf.constant( + [ + lidar_box["[LiDARBoxComponent].box.center.x"], + lidar_box["[LiDARBoxComponent].box.center.y"], + lidar_box["[LiDARBoxComponent].box.center.z"], + lidar_box["[LiDARBoxComponent].box.size.x"], + lidar_box["[LiDARBoxComponent].box.size.y"], + lidar_box["[LiDARBoxComponent].box.size.z"], + lidar_box["[LiDARBoxComponent].box.heading"], + ], + dtype=tf.float32, + ) + ) + v1_box = tf.reshape(v1_box, (1, -1)) + v1_box_world = box_utils.transform_box( + v1_box, + VehiclePoseCom.world_from_vehicle.transform.reshape((4, 4)).astype("float32"), + tf.eye(4), # type: ignore + ) + mask_object = box_utils.is_within_box_3d(pts_lidar_world[:, :3], v1_box_world).numpy() # type: ignore + mask_object = np.any(mask_object, axis=1) + + mean_ts_from_lidar_pts = timestamps[ + mask_object + ].mean() # timestamp of actor is taken from mean of lidar points inside the bbox + object_timestamp = ( + time + mean_ts_from_lidar_pts if np.any(mask_object) else time + ) # If no lidar points in box, timestamp of frame + + # actor pose + # actor ids + uuids = lidar_box["key.laser_object_id"] + actor_id = self.objects_id.uuid2id[uuids] + + # actor type + type_ = lidar_box["[LiDARBoxComponent].type"] + type_names = _v2_box.BoxType(type_).name + + tr_object = np.eye(4) + tr_object[:3, :3] = transforms3d.euler.euler2mat(0, 0, v1_box_world.numpy().ravel()[6]) # type: ignore + tr_object[:3, 3] = v1_box_world.numpy().ravel()[:3] # type: ignore + + if actor_id in actors: + actors[actor_id]["poses"].append(tr_object) + actors[actor_id]["timestamps"].append(object_timestamp) + else: + actors[actor_id] = {"poses": [tr_object], "timestamps": [object_timestamp], "label": type_names} + + pts_lidar = np.hstack((pts_lidar_sensor, intensity, timestamps)) # => (N, 5) == (x, y, z, int, t) + pts_lidar_list.append(pts_lidar) + + missing_intensity = np.zeros_like(missing_pts[:, 1:2]) # 0 for missing point intensity + missing_pts_list.append(np.hstack((missing_pts[:, 3:6], missing_intensity, timestamps_miss))) + + poses.append(l2w) + + return poses, pts_lidar_list, missing_pts_list, times, actors + + def _normalize(self, points: np.ndarray) -> np.ndarray: + max_ = points.max() + min_ = points.min() + + points = (points - min_) / (max_ - min_) + return points diff --git a/nerfstudio/data/utils/lidar_elevation_mappings.py b/nerfstudio/data/utils/lidar_elevation_mappings.py index f2c891a5..15bc53da 100644 --- a/nerfstudio/data/utils/lidar_elevation_mappings.py +++ b/nerfstudio/data/utils/lidar_elevation_mappings.py @@ -250,3 +250,70 @@ } VELODYNE_HDL32E_ELEVATION_MAPPING = dict(zip(np.arange(32), tuple(np.linspace(-30.67, 10.67, 32)))) + +WOD64_ELEVATION_MAPPING = { + 0: 2.5028389775650304, + 1: 2.321411751659905, + 2: 2.160192256145731, + 3: 1.9888398480248883, + 4: 1.8209349283573786, + 5: 1.6502418044970433, + 6: 1.4938679389287557, + 7: 1.3221564279311344, + 8: 1.1632512247221256, + 9: 0.9913750200128197, + 10: 0.8101498633691424, + 11: 0.6482041237244122, + 12: 0.48336997052669073, + 13: 0.3201589105532588, + 14: 0.16462286430089693, + 15: -0.011621928777127347, + 16: -0.1892787856748749, + 17: -0.34201145065403127, + 18: -0.5054471288374568, + 19: -0.6827621682735187, + 20: -0.8449790324744345, + 21: -1.0197501521052226, + 22: -1.1886280361746464, + 23: -1.3669402000816122, + 24: -1.5409274243550963, + 25: -1.7570629940063032, + 26: -1.9649363657632477, + 27: -2.1894398590475905, + 28: -2.4374471868305987, + 29: -2.6683997977793497, + 30: -2.9254801778651274, + 31: -3.208793362354923, + 32: -3.4652440977914574, + 33: -3.770654905928011, + 34: -4.068046596015399, + 35: -4.365557254206326, + 36: -4.68136205944531, + 37: -5.023904856877318, + 38: -5.360837632630594, + 39: -5.715495138382295, + 40: -6.091110098376429, + 41: -6.457270941426794, + 42: -6.8451480987631, + 43: -7.24803061771811, + 44: -7.645534995724646, + 45: -8.08179034271091, + 46: -8.522502366939104, + 47: -8.957247796204939, + 48: -9.421474930460981, + 49: -9.885265834826649, + 50: -10.369068098135806, + 51: -10.829727642824542, + 52: -11.332199121554261, + 53: -11.822915504645561, + 54: -12.364441979859368, + 55: -12.908557767713962, + 56: -13.437836414956127, + 57: -13.983840803683233, + 58: -14.537462865288743, + 59: -15.076443690248071, + 60: -15.689281398977771, + 61: -16.300273448699592, + 62: -16.911934322750316, + 63: -17.546811286086175, +} # degrees diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py index a6b885de..4e03b171 100644 --- a/nerfstudio/scripts/render.py +++ b/nerfstudio/scripts/render.py @@ -53,6 +53,7 @@ get_spiral_path, ) from nerfstudio.cameras.cameras import Cameras, CameraType, RayBundle +from nerfstudio.cameras.lidars import transform_points from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager @@ -1159,16 +1160,12 @@ def update_config(config: TrainerConfig) -> TrainerConfig: if "ray_drop_prob" in lidar_output: points_in_local = points_in_local[(lidar_output["ray_drop_prob"] < 0.5).squeeze(-1)] - points_in_world = ( - lidar.lidar_to_worlds[0] - @ torch.cat( - [ - points_in_local, - torch.ones_like(points_in_local[..., :1]), - ], - dim=-1, - ).unsqueeze(-1) - ).squeeze() + points_in_world = transform_points(points_in_local, lidar.lidar_to_worlds[0]) + # get ground truth for comparison + gt_point_in_world = transform_points(batch["lidar"][..., :3], lidar.lidar_to_worlds[0]) + plot_lidar_points( + gt_point_in_world.cpu().detach().numpy(), output_path / f"gt-lidar_{lidar_idx}.png" + ) plot_lidar_points( points_in_world.cpu().detach().numpy(), output_path / f"lidar_{lidar_idx}.png" ) diff --git a/pyproject.toml b/pyproject.toml index d40d50eb..a2f6a109 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ version = "0.1.0" description = "Neural Rendering methods that are specialized for Autonomous Driving (NeuRAD)." readme = "README.md" license = { text="Apache 2.0"} -requires-python = ">=3.8.0" +requires-python = ">=3.10.0" classifiers = [ "Development Status :: 3 - Alpha", "Programming Language :: Python", @@ -56,6 +56,7 @@ dependencies = [ "tensorboard>=2.13.0", "torch>=1.13.1", "torchvision>=0.14.1", + # Added to README for later installation "torchmetrics[image]>=1.0.1", "typing_extensions>=4.4.0", "viser@git+https://github.com/atonderski/viser.git", @@ -72,6 +73,7 @@ dependencies = [ "pandaset@git+https://github.com/scaleapi/pandaset-devkit.git#egg=pandaset&subdirectory=python", "av2==0.2.1", "fastapi[all]==0.110", + "numba==0.57", ] [project.optional-dependencies] @@ -155,7 +157,7 @@ reportMissingImports = "warning" reportMissingTypeStubs = false reportPrivateImportUsage = false -pythonVersion = "3.8" +pythonVersion = "3.10" pythonPlatform = "Linux" [tool.ruff]