-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add a way to orient based on main direction of the surface #3364
Open
hoanhle
wants to merge
8
commits into
nerfstudio-project:main
Choose a base branch
from
hoanhle:feature/implement_a_new_orientation_method
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+133
−21
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
34bf44e
add a way to orient based on main direction of the flat surface
hoanhle dde9d28
Merge branch 'main' into feature/implement_a_new_orientation_method
jb-ye 50863bf
add a way to align main flat surface
hoanhle 39236ad
use default param for scene box
hoanhle 72006e1
annotate _load_3D_points output
hoanhle d0c97f2
change arg type
hoanhle a3fceab
add missing types
hoanhle 28b8807
Merge branch 'main' into feature/implement_a_new_orientation_method
hoanhle File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,11 +17,13 @@ | |
|
||
from dataclasses import dataclass, field | ||
from pathlib import Path | ||
from typing import Literal, Optional, Tuple, Type | ||
from typing import Dict, Literal, Optional, Tuple, Type | ||
|
||
import numpy as np | ||
import torch | ||
from jaxtyping import Float | ||
from PIL import Image | ||
from torch import Tensor | ||
|
||
from nerfstudio.cameras import camera_utils | ||
from nerfstudio.cameras.cameras import CAMERA_MODEL_TO_TYPE, Cameras, CameraType | ||
|
@@ -53,8 +55,16 @@ class NerfstudioDataParserConfig(DataParserConfig): | |
"""How much to downscale images. If not set, images are chosen such that the max dimension is <1600px.""" | ||
scene_scale: float = 1.0 | ||
"""How much to scale the region of interest by.""" | ||
orientation_method: Literal["pca", "up", "vertical", "none"] = "up" | ||
orientation_method: Literal[ | ||
"pca", | ||
"up", | ||
"vertical", | ||
"align", | ||
"none", | ||
] = "vertical" | ||
"""The method to use for orientation.""" | ||
target_normal: Tuple[float, float, float] = (1.0, 0.0, 0.0) | ||
"""The normal vector to align the scene to, represented as a tuple of floats.""" | ||
center_method: Literal["poses", "focus", "none"] = "poses" | ||
"""The method to use to center the poses.""" | ||
auto_scale_poses: bool = True | ||
|
@@ -232,6 +242,8 @@ def _generate_dataparser_outputs(self, split="train"): | |
CONSOLE.log(f"[yellow] Dataset is overriding orientation method to {orientation_method}") | ||
else: | ||
orientation_method = self.config.orientation_method | ||
if orientation_method == "align": | ||
orientation_method = "up" | ||
|
||
poses = torch.from_numpy(np.array(poses).astype(np.float32)) | ||
poses, transform_matrix = camera_utils.auto_orient_and_center_poses( | ||
|
@@ -258,6 +270,7 @@ def _generate_dataparser_outputs(self, split="train"): | |
|
||
# in x,y,z order | ||
# assumes that the scene is centered at the origin | ||
# _ = self.config.scene_scale | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. delete this comment? |
||
aabb_scale = self.config.scene_scale | ||
scene_box = SceneBox( | ||
aabb=torch.tensor( | ||
|
@@ -298,22 +311,6 @@ def _generate_dataparser_outputs(self, split="train"): | |
if (camera_type in [CameraType.FISHEYE, CameraType.FISHEYE624]) and (fisheye_crop_radius is not None): | ||
metadata["fisheye_crop_radius"] = fisheye_crop_radius | ||
|
||
cameras = Cameras( | ||
fx=fx, | ||
fy=fy, | ||
cx=cx, | ||
cy=cy, | ||
distortion_params=distortion_params, | ||
height=height, | ||
width=width, | ||
camera_to_worlds=poses[:, :3, :4], | ||
camera_type=camera_type, | ||
metadata=metadata, | ||
) | ||
|
||
assert self.downscale_factor is not None | ||
cameras.rescale_output_resolution(scaling_factor=1.0 / self.downscale_factor) | ||
|
||
# The naming is somewhat confusing, but: | ||
# - transform_matrix contains the transformation to dataparser output coordinates from saved coordinates. | ||
# - dataparser_transform_matrix contains the transformation to dataparser output coordinates from original data coordinates. | ||
|
@@ -348,6 +345,9 @@ def _generate_dataparser_outputs(self, split="train"): | |
except AttributeError: | ||
self.prompted_user = False | ||
|
||
alignment_matrix = None | ||
sparse_points = None | ||
|
||
# Load 3D points | ||
if self.config.load_3D_points: | ||
if "ply_file_path" in meta: | ||
|
@@ -399,10 +399,46 @@ def _generate_dataparser_outputs(self, split="train"): | |
|
||
if ply_file_path: | ||
sparse_points = self._load_3D_points(ply_file_path, transform_matrix, scale_factor) | ||
jb-ye marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if sparse_points is not None: | ||
metadata.update(sparse_points) | ||
|
||
if sparse_points is not None and self.config.orientation_method == "align": | ||
target_normal_tensor = torch.tensor(self.config.target_normal, dtype=torch.float32) | ||
points3D_xyz = sparse_points["points3D_xyz"] | ||
aligned_points3D, alignment_matrix = self._align_points_to_target_plane( | ||
points3D_xyz, target_normal_tensor | ||
) | ||
sparse_points["points3D_xyz"] = aligned_points3D[:, :3] | ||
self.prompted_user = True | ||
|
||
if alignment_matrix is not None: | ||
num_poses = poses.shape[0] | ||
bottom_row = torch.tensor([0, 0, 0, 1], dtype=torch.float32).unsqueeze(0).expand(num_poses, -1, -1) | ||
poses_homogeneous = torch.cat([poses, bottom_row], dim=1) # Shape: (num_poses, 4, 4) | ||
|
||
poses = alignment_matrix @ poses_homogeneous | ||
dataparser_transform_matrix = torch.cat( | ||
[dataparser_transform_matrix, torch.tensor([0, 0, 0, 1], dtype=torch.float32).unsqueeze(0)], dim=0 | ||
) | ||
dataparser_transform_matrix = alignment_matrix @ dataparser_transform_matrix | ||
|
||
cameras = Cameras( | ||
fx=fx, | ||
fy=fy, | ||
cx=cx, | ||
cy=cy, | ||
distortion_params=distortion_params, | ||
height=height, | ||
width=width, | ||
camera_to_worlds=poses[:, :3, :4], | ||
camera_type=camera_type, | ||
metadata=metadata, | ||
) | ||
|
||
assert self.downscale_factor is not None | ||
cameras.rescale_output_resolution(scaling_factor=1.0 / self.downscale_factor) | ||
|
||
if sparse_points is not None: | ||
metadata.update(sparse_points) | ||
|
||
dataparser_outputs = DataparserOutputs( | ||
image_filenames=image_filenames, | ||
cameras=cameras, | ||
|
@@ -419,7 +455,9 @@ def _generate_dataparser_outputs(self, split="train"): | |
) | ||
return dataparser_outputs | ||
|
||
def _load_3D_points(self, ply_file_path: Path, transform_matrix: torch.Tensor, scale_factor: float): | ||
def _load_3D_points( | ||
self, ply_file_path: Path, transform_matrix: torch.Tensor, scale_factor: float | ||
) -> Optional[Dict[str, torch.Tensor]]: | ||
"""Loads point clouds positions and colors from .ply | ||
|
||
Args: | ||
|
@@ -458,6 +496,80 @@ def _load_3D_points(self, ply_file_path: Path, transform_matrix: torch.Tensor, s | |
} | ||
return out | ||
|
||
@staticmethod | ||
def _align_points_to_target_plane( | ||
points: torch.Tensor, | ||
target_normal: Float[Tensor, "3"], | ||
threshold: float = 1.0, | ||
max_iterations: int = 5, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
"""Aligns a set of 3D points (in homogeneous coordinates) to a target plane defined by its normal vector. | ||
|
||
Args: | ||
points: A torch tensor of shape (n, 4) representing the 3D points in homogeneous coordinates. | ||
target_normal: A torch tensor of shape (3, ) representing the normal vector of the target plane. | ||
threshold: The distance threshold for identifying inliers. | ||
max_iterations: The maximum number of iterations for refining inliers. | ||
|
||
Returns: | ||
A tuple containing: | ||
- aligned_points: The 3D points aligned to the target plane as a torch tensor of shape (n, 4). | ||
- alignment_matrix: The 4x4 alignment matrix used for alignment. | ||
""" | ||
|
||
def filter_outliers(points_xyz, threshold, max_iterations): | ||
inlier_mask = torch.ones(points_xyz.size(0), dtype=torch.bool) | ||
|
||
for _ in range(max_iterations): | ||
current_inliers = points_xyz[inlier_mask] | ||
centroid = torch.mean(current_inliers, dim=0) | ||
|
||
centered_points = current_inliers - centroid | ||
_, _, vh = torch.linalg.svd(centered_points) | ||
|
||
normal = vh[-1] | ||
|
||
distances = torch.abs((points_xyz - centroid) @ normal) | ||
new_inlier_mask = distances < threshold | ||
|
||
inlier_mask = inlier_mask & new_inlier_mask | ||
|
||
threshold *= 0.9 # Reduce threshold for more aggressiveness | ||
|
||
return inlier_mask | ||
|
||
points_xyz = points[:, :3] # Shape: (n, 3) | ||
inlier_mask = filter_outliers(points_xyz, threshold, max_iterations) | ||
inliers = points_xyz[inlier_mask] | ||
|
||
# Calculate the centroid using only inliers | ||
centroid = torch.mean(inliers, dim=0) # Shape: (3,) | ||
|
||
# Center the inlier points around the centroid | ||
centered_inliers = inliers - centroid # Shape: (m, 3) where m <= n | ||
|
||
# Perform SVD on inliers to find the normal | ||
_, _, vh = torch.linalg.svd(centered_inliers) # vh shape: (3, 3) | ||
normal = vh[-1] # Shape: (3,) | ||
|
||
# Use the provided helper function to get the rotation matrix | ||
rotation_matrix = camera_utils.rotation_matrix_between(normal, target_normal) | ||
|
||
# Create the 4x4 alignment matrix | ||
alignment_matrix = torch.eye(4, dtype=torch.float32) # Shape: (4, 4) | ||
alignment_matrix[:3, :3] = rotation_matrix # Insert rotation part | ||
alignment_matrix[:3, 3] = -rotation_matrix @ centroid # Apply translation | ||
|
||
# Ensure the points are in homogeneous coordinates | ||
if points.shape[1] == 3: | ||
points = torch.cat([points, torch.ones((points.shape[0], 1), dtype=torch.float32)], dim=1) # Shape: (n, 4) | ||
|
||
# Apply the alignment transformation to all points | ||
aligned_points = alignment_matrix @ points.T # Shape: (4, n) | ||
aligned_points = aligned_points.T # Shape: (n, 4) | ||
|
||
return aligned_points, alignment_matrix | ||
|
||
def _get_fname(self, filepath: Path, data_dir: Path, downsample_folder_prefix="images_") -> Path: | ||
"""Get the filename of the image file. | ||
downsample_folder_prefix can be used to point to auxiliary image data, e.g. masks | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should you keep the default method be "up"?