From 5a1a3c3df5eae32678cc7d3c58f73b95dcbe726b Mon Sep 17 00:00:00 2001 From: Isabelle Bouchard Date: Fri, 15 Mar 2024 11:10:15 -0400 Subject: [PATCH] training() -> Trainer.run() --- gaussian_splatting/training.py | 295 +++++++++++++++++---------------- scripts/train.py | 12 +- scripts/train_modal.py | 7 +- 3 files changed, 166 insertions(+), 148 deletions(-) diff --git a/gaussian_splatting/training.py b/gaussian_splatting/training.py index 0abd407a3..529a26d58 100644 --- a/gaussian_splatting/training.py +++ b/gaussian_splatting/training.py @@ -19,146 +19,160 @@ TENSORBOARD_FOUND = False -def training( - dataset, - opt, - pipe, - testing_iterations=None, - saving_iterations=None, - checkpoint_iterations=None, - checkpoint=None, - debug_from=-1, - quiet=False, - ip="127.0.0.1", - port=6009, - detect_anomaly=False -): - if dataset is None: - dataset = ModelParams(source_path=source_path) - - if pipe is None: - pipe = PipelineParams - - if opt is None: - opt = OptimizationParams() - - if testing_iterations is None: - testing_iterations = [7_000, 30_000] - - if saving_iterations is None: - saving_iterations = [7_000, 30_000] - - if checkpoint_iterations is None: - checkpoint_iterations = [] - - # Initialize system state (RNG) - safe_state(quiet) - - # Start GUI server, configure and run training - network_gui.init(ip, port) - torch.autograd.set_detect_anomaly(detect_anomaly) - - first_iter = 0 - tb_writer = prepare_output_and_logger(dataset) - gaussians = GaussianModel(dataset.sh_degree) - scene = Scene(dataset, gaussians) - gaussians.training_setup(opt) - if checkpoint: - (model_params, first_iter) = torch.load(checkpoint) - gaussians.restore(model_params, opt) - - bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] - background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") - - iter_start = torch.cuda.Event(enable_timing = True) - iter_end = torch.cuda.Event(enable_timing = True) - - viewpoint_stack = None - ema_loss_for_log = 0.0 - progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") - first_iter += 1 - for iteration in range(first_iter, opt.iterations + 1): - if network_gui.conn == None: - network_gui.try_connect() - while network_gui.conn != None: - try: - net_image_bytes = None - custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() - if custom_cam != None: - net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] - net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) - network_gui.send(net_image_bytes, dataset.source_path) - if do_training and ((iteration < int(opt.iterations)) or not keep_alive): - break - except Exception as e: - network_gui.conn = None - - iter_start.record() - - gaussians.update_learning_rate(iteration) - - # Every 1000 its we increase the levels of SH up to a maximum degree - if iteration % 1000 == 0: - gaussians.oneupSHdegree() - - # Pick a random Camera - if not viewpoint_stack: - viewpoint_stack = scene.getTrainCameras().copy() - viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) - - # Render - if (iteration - 1) == debug_from: - pipe.debug = True - - bg = torch.rand((3), device="cuda") if opt.random_background else background - - render_pkg = render(viewpoint_cam, gaussians, pipe, bg) - image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] - - # Loss - gt_image = viewpoint_cam.original_image.cuda() - Ll1 = l1_loss(image, gt_image) - loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) - loss.backward() - - iter_end.record() - - with torch.no_grad(): - # Progress bar - ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log - if iteration % 10 == 0: - progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) - progress_bar.update(10) - if iteration == opt.iterations: - progress_bar.close() - - # Log and save - training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) - if (iteration in saving_iterations): - print("\n[ITER {}] Saving Gaussians".format(iteration)) - scene.save(iteration) - - # Densification - if iteration < opt.densify_until_iter: - # Keep track of max radii in image-space for pruning - gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) - gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) - - if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: - size_threshold = 20 if iteration > opt.opacity_reset_interval else None - gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) - - if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): - gaussians.reset_opacity() - - # Optimizer step - if iteration < opt.iterations: - gaussians.optimizer.step() - gaussians.optimizer.zero_grad(set_to_none = True) - - if (iteration in checkpoint_iterations): - print("\n[ITER {}] Saving Checkpoint".format(iteration)) - torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") +class Trainer: + def __init__( + self, + testing_iterations=None, + saving_iterations=None, + checkpoint_iterations=None, + checkpoint=None, + debug_from=-1, + quiet=False, + ip="127.0.0.1", + port=6009, + detect_anomaly=False + ): + if testing_iterations is None: + testing_iterations = [7_000, 30_000] + self._testing_iterations = testing_iterations + + if saving_iterations is None: + saving_iterations = [7_000, 30_000] + self._saving_iterations = saving_iterations + + if checkpoint_iterations is None: + checkpoint_iterations = [] + self._checkpoint_iterations = checkpoint_iterations + + self._checkpoint = checkpoint + self._debug_from = debug_from + + # Initialize system state (RNG) + safe_state(quiet) + # Start GUI server, configure and run training + network_gui.init(ip, port) + torch.autograd.set_detect_anomaly(detect_anomaly) + + + def run( + self, + dataset, + opt, + pipe, + ): + first_iter = 0 + tb_writer = prepare_output_and_logger(dataset) + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians) + gaussians.training_setup(opt) + if self._checkpoint: + (model_params, first_iter) = torch.load(self._checkpoint) + gaussians.restore(model_params, opt) + + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + iter_start = torch.cuda.Event(enable_timing = True) + iter_end = torch.cuda.Event(enable_timing = True) + + viewpoint_stack = None + ema_loss_for_log = 0.0 + progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") + first_iter += 1 + for iteration in range(first_iter, opt.iterations + 1): + if network_gui.conn == None: + network_gui.try_connect() + while network_gui.conn != None: + try: + net_image_bytes = None + custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() + if custom_cam != None: + net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] + net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) + network_gui.send(net_image_bytes, dataset.source_path) + if do_training and ((iteration < int(opt.iterations)) or not keep_alive): + break + except Exception as e: + network_gui.conn = None + + iter_start.record() + + gaussians.update_learning_rate(iteration) + + # Every 1000 its we increase the levels of SH up to a maximum degree + if iteration % 1000 == 0: + gaussians.oneupSHdegree() + + # Pick a random Camera + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) + + # Render + if (iteration - 1) == self._debug_from: + pipe.debug = True + + bg = torch.rand((3), device="cuda") if opt.random_background else background + + render_pkg = render(viewpoint_cam, gaussians, pipe, bg) + image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] + + # Loss + gt_image = viewpoint_cam.original_image.cuda() + Ll1 = l1_loss(image, gt_image) + loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) + loss.backward() + + iter_end.record() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + if iteration % 10 == 0: + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) + progress_bar.update(10) + if iteration == opt.iterations: + progress_bar.close() + + # Log and save + training_report( + tb_writer, + iteration, + Ll1, + loss, + l1_loss, + iter_start.elapsed_time(iter_end), + self._testing_iterations, + scene, + render, + (pipe, background) + ) + if (iteration in self._saving_iterations): + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(iteration) + + # Densification + if iteration < opt.densify_until_iter: + # Keep track of max radii in image-space for pruning + gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) + gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) + + if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: + size_threshold = 20 if iteration > opt.opacity_reset_interval else None + gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) + + if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): + gaussians.reset_opacity() + + # Optimizer step + if iteration < opt.iterations: + gaussians.optimizer.step() + gaussians.optimizer.zero_grad(set_to_none = True) + + if (iteration in self._checkpoint_iterations): + print("\n[ITER {}] Saving Checkpoint".format(iteration)) + torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") + def prepare_output_and_logger(args): if not args.model_path: @@ -182,6 +196,7 @@ def prepare_output_and_logger(args): print("Tensorboard not available: not logging progress") return tb_writer + def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): if tb_writer: tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) diff --git a/scripts/train.py b/scripts/train.py index 268d0b46d..1f3d3b108 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -11,7 +11,7 @@ import sys from argparse import ArgumentParser from gaussian_splatting.arguments import ModelParams, PipelineParams, OptimizationParams -from gaussian_splatting.training import training +from gaussian_splatting.training import Trainer if __name__ == "__main__": @@ -32,10 +32,7 @@ args = parser.parse_args(sys.argv[1:]) args.save_iterations.append(args.iterations) - training( - dataset=lp.extract(args), - opt=op.extract(args), - pipe=pp.extract(args), + trainer = Trainer( testing_iterations=args.test_iterations, saving_iterations=args.save_iterations, checkpoint_iterations=args.checkpoint_iterations, @@ -46,3 +43,8 @@ port=args.port, detect_anomaly=args.detect_anomaly ) + trainer.run( + dataset=lp.extract(args), + opt=op.extract(args), + pipe=pp.extract(args), + ) diff --git a/scripts/train_modal.py b/scripts/train_modal.py index bc11cb5a2..c63274efd 100644 --- a/scripts/train_modal.py +++ b/scripts/train_modal.py @@ -53,7 +53,7 @@ class Dataset(): def __init__(self,): self.sh_degree = 3 - self.source_path = "/workspace/data/train/" + self.source_path = "/workspace/data/phil_open/5" self.model_path = "" self.images = "images" self.resolution = -1 @@ -97,9 +97,10 @@ def __init__(self): timeout=10800 ) def f(): - from gaussian_splatting.training import training + from gaussian_splatting.training import Trainer - training( + trainer = Trainer() + trainer.run( dataset=Dataset(), opt=Optimization(), pipe=Pipeline()