Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a way to orient based on main direction of the surface #3364

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
154 changes: 133 additions & 21 deletions nerfstudio/data/dataparsers/nerfstudio_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal, Optional, Tuple, Type
from typing import Dict, Literal, Optional, Tuple, Type

import numpy as np
import torch
from jaxtyping import Float
from PIL import Image
from torch import Tensor

from nerfstudio.cameras import camera_utils
from nerfstudio.cameras.cameras import CAMERA_MODEL_TO_TYPE, Cameras, CameraType
Expand Down Expand Up @@ -53,8 +55,16 @@ class NerfstudioDataParserConfig(DataParserConfig):
"""How much to downscale images. If not set, images are chosen such that the max dimension is <1600px."""
scene_scale: float = 1.0
"""How much to scale the region of interest by."""
orientation_method: Literal["pca", "up", "vertical", "none"] = "up"
orientation_method: Literal[
"pca",
"up",
"vertical",
"align",
"none",
] = "vertical"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you keep the default method be "up"?

"""The method to use for orientation."""
target_normal: Tuple[float, float, float] = (1.0, 0.0, 0.0)
"""The normal vector to align the scene to, represented as a tuple of floats."""
center_method: Literal["poses", "focus", "none"] = "poses"
"""The method to use to center the poses."""
auto_scale_poses: bool = True
Expand Down Expand Up @@ -232,6 +242,8 @@ def _generate_dataparser_outputs(self, split="train"):
CONSOLE.log(f"[yellow] Dataset is overriding orientation method to {orientation_method}")
else:
orientation_method = self.config.orientation_method
if orientation_method == "align":
orientation_method = "up"

