From 05c703f7e32169895f308d73e74bf313c7a12183 Mon Sep 17 00:00:00 2001 From: Isabelle Bouchard Date: Tue, 23 Apr 2024 17:10:22 -0400 Subject: [PATCH] tune parameters --- .../pose_free/depth_estimator.py | 1 - .../pose_free/global_trainer.py | 21 +++++++-------- gaussian_splatting/pose_free/local_trainer.py | 27 ++++++++++++------- .../pose_free/pose_free_trainer.py | 24 +++++++++++++---- gaussian_splatting/utils/loss.py | 8 +++++- 5 files changed, 53 insertions(+), 28 deletions(-) diff --git a/gaussian_splatting/pose_free/depth_estimator.py b/gaussian_splatting/pose_free/depth_estimator.py index 8fa225b4e..03f1f937b 100644 --- a/gaussian_splatting/pose_free/depth_estimator.py +++ b/gaussian_splatting/pose_free/depth_estimator.py @@ -1,7 +1,6 @@ import torch from transformers import pipeline -from gaussian_splatting.dataset.image_dataset import ImageDataset from gaussian_splatting.utils.general import TorchToPIL diff --git a/gaussian_splatting/pose_free/global_trainer.py b/gaussian_splatting/pose_free/global_trainer.py index 95edca85e..9cae5e222 100644 --- a/gaussian_splatting/pose_free/global_trainer.py +++ b/gaussian_splatting/pose_free/global_trainer.py @@ -8,7 +8,7 @@ class GlobalTrainer(Trainer): - def __init__(self, gaussian_model, iterations: int = 100, output_path=None): + def __init__(self, gaussian_model, output_path=None): self._model_path = self._prepare_model_path(output_path) self.gaussian_model = gaussian_model @@ -16,9 +16,7 @@ def __init__(self, gaussian_model, iterations: int = 100, output_path=None): self.optimizer = Optimizer(self.gaussian_model) self._photometric_loss = PhotometricLoss(lambda_dssim=0.2) - self._iterations = iterations - - self._debug = False + self._debug = True # Densification and pruning self._min_opacity = 0.005 @@ -28,9 +26,9 @@ def __init__(self, gaussian_model, iterations: int = 100, output_path=None): safe_state() - def run(self, current_camera, next_camera, progress_bar=None, run_id: int = 0): + def run(self, current_camera, next_camera, iterations: int = 100, progress_bar=None, run_id: int = 0): cameras = (current_camera, next_camera) - for iteration in range(self._iterations): + for iteration in range(iterations): self.optimizer.update_learning_rate(iteration) # Every 1000 its we increase the levels of SH up to a maximum degree @@ -58,7 +56,7 @@ def run(self, current_camera, next_camera, progress_bar=None, run_id: int = 0): progress_bar.set_postfix( { "stage": "global", - "iteration": f"{iteration}/{self._iterations}", + "iteration": f"{iteration}/{iterations}", "loss": f"{loss_value:.5f}", } ) @@ -68,8 +66,9 @@ def run(self, current_camera, next_camera, progress_bar=None, run_id: int = 0): ) # Densification - self.gaussian_model.update_stats( - viewspace_point_tensor, visibility_filter, radii - ) - self._densify_and_prune(True) + #self.gaussian_model.update_stats( + # viewspace_point_tensor, visibility_filter, radii + #) + #self._split_points() + # self._densify_and_prune(True) # self._reset_opacity() diff --git a/gaussian_splatting/pose_free/local_trainer.py b/gaussian_splatting/pose_free/local_trainer.py index 275e24364..8c2159a17 100644 --- a/gaussian_splatting/pose_free/local_trainer.py +++ b/gaussian_splatting/pose_free/local_trainer.py @@ -26,21 +26,22 @@ def __init__( debug: bool = False, ): self._depth_estimator = DepthEstimator() - self._point_cloud_step = 25 + self._point_cloud_step = 2 self._sh_degree = sh_degree - self._photometric_loss = PhotometricLoss(lambda_dssim=0.2) + self._init_loss = PhotometricLoss(lambda_dssim=0.2) + self._transfo_loss = PhotometricLoss(lambda_dssim=0.2) self._init_iterations = init_iterations - self._init_early_stopper = EarlyStopper(patience=10) + self._init_early_stopper = EarlyStopper(patience=50) self._init_save_artifacts_iterations = 100 - self._transfo_lr = 0.00001 + self._transfo_lr = 1.e-5 self._transfo_iterations = transfo_iterations self._transfo_early_stopper = EarlyStopper( - patience=100, + patience=25, ) - self._transfo_save_artifacts_iterations = 10 + self._transfo_save_artifacts_iterations = 25 self._debug = debug @@ -63,7 +64,7 @@ def run_init(self, image, camera, progress_bar=None, run_id: int = 0): rendered_image, _, _, _ = render(camera, gaussian_model) - loss = self._photometric_loss(rendered_image, image) + loss = self._init_loss(rendered_image, image) loss.backward() loss_value = loss.cpu().item() losses.append(loss_value) @@ -115,7 +116,7 @@ def run_transfo( rendered_image, _, _, _ = render(camera, gaussian_model) - loss = self._photometric_loss(rendered_image, image) + loss = self._transfo_loss(rendered_image, image) loss.backward() loss_value = loss.cpu().item() losses.append(loss_value) @@ -123,11 +124,16 @@ def run_transfo( optimizer.step() if self._transfo_early_stopper.step(loss_value): - transformation = self._transfo_early_stopper.get_best_params() + best_params = self._transfo_early_stopper.get_best_params() + gaussian_model.set_optimizable_tensors({"xyz": best_params["xyz"]}) + transformation = best_params["transformation"] break else: transformation = transformation_model.transformation - self._transfo_early_stopper.set_best_params(transformation) + self._transfo_early_stopper.set_best_params({ + "transformation": transformation, + "xyz": xyz + }) if self._debug and ( iteration % self._transfo_save_artifacts_iterations == 0 @@ -149,6 +155,7 @@ def run_transfo( ) if self._debug: + rendered_image, _, _, _ = render(camera, gaussian_model) self._save_artifacts(losses, rendered_image, output_path, "best") save_image(image, output_path / f"ground_truth.png") diff --git a/gaussian_splatting/pose_free/pose_free_trainer.py b/gaussian_splatting/pose_free/pose_free_trainer.py index 85b6d4c85..a79f50852 100644 --- a/gaussian_splatting/pose_free/pose_free_trainer.py +++ b/gaussian_splatting/pose_free/pose_free_trainer.py @@ -1,4 +1,5 @@ from pathlib import Path +import copy from torchvision.utils import save_image from tqdm import tqdm @@ -16,7 +17,7 @@ def __init__(self, source_path: Path): self._debug = True self._dataset = ImageDataset( - images_path=source_path, step_size=5, downscale_factor=1 + images_path=source_path, step_size=10, downscale_factor=1 ) self._local_trainer = LocalTrainer( @@ -31,7 +32,7 @@ def run(self): initial_gaussian_model = self._local_trainer.get_initial_gaussian_model( current_image, self._output_path ) - global_trainer = GlobalTrainer(initial_gaussian_model, iterations=1000) + global_trainer = GlobalTrainer(initial_gaussian_model) current_camera = get_orthogonal_camera(current_image) @@ -42,6 +43,9 @@ def run(self): gaussian_model = self._local_trainer.run_init( current_image, current_camera, progress_bar, run_id=i ) + if self._debug: + current_gaussian_model = copy.deepcopy(gaussian_model) + rotation, translation = self._local_trainer.run_transfo( next_image, current_camera, @@ -52,12 +56,22 @@ def run(self): next_camera = transform_camera( current_camera, rotation, translation, next_image, _id=i ) - global_trainer.run(current_camera, next_camera, progress_bar, run_id=i) + global_trainer.run( + current_camera, + next_camera, + iterations=(1000 if i == 0 else 100), + progress_bar=progress_bar, + run_id=i + ) if self._debug: - rendered_image, _, _, _ = render(next_camera, gaussian_model) + rendered_image, _, _, _ = render(next_camera, current_gaussian_model) + save_image( + rendered_image, self._output_path / f"{i}_camera_rendered_image.png" + ) + rendered_image, _, _, _ = render(current_camera, gaussian_model) save_image( - rendered_image, self._output_path / f"{i}_rendered_image.png" + rendered_image, self._output_path / f"{i}_gaussian_rendered_image.png" ) save_image(next_image, self._output_path / f"{i}_image.png") diff --git a/gaussian_splatting/utils/loss.py b/gaussian_splatting/utils/loss.py index 1a12e1ca9..9938084c4 100644 --- a/gaussian_splatting/utils/loss.py +++ b/gaussian_splatting/utils/loss.py @@ -17,10 +17,16 @@ class PhotometricLoss: - def __init__(self, lambda_dssim: float = 0.2): + def __init__(self, lambda_dssim: float = 0.2, mask_white_pixels: bool = False): self._lambda_dssim = lambda_dssim + self._mask_white_pixels = mask_white_pixels def __call__(self, network_output, gt): + if self._mask_white_pixels: + mask = (network_output != 1.0).int() + network_output = network_output * mask + gt = gt * mask + l1_value = l1_loss(network_output, gt) ssim_value = ssim(network_output, gt)