From 78fd273f4cc3a4f416b91a9db7e153325f1d56af Mon Sep 17 00:00:00 2001 From: Isabelle Bouchard Date: Wed, 27 Mar 2024 17:15:36 -0400 Subject: [PATCH] cleaning --- gaussian_splatting/.DS_Store | Bin 0 -> 6148 bytes .../colmap_free/global_trainer.py | 86 ++++++++ .../local_initialization_trainer.py | 1 - .../local_transformation_trainer.py | 75 +------ .../colmap_free/transformation_model.py | 72 +++++++ gaussian_splatting/global_trainer.py | 198 ------------------ gaussian_splatting/utils/loss.py | 5 +- scripts/train_colmap_free.py | 43 ++-- 8 files changed, 190 insertions(+), 290 deletions(-) create mode 100644 gaussian_splatting/.DS_Store create mode 100644 gaussian_splatting/colmap_free/global_trainer.py rename gaussian_splatting/{ => colmap_free}/local_initialization_trainer.py (99%) rename gaussian_splatting/{ => colmap_free}/local_transformation_trainer.py (55%) create mode 100644 gaussian_splatting/colmap_free/transformation_model.py delete mode 100644 gaussian_splatting/global_trainer.py diff --git a/gaussian_splatting/.DS_Store b/gaussian_splatting/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..a5239b26ce18cab5b962df047942ac81b27f57cd GIT binary patch literal 6148 zcmeHK%}&BV5S~>?@Fx)yP4w1R-~mhuDbXVd-j%-;l0qAhgSTG1`U2k3gYo18c<}*@ zZ{ag|@SEMG?uKIUV$|%A-EWzl>CAkkJCuk-ZQ896m5Hc?!C2VDlq0;)V?oxmW)CRj zHHOrq5na$E3Ky)M!ZF|&_-zdEv%5hp8dFG9GM?W`vl9;!mF~`UzSS@ef_4w&F{z!^>ML8bmau+1zZXldZC0>~y`8tu zc^8pyN8X2?o$15O-ZgVOddzY&{LO6gE$sMDD*hC`W&rlkZ_N8&9mx3Et{K0vb~opy zh~JdG4};Oym@1S<2PXLn04zW)1#8)} zfgUS>zQ$A`G!UU&fy!0rR}7)tq4%_%uQ64q+)2pDxQ>o2^c#wh(V_RGISF5(n;ip= z0h@t!^|j3XzrOta-*$4}90QJlf5m_(1kIp^l5}rfD2{urhjEU4HhOW4Yi~ de1IVZdmb-= self._densification_iteration_start - and iteration % self._densification_interval == 0 - ): - self._densify_and_prune( - iteration > self._opacity_reset_interval - ) - - # Reset opacity interval - if iteration % self._opacity_reset_interval == 0: - self._reset_opacity() - - # Optimizer step - if iteration < self._iterations: - self.optimizer.step() - self.optimizer.zero_grad(set_to_none=True) - - def _prepare_model_path(self): - unique_str = str(uuid.uuid4()) - model_path = os.path.join("./output/", unique_str[0:10]) - - # Set up output folder - print("Output folder: {}".format(model_path)) - os.makedirs(model_path, exist_ok=True) - - return model_path - - def _densify_and_prune(self, prune_big_points): - # Clone large gaussian in over-reconstruction areas - self._clone_points() - # Split small gaussians in under-construction areas. - self._split_points() - - # Prune transparent and large gaussians. - prune_mask = (self.gaussian_model.get_opacity < self._min_opacity).squeeze() - if prune_big_points: - # Viewspace - big_points_vs = self.gaussian_model.max_radii2D > self._max_screen_size - # World space - big_points_ws = ( - self.gaussian_model.get_scaling.max(dim=1).values - > 0.1 * self.gaussian_model.camera_extent - ) - prune_mask = torch.logical_or( - torch.logical_or(prune_mask, big_points_vs), big_points_ws - ) - if self._debug: - print(f"Pruning: {prune_mask.sum().item()} points.") - self._prune_points(valid_mask=~prune_mask) - - torch.cuda.empty_cache() - - def _split_points(self): - new_points, split_mask = self.gaussian_model.split_points( - self._densification_grad_threshold, self._percent_dense - ) - self._concatenate_points(new_points) - - prune_mask = torch.cat( - ( - split_mask, - torch.zeros(2 * split_mask.sum(), device="cuda", dtype=bool), - ) - ) - if self._debug: - print(f"Densification: split {split_mask.sum().item()} points.") - self._prune_points(valid_mask=~prune_mask) - - def _clone_points(self): - new_points, clone_mask = self.gaussian_model.clone_points( - self._densification_grad_threshold, self._percent_dense - ) - if self._debug: - print(f"Densification: clone {clone_mask.sum().item()} points.") - self._concatenate_points(new_points) - - def _reset_opacity(self): - new_opacity = self.gaussian_model.reset_opacity() - optimizable_tensors = self.optimizer.replace_points(new_opacity, "opacity") - self.gaussian_model.set_optimizable_tensors(optimizable_tensors) - - def _prune_points(self, valid_mask): - optimizable_tensors = self.optimizer.prune_points(valid_mask) - self.gaussian_model.set_optimizable_tensors(optimizable_tensors) - self.gaussian_model.mask_stats(valid_mask) - - def _concatenate_points(self, new_tensors): - optimizable_tensors = self.optimizer.concatenate_points(new_tensors) - self.gaussian_model.set_optimizable_tensors(optimizable_tensors) - self.gaussian_model.reset_stats() diff --git a/gaussian_splatting/utils/loss.py b/gaussian_splatting/utils/loss.py index 9549d5678..1a12e1ca9 100644 --- a/gaussian_splatting/utils/loss.py +++ b/gaussian_splatting/utils/loss.py @@ -24,10 +24,13 @@ def __call__(self, network_output, gt): l1_value = l1_loss(network_output, gt) ssim_value = ssim(network_output, gt) - loss = (1.0 - self._lambda_dssim) * l1_value + self._lambda_dssim * (1.0 - ssim_value) + loss = (1.0 - self._lambda_dssim) * l1_value + self._lambda_dssim * ( + 1.0 - ssim_value + ) return loss + def l1_loss(network_output, gt): return torch.abs((network_output - gt)).mean() diff --git a/scripts/train_colmap_free.py b/scripts/train_colmap_free.py index 2727d16c0..97ce97137 100644 --- a/scripts/train_colmap_free.py +++ b/scripts/train_colmap_free.py @@ -1,19 +1,19 @@ +import copy from pathlib import Path -import copy import numpy as np -import torchvision import torch +import torchvision from tqdm import tqdm -from gaussian_splatting.dataset.cameras import Camera -from gaussian_splatting.render import render -from gaussian_splatting.dataset.image_dataset import ImageDataset -from gaussian_splatting.local_initialization_trainer import \ +from gaussian_splatting.colmap_free.global_trainer import GlobalTrainer +from gaussian_splatting.colmap_free.local_initialization_trainer import \ LocalInitializationTrainer -from gaussian_splatting.local_transformation_trainer import \ +from gaussian_splatting.colmap_free.local_transformation_trainer import \ LocalTransformationTrainer -from gaussian_splatting.global_trainer import GlobalTrainer +from gaussian_splatting.dataset.cameras import Camera +from gaussian_splatting.dataset.image_dataset import ImageDataset +from gaussian_splatting.render import render from gaussian_splatting.utils.general import PILtoTorch from gaussian_splatting.utils.loss import PhotometricLoss @@ -23,6 +23,7 @@ def main(): iteration_step_size = 5 initialization_iterations = 50 transformation_iterations = 50 + global_iterations = 50 photometric_loss = PhotometricLoss(lambda_dssim=0.2) dataset = ImageDataset(images_path=Path("data/phil/1/input/")) @@ -32,16 +33,16 @@ def main(): local_initialization_trainer = LocalInitializationTrainer(current_image) local_initialization_trainer.run(iterations=initialization_iterations) + # We set a copy of the initialized model to both the local transformation and the + # global models. next_gaussian_model = local_initialization_trainer.gaussian_model local_transformation_trainer = LocalTransformationTrainer(next_gaussian_model) - #global_trainer = GlobalTrainer( - # gaussian_model=gaussian_model, - # cameras=[camera], - # iterations=100 - #) + # global_trainer = GlobalTrainer(copy.deepcopy(next_gaussian_model)) current_camera = local_initialization_trainer.camera - for iteration in tqdm(range(iteration_step_size, len(dataset), iteration_step_size)): + for iteration in tqdm( + range(iteration_step_size, len(dataset), iteration_step_size) + ): # Keep a copy of current gaussians to compare with torch.no_grad(): current_gaussian_model = copy.deepcopy(next_gaussian_model) @@ -74,12 +75,18 @@ def main(): loss = photometric_loss(next_camera_image, next_gaussian_image) assert loss < 0.01 - torchvision.utils.save_image(next_camera_image, f"artifacts/global/next_camera_{iteration}.png") - torchvision.utils.save_image(next_gaussian_image, f"artifacts/global/next_gaussian_{iteration}.png") - #global_trainer.add_camera(next_camera) - #global_trainer.run() + torchvision.utils.save_image( + next_camera_image, f"artifacts/global/next_camera_{iteration}.png" + ) + torchvision.utils.save_image( + next_gaussian_image, f"artifacts/global/next_gaussian_{iteration}.png" + ) + + # global_trainer.add_camera(next_camera) + # global_trainer.run(global_iterations) current_camera = next_camera + if __name__ == "__main__": main()