diff --git a/gaussian_splatting/arguments/__init__.py b/gaussian_splatting/arguments/__init__.py index fa4493f4d..65618ee52 100644 --- a/gaussian_splatting/arguments/__init__.py +++ b/gaussian_splatting/arguments/__init__.py @@ -59,7 +59,6 @@ def __init__(self, parser=None, source_path="", sentinel=False): self._source_path = source_path self._model_path = "" self._images = "images" - self._resolution = -1 self.eval = False super().__init__(parser, "Loading Parameters", sentinel) diff --git a/gaussian_splatting/scene/__init__.py b/gaussian_splatting/scene/__init__.py index af0c4f5ad..53628da4f 100644 --- a/gaussian_splatting/scene/__init__.py +++ b/gaussian_splatting/scene/__init__.py @@ -31,6 +31,7 @@ def __init__( gaussians: GaussianModel, load_iteration=None, shuffle=True, + resolution=-1, resolution_scales=[1.0], ): """b @@ -86,11 +87,11 @@ def __init__( for resolution_scale in resolution_scales: print("Loading Training Cameras") self.train_cameras[resolution_scale] = cameraList_from_camInfos( - scene_info.train_cameras, resolution_scale, args + scene_info.train_cameras, resolution_scale, resolution ) print("Loading Test Cameras") self.test_cameras[resolution_scale] = cameraList_from_camInfos( - scene_info.test_cameras, resolution_scale, args + scene_info.test_cameras, resolution_scale, resolution ) if self.loaded_iter: diff --git a/gaussian_splatting/training.py b/gaussian_splatting/training.py index 35492f444..e9b8cf3b9 100644 --- a/gaussian_splatting/training.py +++ b/gaussian_splatting/training.py @@ -16,6 +16,7 @@ class Trainer: def __init__( self, + resolution=-1, testing_iterations=None, saving_iterations=None, checkpoint_iterations=None, @@ -23,6 +24,8 @@ def __init__( quiet=False, detect_anomaly=False, ): + self._resolution = resolution + if testing_iterations is None: testing_iterations = [7_000, 30_000] self._testing_iterations = testing_iterations @@ -48,7 +51,7 @@ def run( ): first_iter = 0 gaussians = GaussianModel(dataset.sh_degree) - scene = Scene(dataset, gaussians) + scene = Scene(dataset, gaussians, resolution=self._resolution) gaussians.training_setup(opt) if self._checkpoint_path: diff --git a/gaussian_splatting/utils/camera.py b/gaussian_splatting/utils/camera.py index b1f3002a5..cfffa2db4 100644 --- a/gaussian_splatting/utils/camera.py +++ b/gaussian_splatting/utils/camera.py @@ -18,15 +18,15 @@ WARNED = False -def loadCam(args, id, cam_info, resolution_scale): +def load_camera(resolution, cam_id, cam_info, resolution_scale): orig_w, orig_h = cam_info.image.size - if args.resolution in [1, 2, 4, 8]: - resolution = round(orig_w / (resolution_scale * args.resolution)), round( - orig_h / (resolution_scale * args.resolution) + if resolution in [1, 2, 4, 8]: + resolution = round(orig_w / (resolution_scale * resolution)), round( + orig_h / (resolution_scale * resolution) ) else: # should be a type that converts to float - if args.resolution == -1: + if resolution == -1: if orig_w > 1600: global WARNED if not WARNED: @@ -39,7 +39,7 @@ def loadCam(args, id, cam_info, resolution_scale): else: global_down = 1 else: - global_down = orig_w / args.resolution + global_down = orig_w / resolution scale = float(global_down) * float(resolution_scale) resolution = (int(orig_w / scale), int(orig_h / scale)) @@ -61,15 +61,15 @@ def loadCam(args, id, cam_info, resolution_scale): image=gt_image, gt_alpha_mask=loaded_mask, image_name=cam_info.image_name, - uid=id, + uid=cam_id, ) -def cameraList_from_camInfos(cam_infos, resolution_scale, args): +def cameraList_from_camInfos(cam_infos, resolution_scale, resolution): camera_list = [] - for id, c in enumerate(cam_infos): - camera_list.append(loadCam(args, id, c, resolution_scale)) + for cam_id, c in enumerate(cam_infos): + camera_list.append(load_camera(resolution, cam_id, c, resolution_scale)) return camera_list diff --git a/scripts/render.py b/scripts/render.py index 1138a4983..43ee759c0 100644 --- a/scripts/render.py +++ b/scripts/render.py @@ -42,11 +42,21 @@ def render_set(model_path, name, iteration, views, gaussians): def render_sets( - dataset: ModelParams, iteration: int, skip_train: bool, skip_test: bool + dataset: ModelParams, + iteration: int, + skip_train: bool, + skip_test: bool, + resolution: int = -1, ): with torch.no_grad(): gaussians = GaussianModel(dataset.sh_degree) - scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) + scene = Scene( + dataset, + gaussians, + load_iteration=iteration, + shuffle=False, + resolution=resolution, + ) if not skip_train: render_set( diff --git a/scripts/train.py b/scripts/train.py index 23a4e03d3..5e2c07db3 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -29,10 +29,12 @@ parser.add_argument("--quiet", action="store_true") parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) parser.add_argument("--checkpoint_path", type=str, default=None) + parser.add_argument("--resolution", default=-1, type=int) args = parser.parse_args(sys.argv[1:]) args.save_iterations.append(args.iterations) trainer = Trainer( + resolution=args.resolution, testing_iterations=args.test_iterations, saving_iterations=args.save_iterations, checkpoint_iterations=args.checkpoint_iterations, diff --git a/scripts/train_modal.py b/scripts/train_modal.py index 584ba38ff..dda8bb002 100644 --- a/scripts/train_modal.py +++ b/scripts/train_modal.py @@ -45,7 +45,6 @@ def __init__( self.source_path = "/workspace/data/phil_open/5" self.model_path = "" self.images = "images" - self.resolution = -1 self.eval = False