Skip to content

Commit 8453746

Browse files
committed
allow dataparser poses to be in float64 for higher precision
1 parent 0724589 commit 8453746

7 files changed

+14
-13
lines changed

nerfstudio/data/dataparsers/ad_dataparser.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,11 @@ def _adjust_times(
310310
def _adjust_poses(self, cameras: Cameras, lidars: Lidars, trajectories: List[Dict]):
311311
"""Determines a new, centered, world coordinate system, and adjusts all poses."""
312312
w2m = _get_world_to_mean_transform(cameras, lidars)
313-
cameras.camera_to_worlds = pose_multiply(w2m, cameras.camera_to_worlds)
314-
lidars.lidar_to_worlds = pose_multiply(w2m, lidars.lidar_to_worlds)
313+
# Cast poses to float32 only after transforming to local frame to avoid precision loss
314+
cameras.camera_to_worlds = pose_multiply(w2m, cameras.camera_to_worlds).to(torch.float32)
315+
lidars.lidar_to_worlds = pose_multiply(w2m, lidars.lidar_to_worlds).to(torch.float32)
315316
for traj in trajectories:
316-
traj["poses"][:, :3] = pose_multiply(w2m, traj["poses"][:, :3])
317+
traj["poses"][:, :3] = pose_multiply(w2m, traj["poses"][:, :3]).to(torch.float32)
317318
return w2m
318319

319320
def _get_train_eval_indices(self, sensors: Union[Cameras, Lidars]) -> Tuple[Tensor, Tensor]:
@@ -617,8 +618,8 @@ def _get_world_to_mean_transform(cameras: Cameras, lidars: Lidars):
617618
m2w = to4x4(select_poses[0:1])[0]
618619
else:
619620
# Otherwise
620-
m2w = torch.from_numpy(_get_mean_pose_from_trajectory(select_trajectory).astype(np.float32))
621-
return torch.linalg.inv(m2w)[:3]
621+
m2w = torch.from_numpy(_get_mean_pose_from_trajectory(select_trajectory))
622+
return torch.linalg.inv(m2w)[:3].to(poses.dtype)
622623

623624

624625
def _empty_cameras():

nerfstudio/data/dataparsers/argoverse2_dataparser.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ class Argoverse2(ADDataParser):
186186
@property
187187
def actor_transform(self) -> torch.Tensor:
188188
"""Argo uses x-forward, so we need to rotate to x-right."""
189-
wlh_to_lwh = np.eye(4, dtype=np.float32)
189+
wlh_to_lwh = np.eye(4)
190190
wlh_to_lwh[:3, :3] = WLH_TO_LWH
191191
return torch.from_numpy(wlh_to_lwh)[:3, :]
192192

@@ -360,11 +360,11 @@ def _read_lidars(self, lidars: Lidars, filepaths: List[Path]) -> List[torch.Tens
360360
assert sweep is not None
361361
uplidar2ego = sweep.ego_SE3_up_lidar
362362
all_lup2w = torch.tensor(
363-
np.array([e2w.compose(uplidar2ego).transform_matrix for e2w in all_ego2w]), dtype=torch.float32
363+
np.array([e2w.compose(uplidar2ego).transform_matrix for e2w in all_ego2w]), dtype=torch.float64
364364
)
365365
downlidar2ego = sweep.ego_SE3_down_lidar
366366
all_ldown2w = torch.tensor(
367-
np.array([e2w.compose(downlidar2ego).transform_matrix for e2w in all_ego2w]), dtype=torch.float32
367+
np.array([e2w.compose(downlidar2ego).transform_matrix for e2w in all_ego2w]), dtype=torch.float64
368368
)
369369
all_times = torch.from_numpy(log_pose_df["timestamp_ns"].to_numpy() / 1e9)
370370

nerfstudio/data/dataparsers/base_dataparser.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class DataparserOutputs:
6868
"""
6969
dataparser_transform: Float[Tensor, "3 4"] = torch.eye(4)[:3, :]
7070
"""Transform applied by the dataparser to the entire scene."""
71-
actor_transform: Float[Tensor, "3 4"] = torch.eye(4, dtype=torch.float32)[:3, :]
71+
actor_transform: Float[Tensor, "3 4"] = torch.eye(4)[:3, :]
7272
"""Transform applied by the dataparser to each actor's local frame."""
7373
dataparser_scale: float = 1.0
7474
"""Scale applied by the dataparser."""

nerfstudio/data/dataparsers/kittimot_dataparser.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class KittiMot(ADDataParser):
107107
@property
108108
def actor_transform(self) -> Tensor:
109109
"""The transform needed to convert the actor poses to our desired format (x-right, y-forward, z-up)."""
110-
return torch.from_numpy(RIGHT_FRONT_UP2RIGHT_DOWN_FRONT)[:3, :]
110+
return torch.from_numpy(RIGHT_FRONT_UP2RIGHT_DOWN_FRONT)
111111

112112
def _get_cameras(self) -> Tuple[Cameras, List[Path]]:
113113
"""Returns camera info and image filenames."""

nerfstudio/data/dataparsers/nuscenes_dataparser.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ class NuScenes(ADDataParser):
185185
@property
186186
def actor_transform(self) -> torch.Tensor:
187187
"""Nuscenes uses x-forward, so we need to rotate to x-right."""
188-
return torch.from_numpy(WLH_TO_LWH)[:3, :]
188+
return torch.from_numpy(WLH_TO_LWH)
189189

190190
def _get_cameras(self) -> Tuple[Cameras, List[Path]]:
191191
if "all" in self.config.cameras:

nerfstudio/data/dataparsers/zod_dataparser.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ class Zod(ADDataParser):
191191
@property
192192
def actor_transform(self) -> torch.Tensor:
193193
"""ZOD uses x-forward, so we need to rotate to x-right."""
194-
return torch.from_numpy(WLH_TO_LWH)[:3, :]
194+
return torch.from_numpy(WLH_TO_LWH)
195195

196196
def _get_lane_shift_sign(self, sequence: str) -> Literal[-1, 1]:
197197
return LANE_SHIFT_SIGN.get(sequence, 1)

nerfstudio/utils/poses.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def interpolate_trajectories(poses, pose_times, query_times, pose_valid_mask=Non
174174
right_time = pose_times[right_idx]
175175
left_time = pose_times[left_idx]
176176
time_diff = right_time - left_time + 1e-6
177-
fraction = (qt - left_time) / time_diff # 0 = all left, 1 = all right
177+
fraction = ((qt - left_time) / time_diff).to(poses.dtype) # 0 = all left, 1 = all right
178178
if clamp_frac:
179179
fraction = fraction.clamp(0.0, 1.0) # clamp to handle out of bounds
180180

0 commit comments

Comments
 (0)