poses = torch.from_numpy(np.array(poses).astype(np.float32))
poses, transform_matrix = camera_utils.auto_orient_and_center_poses(
Expand All @@ -258,6 +270,7 @@ def _generate_dataparser_outputs(self, split="train"):

# in x,y,z order
# assumes that the scene is centered at the origin
# _ = self.config.scene_scale
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this comment?

aabb_scale = self.config.scene_scale
scene_box = SceneBox(
aabb=torch.tensor(
Expand Down Expand Up @@ -298,22 +311,6 @@ def _generate_dataparser_outputs(self, split="train"):
if (camera_type in [CameraType.FISHEYE, CameraType.FISHEYE624]) and (fisheye_crop_radius is not None):
metadata["fisheye_crop_radius"] = fisheye_crop_radius

cameras = Cameras(
fx=fx,
fy=fy,
cx=cx,
cy=cy,
distortion_params=distortion_params,
height=height,
width=width,
camera_to_worlds=poses[:, :3, :4],
camera_type=camera_type,
metadata=metadata,
)

assert self.downscale_factor is not None
cameras.rescale_output_resolution(scaling_factor=1.0 / self.downscale_factor)

# The naming is somewhat confusing, but:
# - transform_matrix contains the transformation to dataparser output coordinates from saved coordinates.
# - dataparser_transform_matrix contains the transformation to dataparser output coordinates from original data coordinates.
Expand Down Expand Up @@ -348,6 +345,9 @@ def _generate_dataparser_outputs(self, split="train"):
except AttributeError:
self.prompted_user = False

alignment_matrix = None
sparse_points = None

# Load 3D points
if self.config.load_3D_points:
if "ply_file_path" in meta:
Expand Down Expand Up @@ -399,10 +399,46 @@ def _generate_dataparser_outputs(self, split="train"):

if ply_file_path:
sparse_points = self._load_3D_points(ply_file_path, transform_matrix, scale_factor)
jb-ye marked this conversation as resolved.
Show resolved Hide resolved
if sparse_points is not None:
metadata.update(sparse_points)

if sparse_points is not None and self.config.orientation_method == "align":
target_normal_tensor = torch.tensor(self.config.target_normal, dtype=torch.float32)
points3D_xyz = sparse_points["points3D_xyz"]
aligned_points3D, alignment_matrix = self._align_points_to_target_plane(
points3D_xyz, target_normal_tensor
)
sparse_points["points3D_xyz"] = aligned_points3D[:, :3]
self.prompted_user = True

if alignment_matrix is not None:
num_poses = poses.shape[0]
bottom_row = torch.tensor([0, 0, 0, 1], dtype=torch.float32).unsqueeze(0).expand(num_poses, -1, -1)
poses_homogeneous = torch.cat([poses, bottom_row], dim=1) # Shape: (num_poses, 4, 4)

poses = alignment_matrix @ poses_homogeneous
dataparser_transform_matrix = torch.cat(
[dataparser_transform_matrix, torch.tensor([0, 0, 0, 1], dtype=torch.float32).unsqueeze(0)], dim=0
)
dataparser_transform_matrix = alignment_matrix @ dataparser_transform_matrix

cameras = Cameras(
fx=fx,
fy=fy,
cx=cx,
cy=cy,
distortion_params=distortion_params,
height=height,
width=width,
camera_to_worlds=poses[:, :3, :4],
camera_type=camera_type,
metadata=metadata,
)

assert self.downscale_factor is not None
cameras.rescale_output_resolution(scaling_factor=1.0 / self.downscale_factor)

if sparse_points is not None:
metadata.update(sparse_points)

dataparser_outputs = DataparserOutputs(
image_filenames=image_filenames,
cameras=cameras,
Expand All @@ -419,7 +455,9 @@ def _generate_dataparser_outputs(self, split="train"):
)
return dataparser_outputs

def _load_3D_points(self, ply_file_path: Path, transform_matrix: torch.Tensor, scale_factor: float):
def _load_3D_points(
self, ply_file_path: Path, transform_matrix: torch.Tensor, scale_factor: float
) -> Optional[Dict[str, torch.Tensor]]:
"""Loads point clouds positions and colors from .ply

Args:
Expand Down Expand Up @@ -458,6 +496,80 @@ def _load_3D_points(self, ply_file_path: Path, transform_matrix: torch.Tensor, s
}
return out

@staticmethod
def _align_points_to_target_plane(
points: torch.Tensor,
target_normal: Float[Tensor, "3"],
threshold: float = 1.0,
max_iterations: int = 5,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Aligns a set of 3D points (in homogeneous coordinates) to a target plane defined by its normal vector.

Args:
points: A torch tensor of shape (n, 4) representing the 3D points in homogeneous coordinates.
target_normal: A torch tensor of shape (3, ) representing the normal vector of the target plane.
threshold: The distance threshold for identifying inliers.
max_iterations: The maximum number of iterations for refining inliers.

Returns:
A tuple containing:
- aligned_points: The 3D points aligned to the target plane as a torch tensor of shape (n, 4).
- alignment_matrix: The 4x4 alignment matrix used for alignment.
"""

def filter_outliers(points_xyz, threshold, max_iterations):
inlier_mask = torch.ones(points_xyz.size(0), dtype=torch.bool)

for _ in range(max_iterations):
current_inliers = points_xyz[inlier_mask]
centroid = torch.mean(current_inliers, dim=0)

centered_points = current_inliers - centroid
_, _, vh = torch.linalg.svd(centered_points)

normal = vh[-1]

distances = torch.abs((points_xyz - centroid) @ normal)
new_inlier_mask = distances < threshold

inlier_mask = inlier_mask & new_inlier_mask

threshold *= 0.9 # Reduce threshold for more aggressiveness

return inlier_mask

points_xyz = points[:, :3] # Shape: (n, 3)
inlier_mask = filter_outliers(points_xyz, threshold, max_iterations)
inliers = points_xyz[inlier_mask]

# Calculate the centroid using only inliers
centroid = torch.mean(inliers, dim=0) # Shape: (3,)

# Center the inlier points around the centroid
centered_inliers = inliers - centroid # Shape: (m, 3) where m <= n

# Perform SVD on inliers to find the normal
_, _, vh = torch.linalg.svd(centered_inliers) # vh shape: (3, 3)
normal = vh[-1] # Shape: (3,)

# Use the provided helper function to get the rotation matrix
rotation_matrix = camera_utils.rotation_matrix_between(normal, target_normal)

# Create the 4x4 alignment matrix
alignment_matrix = torch.eye(4, dtype=torch.float32) # Shape: (4, 4)
alignment_matrix[:3, :3] = rotation_matrix # Insert rotation part
alignment_matrix[:3, 3] = -rotation_matrix @ centroid # Apply translation

# Ensure the points are in homogeneous coordinates
if points.shape[1] == 3:
points = torch.cat([points, torch.ones((points.shape[0], 1), dtype=torch.float32)], dim=1) # Shape: (n, 4)

# Apply the alignment transformation to all points
aligned_points = alignment_matrix @ points.T # Shape: (4, n)
aligned_points = aligned_points.T # Shape: (n, 4)

return aligned_points, alignment_matrix

def _get_fname(self, filepath: Path, data_dir: Path, downsample_folder_prefix="images_") -> Path:
"""Get the filename of the image file.
downsample_folder_prefix can be used to point to auxiliary image data, e.g. masks
Expand Down
Loading