From b48029b893a30bd4be281ab48099e7afe688d59d Mon Sep 17 00:00:00 2001 From: Isabelle Bouchard Date: Thu, 11 Apr 2024 17:11:06 -0400 Subject: [PATCH] global reconstruction with no pose prior --- .../colmap_free/global_trainer.py | 7 ++-- .../local_transformation_trainer.py | 6 +-- scripts/train_colmap_free.py | 39 ++++++++++--------- 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/gaussian_splatting/colmap_free/global_trainer.py b/gaussian_splatting/colmap_free/global_trainer.py index 58dbf1a52..dabb602fc 100644 --- a/gaussian_splatting/colmap_free/global_trainer.py +++ b/gaussian_splatting/colmap_free/global_trainer.py @@ -11,8 +11,8 @@ class GlobalTrainer(Trainer): - def __init__(self, gaussian_model): - self._model_path = self._prepare_model_path() + def __init__(self, gaussian_model, output_path = None): + self._model_path = self._prepare_model_path(output_path) self.gaussian_model = gaussian_model self.cameras = [] @@ -81,4 +81,5 @@ def run(self, iterations: int = 1000): viewspace_point_tensor, visibility_filter, radii ) self._densify_and_prune(True) - self._reset_opacity() + #self._reset_opacity() + diff --git a/gaussian_splatting/colmap_free/local_transformation_trainer.py b/gaussian_splatting/colmap_free/local_transformation_trainer.py index fb47d5e86..e7caf7696 100644 --- a/gaussian_splatting/colmap_free/local_transformation_trainer.py +++ b/gaussian_splatting/colmap_free/local_transformation_trainer.py @@ -21,7 +21,7 @@ def __init__(self, gaussian_model): self.transformation_model.to(gaussian_model.get_xyz.device) self.optimizer = torch.optim.Adam( - self.transformation_model.parameters(), lr=0.0005 + self.transformation_model.parameters(), lr=0.0001 ) self._photometric_loss = PhotometricLoss(lambda_dssim=0.2) @@ -57,8 +57,8 @@ def run(self, current_camera, gt_image, iterations: int = 1000, run: int = 0): progress_bar.set_postfix({"Loss": f"{loss:.{5}f}"}) progress_bar.update(1) - if iteration % 50 == 0 or iteration == iterations - 1: - self._save_artifacts(losses, rendered_image, iteration, run) + #if iteration % 50 == 0 or iteration == iterations - 1: + # self._save_artifacts(losses, rendered_image, iteration, run) if best_loss is None or best_loss > loss: best_loss = loss.cpu().item() diff --git a/scripts/train_colmap_free.py b/scripts/train_colmap_free.py index 87c80eaee..ed46678a5 100644 --- a/scripts/train_colmap_free.py +++ b/scripts/train_colmap_free.py @@ -5,7 +5,7 @@ import numpy as np import torchvision -# from gaussian_splatting.colmap_free.global_trainer import GlobalTrainer +from gaussian_splatting.colmap_free.global_trainer import GlobalTrainer from gaussian_splatting.colmap_free.local_initialization_trainer import \ LocalInitializationTrainer from gaussian_splatting.colmap_free.local_transformation_trainer import \ @@ -18,41 +18,43 @@ def main(): debug = True - iteration_step_size = 50 - global_iterations = 5 + step_size = 10 + initialization_iterations = 1000 + transformation_iterations = 250 + global_iterations = 50 photometric_loss = PhotometricLoss(lambda_dssim=0.2) dataset = ImageDataset(images_path=Path("data/phil/1/input/")) global_trainer = None - for iteration in range(0, len(dataset), iteration_step_size): - print(f">>> Current: {iteration} / Next: {iteration + iteration_step_size}") + for i in range(len(dataset) // step_size): + print(f">>> Current: {i * step_size} / Next: {(i + 1) * step_size}") - current_image = dataset.get_frame(iteration) - next_image = dataset.get_frame(iteration + iteration_step_size) + current_image = dataset.get_frame(i * step_size) + next_image = dataset.get_frame((i + 1) * step_size) - if iteration == 0: + if i == 0: current_camera = _get_orthogonal_camera(current_image) else: current_camera = next_camera current_camera, current_gaussian_model = _initialize_local_gaussian_model( - current_image, current_camera, run=iteration + current_image, current_camera, iterations=initialization_iterations, run=i ) - # if iteration == 0: - # global_gaussian_model = copy.deepcopy(current_gaussian_model) - # global_trainer = GlobalTrainer(global_gaussian_model) + if i == 0: + global_gaussian_model = copy.deepcopy(current_gaussian_model) + global_trainer = GlobalTrainer(global_gaussian_model) if debug: current_gaussian_model_copy = copy.deepcopy(current_gaussian_model) next_camera, next_gaussian_model = _transform_local_gaussian_model( - next_image, current_camera, current_gaussian_model, run=iteration + next_image, current_camera, current_gaussian_model, + iterations=transformation_iterations, run=i ) - # global_trainer.add_camera(next_camera) - # global_trainer.run(global_iterations) + global_trainer.add_camera(next_camera) if debug: save_artifacts( @@ -62,11 +64,12 @@ def main(): next_camera, next_gaussian_model, next_image, - iteration, + i, ) - if iteration >= 50: - break + global_trainer.run((i + 1) * global_iterations) + + #global_trainer.run(global_iterations) def _initialize_local_gaussian_model( image, camera, iterations: int = 250, run: int = 0