Skip to content

Commit

Permalink
remove resolution model params
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Mar 18, 2024
1 parent eebb596 commit a655198
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 17 deletions.
1 change: 0 additions & 1 deletion gaussian_splatting/arguments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(self, parser=None, source_path="", sentinel=False):
self._source_path = source_path
self._model_path = ""
self._images = "images"
self._resolution = -1
self.eval = False
super().__init__(parser, "Loading Parameters", sentinel)

Expand Down
5 changes: 3 additions & 2 deletions gaussian_splatting/scene/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
gaussians: GaussianModel,
load_iteration=None,
shuffle=True,
resolution=-1,
resolution_scales=[1.0],
):
"""b
Expand Down Expand Up @@ -86,11 +87,11 @@ def __init__(
for resolution_scale in resolution_scales:
print("Loading Training Cameras")
self.train_cameras[resolution_scale] = cameraList_from_camInfos(
scene_info.train_cameras, resolution_scale, args
scene_info.train_cameras, resolution_scale, resolution
)
print("Loading Test Cameras")
self.test_cameras[resolution_scale] = cameraList_from_camInfos(
scene_info.test_cameras, resolution_scale, args
scene_info.test_cameras, resolution_scale, resolution
)

if self.loaded_iter:
Expand Down
5 changes: 4 additions & 1 deletion gaussian_splatting/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
class Trainer:
def __init__(
self,
resolution=-1,
testing_iterations=None,
saving_iterations=None,
checkpoint_iterations=None,
checkpoint_path=None,
quiet=False,
detect_anomaly=False,
):
self._resolution = resolution

if testing_iterations is None:
testing_iterations = [7_000, 30_000]
self._testing_iterations = testing_iterations
Expand All @@ -48,7 +51,7 @@ def run(
):
first_iter = 0
gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians)
scene = Scene(dataset, gaussians, resolution=self._resolution)
gaussians.training_setup(opt)

if self._checkpoint_path:
Expand Down
20 changes: 10 additions & 10 deletions gaussian_splatting/utils/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
WARNED = False


def loadCam(args, id, cam_info, resolution_scale):
def load_camera(resolution, cam_id, cam_info, resolution_scale):
orig_w, orig_h = cam_info.image.size

if args.resolution in [1, 2, 4, 8]:
resolution = round(orig_w / (resolution_scale * args.resolution)), round(
orig_h / (resolution_scale * args.resolution)
if resolution in [1, 2, 4, 8]:
resolution = round(orig_w / (resolution_scale * resolution)), round(
orig_h / (resolution_scale * resolution)
)
else: # should be a type that converts to float
if args.resolution == -1:
if resolution == -1:
if orig_w > 1600:
global WARNED
if not WARNED:
Expand All @@ -39,7 +39,7 @@ def loadCam(args, id, cam_info, resolution_scale):
else:
global_down = 1
else:
global_down = orig_w / args.resolution
global_down = orig_w / resolution

scale = float(global_down) * float(resolution_scale)
resolution = (int(orig_w / scale), int(orig_h / scale))
Expand All @@ -61,15 +61,15 @@ def loadCam(args, id, cam_info, resolution_scale):
image=gt_image,
gt_alpha_mask=loaded_mask,
image_name=cam_info.image_name,
uid=id,
uid=cam_id,
)


def cameraList_from_camInfos(cam_infos, resolution_scale, args):
def cameraList_from_camInfos(cam_infos, resolution_scale, resolution):
camera_list = []

for id, c in enumerate(cam_infos):
camera_list.append(loadCam(args, id, c, resolution_scale))
for cam_id, c in enumerate(cam_infos):
camera_list.append(load_camera(resolution, cam_id, c, resolution_scale))

return camera_list

Expand Down
14 changes: 12 additions & 2 deletions scripts/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,21 @@ def render_set(model_path, name, iteration, views, gaussians):


def render_sets(
dataset: ModelParams, iteration: int, skip_train: bool, skip_test: bool
dataset: ModelParams,
iteration: int,
skip_train: bool,
skip_test: bool,
resolution: int = -1,
):
with torch.no_grad():
gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
scene = Scene(
dataset,
gaussians,
load_iteration=iteration,
shuffle=False,
resolution=resolution,
)

if not skip_train:
render_set(
Expand Down
2 changes: 2 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
parser.add_argument("--quiet", action="store_true")
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
parser.add_argument("--checkpoint_path", type=str, default=None)
parser.add_argument("--resolution", default=-1, type=int)
args = parser.parse_args(sys.argv[1:])
args.save_iterations.append(args.iterations)

trainer = Trainer(
resolution=args.resolution,
testing_iterations=args.test_iterations,
saving_iterations=args.save_iterations,
checkpoint_iterations=args.checkpoint_iterations,
Expand Down
1 change: 0 additions & 1 deletion scripts/train_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
self.source_path = "/workspace/data/phil_open/5"
self.model_path = ""
self.images = "images"
self.resolution = -1
self.eval = False


Expand Down

0 comments on commit a655198

Please sign in to comment.