@@ -310,10 +310,11 @@ def _adjust_times(
310
310
def _adjust_poses (self , cameras : Cameras , lidars : Lidars , trajectories : List [Dict ]):
311
311
"""Determines a new, centered, world coordinate system, and adjusts all poses."""
312
312
w2m = _get_world_to_mean_transform (cameras , lidars )
313
- cameras .camera_to_worlds = pose_multiply (w2m , cameras .camera_to_worlds )
314
- lidars .lidar_to_worlds = pose_multiply (w2m , lidars .lidar_to_worlds )
313
+ # Cast poses to float32 only after transforming to local frame to avoid precision loss
314
+ cameras .camera_to_worlds = pose_multiply (w2m , cameras .camera_to_worlds ).to (torch .float32 )
315
+ lidars .lidar_to_worlds = pose_multiply (w2m , lidars .lidar_to_worlds ).to (torch .float32 )
315
316
for traj in trajectories :
316
- traj ["poses" ][:, :3 ] = pose_multiply (w2m , traj ["poses" ][:, :3 ])
317
+ traj ["poses" ][:, :3 ] = pose_multiply (w2m , traj ["poses" ][:, :3 ]). to ( torch . float32 )
317
318
return w2m
318
319
319
320
def _get_train_eval_indices (self , sensors : Union [Cameras , Lidars ]) -> Tuple [Tensor , Tensor ]:
@@ -617,8 +618,8 @@ def _get_world_to_mean_transform(cameras: Cameras, lidars: Lidars):
617
618
m2w = to4x4 (select_poses [0 :1 ])[0 ]
618
619
else :
619
620
# Otherwise
620
- m2w = torch .from_numpy (_get_mean_pose_from_trajectory (select_trajectory ). astype ( np . float32 ) )
621
- return torch .linalg .inv (m2w )[:3 ]
621
+ m2w = torch .from_numpy (_get_mean_pose_from_trajectory (select_trajectory ))
622
+ return torch .linalg .inv (m2w )[:3 ]. to ( poses . dtype )
622
623
623
624
624
625
def _empty_cameras ():
0 commit comments