diff --git a/gaussian_splatting/local_initialization_trainer.py b/gaussian_splatting/local_initialization_trainer.py index ae4fe13ec..be3b7ff75 100644 --- a/gaussian_splatting/local_initialization_trainer.py +++ b/gaussian_splatting/local_initialization_trainer.py @@ -14,7 +14,7 @@ from gaussian_splatting.trainer import Trainer from gaussian_splatting.utils.general import PILtoTorch, safe_state from gaussian_splatting.utils.graphics import BasicPointCloud -from gaussian_splatting.utils.loss import l1_loss, ssim +from gaussian_splatting.utils.loss import PhotometricLoss class LocalInitializationTrainer(Trainer): @@ -32,12 +32,10 @@ def __init__(self, image, sh_degree: int = 3, iterations: int = 10000): # TODO: set camera extent??? self.optimizer = Optimizer(self.gaussian_model) + self._photometric_loss = PhotometricLoss(lambda_dssim=0.2) self.camera = self._get_orthogonal_camera(image) - self._iterations = iterations - self._lambda_dssim = 0.2 - # Densification and pruning self._opacity_reset_interval = 10001 self._min_opacity = 0.005 @@ -52,11 +50,11 @@ def __init__(self, image, sh_degree: int = 3, iterations: int = 10000): safe_state(seed=2234) - def run(self): - progress_bar = tqdm(range(self._iterations), desc="Initialization") + def run(self, iterations: int = 3000): + progress_bar = tqdm(range(iterations), desc="Initialization") best_loss, best_iteration, losses = None, 0, [] - for iteration in range(self._iterations): + for iteration in range(iterations): self.optimizer.update_learning_rate(iteration) rendered_image, viewspace_point_tensor, visibility_filter, radii = render( self.camera, self.gaussian_model @@ -73,19 +71,17 @@ def run(self): ) gt_image = self.camera.original_image.cuda() - Ll1 = l1_loss(rendered_image, gt_image) - loss = (1.0 - self._lambda_dssim) * Ll1 + self._lambda_dssim * ( - 1.0 - ssim(rendered_image, gt_image) - ) + loss = self._photometric_loss(rendered_image, gt_image) + loss.backward() + + self.optimizer.step() + self.optimizer.zero_grad(set_to_none=True) + if best_loss is None or best_loss > loss: best_loss = loss.cpu().item() best_iteration = iteration losses.append(loss.cpu().item()) - loss.backward() - - self.optimizer.step() - self.optimizer.zero_grad(set_to_none=True) with torch.no_grad(): # Densification diff --git a/gaussian_splatting/local_transformation_trainer.py b/gaussian_splatting/local_transformation_trainer.py index 7bb493f7d..ad05c5a21 100644 --- a/gaussian_splatting/local_transformation_trainer.py +++ b/gaussian_splatting/local_transformation_trainer.py @@ -6,8 +6,8 @@ from gaussian_splatting.render import render from gaussian_splatting.trainer import Trainer -from gaussian_splatting.utils.general import PILtoTorch, safe_state -from gaussian_splatting.utils.loss import l1_loss, ssim +from gaussian_splatting.utils.general import safe_state +from gaussian_splatting.utils.loss import PhotometricLoss class QuaternionRotation(nn.Module): @@ -81,84 +81,73 @@ def translation(self): class LocalTransformationTrainer(Trainer): - def __init__(self, image, camera, gaussian_model, iterations: int = 100): - self.camera = camera + def __init__(self, gaussian_model): self.gaussian_model = gaussian_model - self.xyz = gaussian_model.get_xyz.detach() - self.transformation_model = TransformationModel() - self.transformation_model.to(self.xyz.device) - - self.image = PILtoTorch(image).to(self.xyz.device) + self.transformation_model.to(gaussian_model.get_xyz.device) self.optimizer = torch.optim.Adam( self.transformation_model.parameters(), lr=0.0001 ) - - self._iterations = iterations - self._lambda_dssim = 0.2 + self._photometric_loss = PhotometricLoss(lambda_dssim=0.2) safe_state(seed=2234) - def run(self): - progress_bar = tqdm(range(self._iterations), desc="Transformation") + def run(self, current_camera, gt_image, iterations: int = 1000): + gt_image = gt_image.to(self.gaussian_model.get_xyz.device) + + progress_bar = tqdm(range(iterations), desc="Transformation") - best_loss, best_iteration, losses = None, 0, [] - best_xyz = None - for iteration in range(self._iterations): - xyz = self.transformation_model(self.xyz) + losses = [] + best_loss, best_iteration, best_xyz = None, 0, None + patience = 0 + for iteration in range(iterations): + xyz = self.transformation_model(self.gaussian_model.get_xyz.detach()) self.gaussian_model.set_optimizable_tensors({"xyz": xyz}) rendered_image, viewspace_point_tensor, visibility_filter, radii = render( - self.camera, self.gaussian_model + current_camera, self.gaussian_model ) - if iteration % 10 == 0: - plt.cla() - plt.plot(losses) - plt.yscale("log") - plt.savefig("artifacts/local/transfo/losses.png") - - torchvision.utils.save_image( - rendered_image, f"artifacts/local/transfo/rendered_{iteration}.png" - ) - - gt_image = self.image - Ll1 = l1_loss(rendered_image, gt_image) - loss = (1.0 - self._lambda_dssim) * Ll1 + self._lambda_dssim * ( - 1.0 - ssim(rendered_image, gt_image) - ) - if best_loss is None or best_loss > loss: - best_loss = loss.cpu().item() - best_iteration = iteration - best_xyz = xyz - losses.append(loss.cpu().item()) - + loss = self._photometric_loss(rendered_image, gt_image) loss.backward() - self.optimizer.step() - progress_bar.set_postfix( - { - "Loss": f"{loss:.{5}f}", - } - ) + losses.append(loss.cpu().item()) + + progress_bar.set_postfix({"Loss": f"{loss:.{5}f}"}) progress_bar.update(1) + if iteration % 10 == 0: + self._save_artifacts(losses, rendered_image, iteration) + + if best_loss is None or best_loss > loss: + best_loss = loss.cpu().item() + best_iteration = iteration + best_xyz = xyz.detach() + elif best_loss < loss and patience > 10: + self._save_artifacts(losses, rendered_image, iteration) + break + else: + patience += 1 + progress_bar.close() - print( - f"Training done. Best loss = {best_loss:.{5}f} at iteration {best_iteration}." - ) - self.gaussian_model.set_optimizable_tensors({"xyz": best_xyz}) - torchvision.utils.save_image( - rendered_image, f"artifacts/local/transfo/rendered_best.png" - ) - torchvision.utils.save_image(gt_image, f"artifacts/local/transfo/gt.png") + print(f"Best loss = {best_loss:.{5}f} at iteration {best_iteration}.") + self.gaussian_model.set_optimizable_tensors({"xyz": best_xyz}) - def get_affine_transformation(self): rotation = self.transformation_model.rotation.numpy() translation = self.transformation_model.translation.numpy() return rotation, translation + + def _save_artifacts(self, losses, rendered_image, iteration): + plt.cla() + plt.plot(losses) + plt.yscale("log") + plt.savefig("artifacts/local/transfo/losses.png") + + torchvision.utils.save_image( + rendered_image, f"artifacts/local/transfo/rendered_{iteration}.png" + ) diff --git a/gaussian_splatting/utils/loss.py b/gaussian_splatting/utils/loss.py index 06bad7e0d..9549d5678 100644 --- a/gaussian_splatting/utils/loss.py +++ b/gaussian_splatting/utils/loss.py @@ -16,6 +16,18 @@ from torch.autograd import Variable +class PhotometricLoss: + def __init__(self, lambda_dssim: float = 0.2): + self._lambda_dssim = lambda_dssim + + def __call__(self, network_output, gt): + l1_value = l1_loss(network_output, gt) + ssim_value = ssim(network_output, gt) + + loss = (1.0 - self._lambda_dssim) * l1_value + self._lambda_dssim * (1.0 - ssim_value) + + return loss + def l1_loss(network_output, gt): return torch.abs((network_output - gt)).mean() diff --git a/scripts/train_colmap_free.py b/scripts/train_colmap_free.py index cd9f69e2f..2727d16c0 100644 --- a/scripts/train_colmap_free.py +++ b/scripts/train_colmap_free.py @@ -1,72 +1,85 @@ -import copy from pathlib import Path +import copy import numpy as np import torchvision +import torch +from tqdm import tqdm from gaussian_splatting.dataset.cameras import Camera +from gaussian_splatting.render import render from gaussian_splatting.dataset.image_dataset import ImageDataset -from gaussian_splatting.global_trainer import GlobalTrainer from gaussian_splatting.local_initialization_trainer import \ LocalInitializationTrainer from gaussian_splatting.local_transformation_trainer import \ LocalTransformationTrainer -from gaussian_splatting.render import render +from gaussian_splatting.global_trainer import GlobalTrainer from gaussian_splatting.utils.general import PILtoTorch +from gaussian_splatting.utils.loss import PhotometricLoss def main(): + debug = True + iteration_step_size = 5 + initialization_iterations = 50 + transformation_iterations = 50 + + photometric_loss = PhotometricLoss(lambda_dssim=0.2) dataset = ImageDataset(images_path=Path("data/phil/1/input/")) - first_image = dataset.get_frame(0) - local_initialization_trainer = LocalInitializationTrainer( - first_image, iterations=500 - ) - local_initialization_trainer.run() + # Initialize Local3DGS gaussians + current_image = dataset.get_frame(0) + local_initialization_trainer = LocalInitializationTrainer(current_image) + local_initialization_trainer.run(iterations=initialization_iterations) + + next_gaussian_model = local_initialization_trainer.gaussian_model + local_transformation_trainer = LocalTransformationTrainer(next_gaussian_model) + #global_trainer = GlobalTrainer( + # gaussian_model=gaussian_model, + # cameras=[camera], + # iterations=100 + #) + + current_camera = local_initialization_trainer.camera + for iteration in tqdm(range(iteration_step_size, len(dataset), iteration_step_size)): + # Keep a copy of current gaussians to compare + with torch.no_grad(): + current_gaussian_model = copy.deepcopy(next_gaussian_model) - step = 5 - gaussian_model = local_initialization_trainer.gaussian_model - camera = local_initialization_trainer.camera - global_trainer = GlobalTrainer( - gaussian_model=gaussian_model, cameras=[camera], iterations=100 - ) - for iteration in range(step, len(dataset), step): - next_image = dataset.get_frame(iteration) - local_transformation_trainer = LocalTransformationTrainer( + # Find transformation from current to next camera poses + next_image = PILtoTorch(dataset.get_frame(iteration)) + rotation, translation = local_transformation_trainer.run( + current_camera, next_image, - camera=camera, - gaussian_model=copy.deepcopy(gaussian_model), - iterations=250, + iterations=transformation_iterations, ) - local_transformation_trainer.run() - rotation, translation = local_transformation_trainer.get_affine_transformation() - next_image = PILtoTorch(next_image) + # Add new camera to Global3DGS training cameras next_camera = Camera( - R=np.matmul(camera.R, rotation), - T=camera.T + translation, - FoVx=camera.FoVx, - FoVy=camera.FoVy, + R=np.matmul(current_camera.R, rotation), + T=current_camera.T + translation, + FoVx=current_camera.FoVx, + FoVy=current_camera.FoVy, image=next_image, gt_alpha_mask=None, image_name="patate", colmap_id=iteration, uid=iteration, ) - global_trainer.add_camera(next_camera) - global_trainer.run() - rendered_image, viewspace_point_tensor, visibility_filter, radii = render( - next_camera, - local_initialization_trainer.gaussian_model, - ) - torchvision.utils.save_image( - rendered_image, f"artifacts/global/rendered_{iteration}.png" - ) - torchvision.utils.save_image(next_image, f"artifacts/global/gt_{iteration}.png") + if debug: + # Save artifact + next_camera_image, _, _, _ = render(next_camera, current_gaussian_model) + next_gaussian_image, _, _, _ = render(current_camera, next_gaussian_model) + loss = photometric_loss(next_camera_image, next_gaussian_image) + assert loss < 0.01 - print(f">>> Iteration {iteration} / {len(dataset) // step}") + torchvision.utils.save_image(next_camera_image, f"artifacts/global/next_camera_{iteration}.png") + torchvision.utils.save_image(next_gaussian_image, f"artifacts/global/next_gaussian_{iteration}.png") + #global_trainer.add_camera(next_camera) + #global_trainer.run() + current_camera = next_camera if __name__ == "__main__": main()