diff --git a/gaussian_splatting/.DS_Store b/gaussian_splatting/.DS_Store new file mode 100644 index 000000000..a5239b26c Binary files /dev/null and b/gaussian_splatting/.DS_Store differ diff --git a/gaussian_splatting/colmap_free/global_trainer.py b/gaussian_splatting/colmap_free/global_trainer.py new file mode 100644 index 000000000..8da863a26 --- /dev/null +++ b/gaussian_splatting/colmap_free/global_trainer.py @@ -0,0 +1,86 @@ +import os +import uuid +from random import randint + +import torch +from tqdm import tqdm + +from gaussian_splatting.optimizer import Optimizer +from gaussian_splatting.render import render +from gaussian_splatting.trainer import Trainer +from gaussian_splatting.utils.general import safe_state +from gaussian_splatting.utils.loss import PhotometricLoss + + +class GlobalTrainer(Trainer): + def __init__(self, gaussian_model): + self._model_path = self._prepare_model_path() + + self.gaussian_model = gaussian_model + self.cameras = [] + + self.optimizer = Optimizer(self.gaussian_model) + self._photometric_loss = PhotometricLoss(lambda_dssim=0.2) + + self._debug = False + + # Densification and pruning + self._min_opacity = 0.005 + self._max_screen_size = 20 + self._percent_dense = 0.01 + self._densification_grad_threshold = 0.0002 + + safe_state() + + def add_camera(self, camera): + self.cameras.append(camera) + + def run(self, iterations: int = 1000): + ema_loss_for_log = 0.0 + cameras = None + first_iter = 1 + progress_bar = tqdm(range(first_iter, iterations), desc="Training progress") + for iteration in range(first_iter, iterations + 1): + self.optimizer.update_learning_rate(iteration) + + # Every 1000 its we increase the levels of SH up to a maximum degree + if iteration % 1000 == 0: + self.gaussian_model.oneupSHdegree() + + # Pick a random camera + if not cameras: + cameras = self.cameras.copy() + camera = cameras.pop(randint(0, len(cameras) - 1)) + + # Render image + rendered_image, viewspace_point_tensor, visibility_filter, radii = render( + camera, self.gaussian_model + ) + + # Loss + gt_image = camera.original_image.cuda() + loss = self._photometric_loss(rendered_image, gt_image) + loss.backward() + + # Optimizer step + self.optimizer.step() + self.optimizer.zero_grad(set_to_none=True) + + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) + progress_bar.update(1) + + progress_bar.close() + + point_cloud_path = os.path.join( + self._model_path, "point_cloud/iteration_{}".format(iteration) + ) + self.gaussian_model.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) + + # Densification + self.gaussian_model.update_stats( + viewspace_point_tensor, visibility_filter, radii + ) + self._densify_and_prune(True) + self._reset_opacity() diff --git a/gaussian_splatting/local_initialization_trainer.py b/gaussian_splatting/colmap_free/local_initialization_trainer.py similarity index 99% rename from gaussian_splatting/local_initialization_trainer.py rename to gaussian_splatting/colmap_free/local_initialization_trainer.py index be3b7ff75..1c19b8ef2 100644 --- a/gaussian_splatting/local_initialization_trainer.py +++ b/gaussian_splatting/colmap_free/local_initialization_trainer.py @@ -82,7 +82,6 @@ def run(self, iterations: int = 3000): best_iteration = iteration losses.append(loss.cpu().item()) - with torch.no_grad(): # Densification if iteration < self._densification_iteration_stop: diff --git a/gaussian_splatting/local_transformation_trainer.py b/gaussian_splatting/colmap_free/local_transformation_trainer.py similarity index 55% rename from gaussian_splatting/local_transformation_trainer.py rename to gaussian_splatting/colmap_free/local_transformation_trainer.py index ad05c5a21..724ee16d1 100644 --- a/gaussian_splatting/local_transformation_trainer.py +++ b/gaussian_splatting/colmap_free/local_transformation_trainer.py @@ -1,90 +1,21 @@ import torch -import torch.nn as nn import torchvision from matplotlib import pyplot as plt from tqdm import tqdm +from gaussian_splatting.colmap_free.transformation_model import \ + AffineTransformationModel from gaussian_splatting.render import render from gaussian_splatting.trainer import Trainer from gaussian_splatting.utils.general import safe_state from gaussian_splatting.utils.loss import PhotometricLoss -class QuaternionRotation(nn.Module): - def __init__(self): - super().__init__() - self.quaternion = nn.Parameter(torch.Tensor([[0, 0, 0, 1]])) - - def forward(self, input_tensor): - rotation_matrix = self.get_rotation_matrix() - rotated_tensor = torch.matmul(input_tensor, rotation_matrix) - - return rotated_tensor - - def get_rotation_matrix(self): - # Normalize quaternion to ensure unit magnitude - quaternion_norm = torch.norm(self.quaternion, p=2, dim=1, keepdim=True) - normalized_quaternion = self.quaternion / quaternion_norm - - x, y, z, w = normalized_quaternion[0] - rotation_matrix = torch.zeros( - 3, 3, dtype=torch.float32, device=self.quaternion.device - ) - rotation_matrix[0, 0] = 1 - 2 * (y**2 + z**2) - rotation_matrix[0, 1] = 2 * (x * y - w * z) - rotation_matrix[0, 2] = 2 * (x * z + w * y) - rotation_matrix[1, 0] = 2 * (x * y + w * z) - rotation_matrix[1, 1] = 1 - 2 * (x**2 + z**2) - rotation_matrix[1, 2] = 2 * (y * z - w * x) - rotation_matrix[2, 0] = 2 * (x * z - w * y) - rotation_matrix[2, 1] = 2 * (y * z + w * x) - rotation_matrix[2, 2] = 1 - 2 * (x**2 + y**2) - - return rotation_matrix - - -class Translation(nn.Module): - def __init__(self): - super().__init__() - self.translation = nn.Parameter(torch.Tensor(1, 3)) - nn.init.zeros_(self.translation) - - def forward(self, input_tensor): - translated_tensor = torch.add(self.translation, input_tensor) - - return translated_tensor - - -class TransformationModel(nn.Module): - def __init__(self): - super().__init__() - self._rotation = QuaternionRotation() - self._translation = Translation() - - def forward(self, xyz): - transformed_xyz = self._rotation(xyz) - transformed_xyz = self._translation(transformed_xyz) - - return transformed_xyz - - @property - def rotation(self): - rotation = self._rotation.get_rotation_matrix().detach().cpu() - - return rotation - - @property - def translation(self): - translation = self._translation.translation.detach().cpu() - - return translation - - class LocalTransformationTrainer(Trainer): def __init__(self, gaussian_model): self.gaussian_model = gaussian_model - self.transformation_model = TransformationModel() + self.transformation_model = AffineTransformationModel() self.transformation_model.to(gaussian_model.get_xyz.device) self.optimizer = torch.optim.Adam( diff --git a/gaussian_splatting/colmap_free/transformation_model.py b/gaussian_splatting/colmap_free/transformation_model.py new file mode 100644 index 000000000..374d061bb --- /dev/null +++ b/gaussian_splatting/colmap_free/transformation_model.py @@ -0,0 +1,72 @@ +import torch +from torch import nn + + +class QuaternionRotationLayer(nn.Module): + def __init__(self): + super().__init__() + self.quaternion = nn.Parameter(torch.Tensor([[0, 0, 0, 1]])) + + def forward(self, input_tensor): + rotation_matrix = self.get_rotation_matrix() + rotated_tensor = torch.matmul(input_tensor, rotation_matrix) + + return rotated_tensor + + def get_rotation_matrix(self): + # Normalize quaternion to ensure unit magnitude + quaternion_norm = torch.norm(self.quaternion, p=2, dim=1, keepdim=True) + normalized_quaternion = self.quaternion / quaternion_norm + + x, y, z, w = normalized_quaternion[0] + rotation_matrix = torch.zeros( + 3, 3, dtype=torch.float32, device=self.quaternion.device + ) + rotation_matrix[0, 0] = 1 - 2 * (y**2 + z**2) + rotation_matrix[0, 1] = 2 * (x * y - w * z) + rotation_matrix[0, 2] = 2 * (x * z + w * y) + rotation_matrix[1, 0] = 2 * (x * y + w * z) + rotation_matrix[1, 1] = 1 - 2 * (x**2 + z**2) + rotation_matrix[1, 2] = 2 * (y * z - w * x) + rotation_matrix[2, 0] = 2 * (x * z - w * y) + rotation_matrix[2, 1] = 2 * (y * z + w * x) + rotation_matrix[2, 2] = 1 - 2 * (x**2 + y**2) + + return rotation_matrix + + +class TranslationLayer(nn.Module): + def __init__(self): + super().__init__() + self.translation = nn.Parameter(torch.Tensor(1, 3)) + nn.init.zeros_(self.translation) + + def forward(self, input_tensor): + translated_tensor = torch.add(self.translation, input_tensor) + + return translated_tensor + + +class AffineTransformationModel(nn.Module): + def __init__(self): + super().__init__() + self._rotation = QuaternionRotationLayer() + self._translation = TranslationLayer() + + def forward(self, xyz): + transformed_xyz = self._rotation(xyz) + transformed_xyz = self._translation(transformed_xyz) + + return transformed_xyz + + @property + def rotation(self): + rotation = self._rotation.get_rotation_matrix().detach().cpu() + + return rotation + + @property + def translation(self): + translation = self._translation.translation.detach().cpu() + + return translation diff --git a/gaussian_splatting/global_trainer.py b/gaussian_splatting/global_trainer.py deleted file mode 100644 index 6acb75ae1..000000000 --- a/gaussian_splatting/global_trainer.py +++ /dev/null @@ -1,198 +0,0 @@ -import os -import uuid -from random import randint - -import torch -from tqdm import tqdm - -from gaussian_splatting.model import GaussianModel -from gaussian_splatting.optimizer import Optimizer -from gaussian_splatting.render import render -from gaussian_splatting.utils.general import safe_state -from gaussian_splatting.utils.image import psnr -from gaussian_splatting.utils.loss import l1_loss, ssim - - -class GlobalTrainer: - def __init__(self, gaussian_model, cameras, iterations: int = 1000): - self._model_path = self._prepare_model_path() - - self.gaussian_model = gaussian_model - self.cameras = cameras - - self.optimizer = Optimizer(self.gaussian_model) - - self._debug = False - - self._iterations = iterations - self._testing_iterations = [iterations, 7000, 30000] - self._saving_iterations = [iterations - 1, 7000, 30000] - self._checkpoint_iterations = [] - - # Loss function - self._lambda_dssim = 0.2 - - # Densification and pruning - self._opacity_reset_interval = 3000 - self._min_opacity = 0.005 - self._max_screen_size = 20 - self._percent_dense = 0.01 - self._densification_interval = 100 - self._densification_iteration_start = 500 - self._densification_iteration_stop = 15000 - self._densification_grad_threshold = 0.0002 - - safe_state() - - def add_camera(self, camera): - self.cameras.append(camera) - - def run(self): - first_iter = 0 - - ema_loss_for_log = 0.0 - cameras = None - progress_bar = tqdm( - range(first_iter, self._iterations), desc="Training progress" - ) - first_iter += 1 - for iteration in range(first_iter, self._iterations + 1): - self.optimizer.update_learning_rate(iteration) - - # Every 1000 its we increase the levels of SH up to a maximum degree - if iteration % 1000 == 0: - self.gaussian_model.oneupSHdegree() - - # Pick a random camera - if not cameras: - cameras = self.cameras.copy() - camera = cameras.pop(randint(0, len(cameras) - 1)) - - # Render image - rendered_image, viewspace_point_tensor, visibility_filter, radii = render( - camera, self.gaussian_model - ) - - # Loss - gt_image = camera.original_image.cuda() - Ll1 = l1_loss(rendered_image, gt_image) - loss = (1.0 - self._lambda_dssim) * Ll1 + self._lambda_dssim * ( - 1.0 - ssim(rendered_image, gt_image) - ) - - loss.backward() - - with torch.no_grad(): - # Progress bar - ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log - if iteration % 10 == 0: - progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) - progress_bar.update(10) - if iteration == self._iterations: - progress_bar.close() - - if iteration in self._saving_iterations: - print("\n[ITER {}] Saving Gaussians".format(iteration)) - point_cloud_path = os.path.join( - self._model_path, "point_cloud/iteration_{}".format(iteration) - ) - self.gaussian_model.save_ply( - os.path.join(point_cloud_path, "point_cloud.ply") - ) - - # Densification - if iteration < self._densification_iteration_stop: - self.gaussian_model.update_stats( - viewspace_point_tensor, visibility_filter, radii - ) - - if ( - iteration >= self._densification_iteration_start - and iteration % self._densification_interval == 0 - ): - self._densify_and_prune( - iteration > self._opacity_reset_interval - ) - - # Reset opacity interval - if iteration % self._opacity_reset_interval == 0: - self._reset_opacity() - - # Optimizer step - if iteration < self._iterations: - self.optimizer.step() - self.optimizer.zero_grad(set_to_none=True) - - def _prepare_model_path(self): - unique_str = str(uuid.uuid4()) - model_path = os.path.join("./output/", unique_str[0:10]) - - # Set up output folder - print("Output folder: {}".format(model_path)) - os.makedirs(model_path, exist_ok=True) - - return model_path - - def _densify_and_prune(self, prune_big_points): - # Clone large gaussian in over-reconstruction areas - self._clone_points() - # Split small gaussians in under-construction areas. - self._split_points() - - # Prune transparent and large gaussians. - prune_mask = (self.gaussian_model.get_opacity < self._min_opacity).squeeze() - if prune_big_points: - # Viewspace - big_points_vs = self.gaussian_model.max_radii2D > self._max_screen_size - # World space - big_points_ws = ( - self.gaussian_model.get_scaling.max(dim=1).values - > 0.1 * self.gaussian_model.camera_extent - ) - prune_mask = torch.logical_or( - torch.logical_or(prune_mask, big_points_vs), big_points_ws - ) - if self._debug: - print(f"Pruning: {prune_mask.sum().item()} points.") - self._prune_points(valid_mask=~prune_mask) - - torch.cuda.empty_cache() - - def _split_points(self): - new_points, split_mask = self.gaussian_model.split_points( - self._densification_grad_threshold, self._percent_dense - ) - self._concatenate_points(new_points) - - prune_mask = torch.cat( - ( - split_mask, - torch.zeros(2 * split_mask.sum(), device="cuda", dtype=bool), - ) - ) - if self._debug: - print(f"Densification: split {split_mask.sum().item()} points.") - self._prune_points(valid_mask=~prune_mask) - - def _clone_points(self): - new_points, clone_mask = self.gaussian_model.clone_points( - self._densification_grad_threshold, self._percent_dense - ) - if self._debug: - print(f"Densification: clone {clone_mask.sum().item()} points.") - self._concatenate_points(new_points) - - def _reset_opacity(self): - new_opacity = self.gaussian_model.reset_opacity() - optimizable_tensors = self.optimizer.replace_points(new_opacity, "opacity") - self.gaussian_model.set_optimizable_tensors(optimizable_tensors) - - def _prune_points(self, valid_mask): - optimizable_tensors = self.optimizer.prune_points(valid_mask) - self.gaussian_model.set_optimizable_tensors(optimizable_tensors) - self.gaussian_model.mask_stats(valid_mask) - - def _concatenate_points(self, new_tensors): - optimizable_tensors = self.optimizer.concatenate_points(new_tensors) - self.gaussian_model.set_optimizable_tensors(optimizable_tensors) - self.gaussian_model.reset_stats() diff --git a/gaussian_splatting/utils/loss.py b/gaussian_splatting/utils/loss.py index 9549d5678..1a12e1ca9 100644 --- a/gaussian_splatting/utils/loss.py +++ b/gaussian_splatting/utils/loss.py @@ -24,10 +24,13 @@ def __call__(self, network_output, gt): l1_value = l1_loss(network_output, gt) ssim_value = ssim(network_output, gt) - loss = (1.0 - self._lambda_dssim) * l1_value + self._lambda_dssim * (1.0 - ssim_value) + loss = (1.0 - self._lambda_dssim) * l1_value + self._lambda_dssim * ( + 1.0 - ssim_value + ) return loss + def l1_loss(network_output, gt): return torch.abs((network_output - gt)).mean() diff --git a/scripts/train_colmap_free.py b/scripts/train_colmap_free.py index 2727d16c0..97ce97137 100644 --- a/scripts/train_colmap_free.py +++ b/scripts/train_colmap_free.py @@ -1,19 +1,19 @@ +import copy from pathlib import Path -import copy import numpy as np -import torchvision import torch +import torchvision from tqdm import tqdm -from gaussian_splatting.dataset.cameras import Camera -from gaussian_splatting.render import render -from gaussian_splatting.dataset.image_dataset import ImageDataset -from gaussian_splatting.local_initialization_trainer import \ +from gaussian_splatting.colmap_free.global_trainer import GlobalTrainer +from gaussian_splatting.colmap_free.local_initialization_trainer import \ LocalInitializationTrainer -from gaussian_splatting.local_transformation_trainer import \ +from gaussian_splatting.colmap_free.local_transformation_trainer import \ LocalTransformationTrainer -from gaussian_splatting.global_trainer import GlobalTrainer +from gaussian_splatting.dataset.cameras import Camera +from gaussian_splatting.dataset.image_dataset import ImageDataset +from gaussian_splatting.render import render from gaussian_splatting.utils.general import PILtoTorch from gaussian_splatting.utils.loss import PhotometricLoss @@ -23,6 +23,7 @@ def main(): iteration_step_size = 5 initialization_iterations = 50 transformation_iterations = 50 + global_iterations = 50 photometric_loss = PhotometricLoss(lambda_dssim=0.2) dataset = ImageDataset(images_path=Path("data/phil/1/input/")) @@ -32,16 +33,16 @@ def main(): local_initialization_trainer = LocalInitializationTrainer(current_image) local_initialization_trainer.run(iterations=initialization_iterations) + # We set a copy of the initialized model to both the local transformation and the + # global models. next_gaussian_model = local_initialization_trainer.gaussian_model local_transformation_trainer = LocalTransformationTrainer(next_gaussian_model) - #global_trainer = GlobalTrainer( - # gaussian_model=gaussian_model, - # cameras=[camera], - # iterations=100 - #) + # global_trainer = GlobalTrainer(copy.deepcopy(next_gaussian_model)) current_camera = local_initialization_trainer.camera - for iteration in tqdm(range(iteration_step_size, len(dataset), iteration_step_size)): + for iteration in tqdm( + range(iteration_step_size, len(dataset), iteration_step_size) + ): # Keep a copy of current gaussians to compare with torch.no_grad(): current_gaussian_model = copy.deepcopy(next_gaussian_model) @@ -74,12 +75,18 @@ def main(): loss = photometric_loss(next_camera_image, next_gaussian_image) assert loss < 0.01 - torchvision.utils.save_image(next_camera_image, f"artifacts/global/next_camera_{iteration}.png") - torchvision.utils.save_image(next_gaussian_image, f"artifacts/global/next_gaussian_{iteration}.png") - #global_trainer.add_camera(next_camera) - #global_trainer.run() + torchvision.utils.save_image( + next_camera_image, f"artifacts/global/next_camera_{iteration}.png" + ) + torchvision.utils.save_image( + next_gaussian_image, f"artifacts/global/next_gaussian_{iteration}.png" + ) + + # global_trainer.add_camera(next_camera) + # global_trainer.run(global_iterations) current_camera = next_camera + if __name__ == "__main__": main()