Skip to content

Commit

Permalink
Remove data device
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Mar 15, 2024
1 parent e230133 commit 664b860
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 16 deletions.
1 change: 0 additions & 1 deletion gaussian_splatting/arguments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(self, parser=None, source_path="", sentinel=False):
self._images = "images"
self._resolution = -1
self._white_background = False
self.data_device = "cuda"
self.eval = False
super().__init__(parser, "Loading Parameters", sentinel)

Expand Down
25 changes: 15 additions & 10 deletions gaussian_splatting/scene/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,20 @@
from gaussian_splatting.utils.graphics_utils import getWorld2View2, getProjectionMatrix

class Camera(nn.Module):
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
image_name, uid,
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
):
def __init__(
self,
colmap_id,
R,
T,
FoVx,
FoVy,
image,
gt_alpha_mask,
image_name,
uid,
trans=np.array([0.0, 0.0, 0.0]),
scale=1.0
):
super(Camera, self).__init__()

self.uid = uid
Expand All @@ -29,12 +39,7 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
self.FoVy = FoVy
self.image_name = image_name

try:
self.data_device = torch.device(data_device)
except Exception as e:
print(e)
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
self.data_device = torch.device("cuda")
self.data_device = torch.device("cuda")

self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
self.image_width = self.original_image.shape[2]
Expand Down
15 changes: 11 additions & 4 deletions gaussian_splatting/utils/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,17 @@ def loadCam(args, id, cam_info, resolution_scale):
if resized_image_rgb.shape[1] == 4:
loaded_mask = resized_image_rgb[3:4, ...]

return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
image=gt_image, gt_alpha_mask=loaded_mask,
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
return Camera(
colmap_id=cam_info.uid,
R=cam_info.R,
T=cam_info.T,
FoVx=cam_info.FovX,
FoVy=cam_info.FovY,
image=gt_image,
gt_alpha_mask=loaded_mask,
image_name=cam_info.image_name,
uid=id
)

def cameraList_from_camInfos(cam_infos, resolution_scale, args):
camera_list = []
Expand Down
1 change: 0 additions & 1 deletion scripts/train_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __init__(self,):
self.images = "images"
self.resolution = -1
self.white_background = False
self.data_device = "cuda"
self.eval = False

class Pipeline():
Expand Down

0 comments on commit 664b860

Please sign in to comment.