diff --git a/compress.py b/compress.py index c571820..69e2cbc 100644 --- a/compress.py +++ b/compress.py @@ -28,6 +28,7 @@ from utils.image_utils import psnr from utils.loss_utils import ssim +from utils.camera_utils import LazyCameraLoader def unique_output_folder(): if os.getenv("OAR_JOB_ID"): @@ -72,6 +73,10 @@ def calc_importance( loss.backward() num_pixels += rendering.shape[1]*rendering.shape[2] + # Free up memory + del camera + torch.cuda.empty_cache() + importance = torch.cat( [gaussians._features_dc.grad, gaussians._features_rest.grad], 1, diff --git a/finetune.py b/finetune.py index e431c5c..41ebdb4 100644 --- a/finetune.py +++ b/finetune.py @@ -24,7 +24,6 @@ def finetune(scene: Scene, dataset, opt, comp, pipe, testing_iterations, debug_f scene.gaussians.training_setup(opt) scene.gaussians.update_learning_rate(first_iter) - viewpoint_stack = None ema_loss_for_log = 0.0 progress_bar = tqdm(range(first_iter, max_iter), desc="Training progress") first_iter += 1 @@ -32,9 +31,7 @@ def finetune(scene: Scene, dataset, opt, comp, pipe, testing_iterations, debug_f iter_start.record() # Pick a random Camera - if not viewpoint_stack: - viewpoint_stack = scene.getTrainCameras().copy() - viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1)) + viewpoint_cam = next(iter(scene.getTrainCameras())) # Render if (iteration - 1) == debug_from: @@ -55,6 +52,10 @@ def finetune(scene: Scene, dataset, opt, comp, pipe, testing_iterations, debug_f ) loss.backward() + # Free up memory + del viewpoint_cam + torch.cuda.empty_cache() + iter_end.record() scene.gaussians.update_learning_rate(iteration) diff --git a/scene/__init__.py b/scene/__init__.py index be4e04b..2e5a213 100644 --- a/scene/__init__.py +++ b/scene/__init__.py @@ -16,7 +16,7 @@ from scene.dataset_readers import sceneLoadTypeCallbacks from scene.gaussian_model import GaussianModel from arguments import ModelParams -from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON +from utils.camera_utils import LazyCameraLoader, camera_to_JSON from glob import glob @@ -91,13 +91,13 @@ def __init__( for resolution_scale in resolution_scales: print("Loading Training Cameras") - self.train_cameras[resolution_scale] = cameraList_from_camInfos( + self.train_cameras[resolution_scale] = LazyCameraLoader( scene_info.train_cameras, resolution_scale, args ) print("Loading Test Cameras") - self.test_cameras[resolution_scale] = cameraList_from_camInfos( + self.test_cameras[resolution_scale] = LazyCameraLoader( scene_info.test_cameras, resolution_scale, args - ) + ) if self.loaded_iter: self.gaussians.load( diff --git a/utils/camera_utils.py b/utils/camera_utils.py index 1a54d0a..23585e4 100644 --- a/utils/camera_utils.py +++ b/utils/camera_utils.py @@ -11,11 +11,33 @@ from scene.cameras import Camera import numpy as np +import torch from utils.general_utils import PILtoTorch from utils.graphics_utils import fov2focal WARNED = False +class LazyCameraLoader: + def __init__(self, cam_infos, resolution_scale, args): + self.cam_infos = cam_infos + self.resolution_scale = resolution_scale + self.args = args + + def get_camera(self, index): + cam_info = self.cam_infos[index] + return loadCam(self.args, index, cam_info, self.resolution_scale) + + def __len__(self): + return len(self.cam_infos) + + def __getitem__(self, index): + return self.get_camera(index) + + def __iter__(self): + for index in range(len(self.cam_infos)): + yield self.get_camera(index) + torch.cuda.empty_cache() + def loadCam(args, id, cam_info, resolution_scale): orig_w, orig_h = cam_info.image.size @@ -52,12 +74,7 @@ def loadCam(args, id, cam_info, resolution_scale): image_name=cam_info.image_name, uid=id, data_device=args.data_device) def cameraList_from_camInfos(cam_infos, resolution_scale, args): - camera_list = [] - - for id, c in enumerate(cam_infos): - camera_list.append(loadCam(args, id, c, resolution_scale)) - - return camera_list + return LazyCameraLoader(cam_infos, resolution_scale, args) def camera_to_JSON(id, camera : Camera): Rt = np.zeros((4, 4))