From 1660517a3ae065a69966d7870d702c1d42f60400 Mon Sep 17 00:00:00 2001 From: Isabelle Bouchard Date: Tue, 26 Mar 2024 18:27:01 -0400 Subject: [PATCH] Added global training --- gaussian_splatting/global_trainer.py | 198 ++++++++++++++++++ .../local_transformation_trainer.py | 5 +- .../utils/affine_transformation.py | 7 - scripts/train_colmap_free.py | 61 +++++- 4 files changed, 255 insertions(+), 16 deletions(-) create mode 100644 gaussian_splatting/global_trainer.py delete mode 100644 gaussian_splatting/utils/affine_transformation.py diff --git a/gaussian_splatting/global_trainer.py b/gaussian_splatting/global_trainer.py new file mode 100644 index 000000000..6acb75ae1 --- /dev/null +++ b/gaussian_splatting/global_trainer.py @@ -0,0 +1,198 @@ +import os +import uuid +from random import randint + +import torch +from tqdm import tqdm + +from gaussian_splatting.model import GaussianModel +from gaussian_splatting.optimizer import Optimizer +from gaussian_splatting.render import render +from gaussian_splatting.utils.general import safe_state +from gaussian_splatting.utils.image import psnr +from gaussian_splatting.utils.loss import l1_loss, ssim + + +class GlobalTrainer: + def __init__(self, gaussian_model, cameras, iterations: int = 1000): + self._model_path = self._prepare_model_path() + + self.gaussian_model = gaussian_model + self.cameras = cameras + + self.optimizer = Optimizer(self.gaussian_model) + + self._debug = False + + self._iterations = iterations + self._testing_iterations = [iterations, 7000, 30000] + self._saving_iterations = [iterations - 1, 7000, 30000] + self._checkpoint_iterations = [] + + # Loss function + self._lambda_dssim = 0.2 + + # Densification and pruning + self._opacity_reset_interval = 3000 + self._min_opacity = 0.005 + self._max_screen_size = 20 + self._percent_dense = 0.01 + self._densification_interval = 100 + self._densification_iteration_start = 500 + self._densification_iteration_stop = 15000 + self._densification_grad_threshold = 0.0002 + + safe_state() + + def add_camera(self, camera): + self.cameras.append(camera) + + def run(self): + first_iter = 0 + + ema_loss_for_log = 0.0 + cameras = None + progress_bar = tqdm( + range(first_iter, self._iterations), desc="Training progress" + ) + first_iter += 1 + for iteration in range(first_iter, self._iterations + 1): + self.optimizer.update_learning_rate(iteration) + + # Every 1000 its we increase the levels of SH up to a maximum degree + if iteration % 1000 == 0: + self.gaussian_model.oneupSHdegree() + + # Pick a random camera + if not cameras: + cameras = self.cameras.copy() + camera = cameras.pop(randint(0, len(cameras) - 1)) + + # Render image + rendered_image, viewspace_point_tensor, visibility_filter, radii = render( + camera, self.gaussian_model + ) + + # Loss + gt_image = camera.original_image.cuda() + Ll1 = l1_loss(rendered_image, gt_image) + loss = (1.0 - self._lambda_dssim) * Ll1 + self._lambda_dssim * ( + 1.0 - ssim(rendered_image, gt_image) + ) + + loss.backward() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + if iteration % 10 == 0: + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) + progress_bar.update(10) + if iteration == self._iterations: + progress_bar.close() + + if iteration in self._saving_iterations: + print("\n[ITER {}] Saving Gaussians".format(iteration)) + point_cloud_path = os.path.join( + self._model_path, "point_cloud/iteration_{}".format(iteration) + ) + self.gaussian_model.save_ply( + os.path.join(point_cloud_path, "point_cloud.ply") + ) + + # Densification + if iteration < self._densification_iteration_stop: + self.gaussian_model.update_stats( + viewspace_point_tensor, visibility_filter, radii + ) + + if ( + iteration >= self._densification_iteration_start + and iteration % self._densification_interval == 0 + ): + self._densify_and_prune( + iteration > self._opacity_reset_interval + ) + + # Reset opacity interval + if iteration % self._opacity_reset_interval == 0: + self._reset_opacity() + + # Optimizer step + if iteration < self._iterations: + self.optimizer.step() + self.optimizer.zero_grad(set_to_none=True) + + def _prepare_model_path(self): + unique_str = str(uuid.uuid4()) + model_path = os.path.join("./output/", unique_str[0:10]) + + # Set up output folder + print("Output folder: {}".format(model_path)) + os.makedirs(model_path, exist_ok=True) + + return model_path + + def _densify_and_prune(self, prune_big_points): + # Clone large gaussian in over-reconstruction areas + self._clone_points() + # Split small gaussians in under-construction areas. + self._split_points() + + # Prune transparent and large gaussians. + prune_mask = (self.gaussian_model.get_opacity < self._min_opacity).squeeze() + if prune_big_points: + # Viewspace + big_points_vs = self.gaussian_model.max_radii2D > self._max_screen_size + # World space + big_points_ws = ( + self.gaussian_model.get_scaling.max(dim=1).values + > 0.1 * self.gaussian_model.camera_extent + ) + prune_mask = torch.logical_or( + torch.logical_or(prune_mask, big_points_vs), big_points_ws + ) + if self._debug: + print(f"Pruning: {prune_mask.sum().item()} points.") + self._prune_points(valid_mask=~prune_mask) + + torch.cuda.empty_cache() + + def _split_points(self): + new_points, split_mask = self.gaussian_model.split_points( + self._densification_grad_threshold, self._percent_dense + ) + self._concatenate_points(new_points) + + prune_mask = torch.cat( + ( + split_mask, + torch.zeros(2 * split_mask.sum(), device="cuda", dtype=bool), + ) + ) + if self._debug: + print(f"Densification: split {split_mask.sum().item()} points.") + self._prune_points(valid_mask=~prune_mask) + + def _clone_points(self): + new_points, clone_mask = self.gaussian_model.clone_points( + self._densification_grad_threshold, self._percent_dense + ) + if self._debug: + print(f"Densification: clone {clone_mask.sum().item()} points.") + self._concatenate_points(new_points) + + def _reset_opacity(self): + new_opacity = self.gaussian_model.reset_opacity() + optimizable_tensors = self.optimizer.replace_points(new_opacity, "opacity") + self.gaussian_model.set_optimizable_tensors(optimizable_tensors) + + def _prune_points(self, valid_mask): + optimizable_tensors = self.optimizer.prune_points(valid_mask) + self.gaussian_model.set_optimizable_tensors(optimizable_tensors) + self.gaussian_model.mask_stats(valid_mask) + + def _concatenate_points(self, new_tensors): + optimizable_tensors = self.optimizer.concatenate_points(new_tensors) + self.gaussian_model.set_optimizable_tensors(optimizable_tensors) + self.gaussian_model.reset_stats() diff --git a/gaussian_splatting/local_transformation_trainer.py b/gaussian_splatting/local_transformation_trainer.py index 72814f2cf..7bb493f7d 100644 --- a/gaussian_splatting/local_transformation_trainer.py +++ b/gaussian_splatting/local_transformation_trainer.py @@ -105,6 +105,7 @@ def run(self): progress_bar = tqdm(range(self._iterations), desc="Transformation") best_loss, best_iteration, losses = None, 0, [] + best_xyz = None for iteration in range(self._iterations): xyz = self.transformation_model(self.xyz) self.gaussian_model.set_optimizable_tensors({"xyz": xyz}) @@ -131,6 +132,7 @@ def run(self): if best_loss is None or best_loss > loss: best_loss = loss.cpu().item() best_iteration = iteration + best_xyz = xyz losses.append(loss.cpu().item()) loss.backward() @@ -148,9 +150,10 @@ def run(self): print( f"Training done. Best loss = {best_loss:.{5}f} at iteration {best_iteration}." ) + self.gaussian_model.set_optimizable_tensors({"xyz": best_xyz}) torchvision.utils.save_image( - rendered_image, f"artifacts/local/transfo/rendered_{iteration}.png" + rendered_image, f"artifacts/local/transfo/rendered_best.png" ) torchvision.utils.save_image(gt_image, f"artifacts/local/transfo/gt.png") diff --git a/gaussian_splatting/utils/affine_transformation.py b/gaussian_splatting/utils/affine_transformation.py deleted file mode 100644 index 27637ebab..000000000 --- a/gaussian_splatting/utils/affine_transformation.py +++ /dev/null @@ -1,7 +0,0 @@ -import numpy as np - - -def apply_affine_transformation(xyz, rotation, translation): - transformed_xyz = np.matmul(xyz, rotation.T) + translation - - return transformed_xyz diff --git a/scripts/train_colmap_free.py b/scripts/train_colmap_free.py index 0bb11d69e..cd9f69e2f 100644 --- a/scripts/train_colmap_free.py +++ b/scripts/train_colmap_free.py @@ -1,26 +1,71 @@ +import copy from pathlib import Path +import numpy as np +import torchvision + +from gaussian_splatting.dataset.cameras import Camera from gaussian_splatting.dataset.image_dataset import ImageDataset +from gaussian_splatting.global_trainer import GlobalTrainer from gaussian_splatting.local_initialization_trainer import \ LocalInitializationTrainer from gaussian_splatting.local_transformation_trainer import \ LocalTransformationTrainer +from gaussian_splatting.render import render +from gaussian_splatting.utils.general import PILtoTorch def main(): dataset = ImageDataset(images_path=Path("data/phil/1/input/")) - image_0 = dataset.get_frame(0) - image_1 = dataset.get_frame(10) + first_image = dataset.get_frame(0) - local_initialization_trainer = LocalInitializationTrainer(image_0, iterations=100) + local_initialization_trainer = LocalInitializationTrainer( + first_image, iterations=500 + ) local_initialization_trainer.run() - local_transformation_trainer = LocalTransformationTrainer( - image_1, - camera=local_initialization_trainer.camera, - gaussian_model=local_initialization_trainer.gaussian_model, + step = 5 + gaussian_model = local_initialization_trainer.gaussian_model + camera = local_initialization_trainer.camera + global_trainer = GlobalTrainer( + gaussian_model=gaussian_model, cameras=[camera], iterations=100 ) - local_transformation_trainer.run() + for iteration in range(step, len(dataset), step): + next_image = dataset.get_frame(iteration) + local_transformation_trainer = LocalTransformationTrainer( + next_image, + camera=camera, + gaussian_model=copy.deepcopy(gaussian_model), + iterations=250, + ) + local_transformation_trainer.run() + rotation, translation = local_transformation_trainer.get_affine_transformation() + + next_image = PILtoTorch(next_image) + next_camera = Camera( + R=np.matmul(camera.R, rotation), + T=camera.T + translation, + FoVx=camera.FoVx, + FoVy=camera.FoVy, + image=next_image, + gt_alpha_mask=None, + image_name="patate", + colmap_id=iteration, + uid=iteration, + ) + global_trainer.add_camera(next_camera) + global_trainer.run() + + rendered_image, viewspace_point_tensor, visibility_filter, radii = render( + next_camera, + local_initialization_trainer.gaussian_model, + ) + torchvision.utils.save_image( + rendered_image, f"artifacts/global/rendered_{iteration}.png" + ) + torchvision.utils.save_image(next_image, f"artifacts/global/gt_{iteration}.png") + + print(f">>> Iteration {iteration} / {len(dataset) // step}") if __name__ == "__main__":