diff --git a/gaussian_splatting/arguments/__init__.py b/gaussian_splatting/arguments/__init__.py index 50110d0e5..3f17a9d9c 100644 --- a/gaussian_splatting/arguments/__init__.py +++ b/gaussian_splatting/arguments/__init__.py @@ -55,7 +55,6 @@ def __init__(self, parser=None, source_path="", sentinel=False): self._images = "images" self._resolution = -1 self._white_background = False - self.data_device = "cuda" self.eval = False super().__init__(parser, "Loading Parameters", sentinel) diff --git a/gaussian_splatting/scene/cameras.py b/gaussian_splatting/scene/cameras.py index c83674f32..4398b3270 100644 --- a/gaussian_splatting/scene/cameras.py +++ b/gaussian_splatting/scene/cameras.py @@ -15,10 +15,20 @@ from gaussian_splatting.utils.graphics_utils import getWorld2View2, getProjectionMatrix class Camera(nn.Module): - def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, - image_name, uid, - trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" - ): + def __init__( + self, + colmap_id, + R, + T, + FoVx, + FoVy, + image, + gt_alpha_mask, + image_name, + uid, + trans=np.array([0.0, 0.0, 0.0]), + scale=1.0 + ): super(Camera, self).__init__() self.uid = uid @@ -29,12 +39,7 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, self.FoVy = FoVy self.image_name = image_name - try: - self.data_device = torch.device(data_device) - except Exception as e: - print(e) - print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) - self.data_device = torch.device("cuda") + self.data_device = torch.device("cuda") self.original_image = image.clamp(0.0, 1.0).to(self.data_device) self.image_width = self.original_image.shape[2] diff --git a/gaussian_splatting/utils/camera_utils.py b/gaussian_splatting/utils/camera_utils.py index e8fd5de2d..0fa4742d4 100644 --- a/gaussian_splatting/utils/camera_utils.py +++ b/gaussian_splatting/utils/camera_utils.py @@ -46,10 +46,17 @@ def loadCam(args, id, cam_info, resolution_scale): if resized_image_rgb.shape[1] == 4: loaded_mask = resized_image_rgb[3:4, ...] - return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, - FoVx=cam_info.FovX, FoVy=cam_info.FovY, - image=gt_image, gt_alpha_mask=loaded_mask, - image_name=cam_info.image_name, uid=id, data_device=args.data_device) + return Camera( + colmap_id=cam_info.uid, + R=cam_info.R, + T=cam_info.T, + FoVx=cam_info.FovX, + FoVy=cam_info.FovY, + image=gt_image, + gt_alpha_mask=loaded_mask, + image_name=cam_info.image_name, + uid=id + ) def cameraList_from_camInfos(cam_infos, resolution_scale, args): camera_list = [] diff --git a/scripts/train_modal.py b/scripts/train_modal.py index df31cace4..a557e9fdd 100644 --- a/scripts/train_modal.py +++ b/scripts/train_modal.py @@ -58,7 +58,6 @@ def __init__(self,): self.images = "images" self.resolution = -1 self.white_background = False - self.data_device = "cuda" self.eval = False class Pipeline():