From 11120b1c33f7cc9783f9addfa8b95d72e3e791c5 Mon Sep 17 00:00:00 2001 From: Isabelle Bouchard Date: Wed, 10 Apr 2024 15:05:18 -0400 Subject: [PATCH] optimize training --- Dockerfile | 3 +- gaussian_splatting/dataset/cameras.py | 8 +-- gaussian_splatting/dataset/dataset.py | 1 - .../metrics/lpipPyTorch/modules/lpips.py | 1 - gaussian_splatting/model.py | 1 - gaussian_splatting/optimizer.py | 1 - gaussian_splatting/trainer.py | 57 +++++++++++-------- scripts/convert.py | 3 +- scripts/extract_video_frames.py | 2 +- scripts/metrics.py | 1 - setup.py | 5 +- 11 files changed, 44 insertions(+), 39 deletions(-) diff --git a/Dockerfile b/Dockerfile index fbe2ec6b7..f92254c41 100644 --- a/Dockerfile +++ b/Dockerfile @@ -26,7 +26,8 @@ RUN conda init RUN conda env create --file environment.yml # COLMAP dependencies -RUN DEBIAN_FRONTEND=noninteractive apt-get install -y \ +RUN apt update && \ + DEBIAN_FRONTEND=noninteractive apt-get install -y \ git \ cmake \ libboost-program-options-dev \ diff --git a/gaussian_splatting/dataset/cameras.py b/gaussian_splatting/dataset/cameras.py index f2c34563d..4abf6eb8e 100644 --- a/gaussian_splatting/dataset/cameras.py +++ b/gaussian_splatting/dataset/cameras.py @@ -44,16 +44,12 @@ def __init__( self.data_device = torch.device("cuda") - self.original_image = image.clamp(0.0, 1.0).to(self.data_device) + self.original_image = image.clamp(0.0, 1.0) self.image_width = self.original_image.shape[2] self.image_height = self.original_image.shape[1] if gt_alpha_mask is not None: - self.original_image *= gt_alpha_mask.to(self.data_device) - else: - self.original_image *= torch.ones( - (1, self.image_height, self.image_width), device=self.data_device - ) + self.original_image *= gt_alpha_mask self.zfar = 100.0 self.znear = 0.01 diff --git a/gaussian_splatting/dataset/dataset.py b/gaussian_splatting/dataset/dataset.py index 4e110cd17..469b9ceee 100644 --- a/gaussian_splatting/dataset/dataset.py +++ b/gaussian_splatting/dataset/dataset.py @@ -18,7 +18,6 @@ class Dataset: - def __init__( self, source_path, diff --git a/gaussian_splatting/metrics/lpipPyTorch/modules/lpips.py b/gaussian_splatting/metrics/lpipPyTorch/modules/lpips.py index 066bd83ad..2c0f14b4b 100644 --- a/gaussian_splatting/metrics/lpipPyTorch/modules/lpips.py +++ b/gaussian_splatting/metrics/lpipPyTorch/modules/lpips.py @@ -17,7 +17,6 @@ class LPIPS(nn.Module): """ def __init__(self, net_type: str = "alex", version: str = "0.1"): - assert version in ["0.1"], "v0.1 is only supported now" super(LPIPS, self).__init__() diff --git a/gaussian_splatting/model.py b/gaussian_splatting/model.py index a23e3fbfc..d9b4438ba 100644 --- a/gaussian_splatting/model.py +++ b/gaussian_splatting/model.py @@ -33,7 +33,6 @@ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): class GaussianModel: - def __init__(self, sh_degree: int = 3): self.active_sh_degree = 0 self.max_sh_degree = sh_degree diff --git a/gaussian_splatting/optimizer.py b/gaussian_splatting/optimizer.py index b4f1379fb..eb6482f12 100644 --- a/gaussian_splatting/optimizer.py +++ b/gaussian_splatting/optimizer.py @@ -125,7 +125,6 @@ def concatenate_points(self, tensors_dict): extension_tensor = tensors_dict[group["name"]] stored_state = self._optimizer.state.get(group["params"][0], None) if stored_state is not None: - stored_state["exp_avg"] = torch.cat( (stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0 ) diff --git a/gaussian_splatting/trainer.py b/gaussian_splatting/trainer.py index 12a2decdc..fd322869b 100644 --- a/gaussian_splatting/trainer.py +++ b/gaussian_splatting/trainer.py @@ -22,11 +22,15 @@ def __init__( resolution=-1, sh_degree=3, checkpoint_path=None, + output_path=None, + min_num_cameras=15, ): - self._model_path = self._prepare_model_path() + self.model_path = self._prepare_model_path(output_path) self.dataset = Dataset(source_path, keep_eval=keep_eval, resolution=resolution) - self.dataset.save_scene_info(self._model_path) + if len(self.dataset.get_train_cameras()) <= min_num_cameras: + raise ValueError("Not enough cameras to reconstruct the scene!") + self.dataset.save_scene_info(self.model_path) self.gaussian_model = GaussianModel(sh_degree) self.gaussian_model.initialize(self.dataset) @@ -35,18 +39,18 @@ def __init__( self._checkpoint_path = checkpoint_path - self._debug = False + self._debug = True - self._iterations = 30000 - self._testing_iterations = [7000, 30000] - self._saving_iterations = [7000, 30000] + self._iterations = 10000 + self._testing_iterations = [20, 250, 1000, 2500, 7000, 30000] + self._saving_iterations = [20, 250, 1000, 2500, 3500, 4500, 5000, 7000, 10000] self._checkpoint_iterations = [] # Loss function self._lambda_dssim = 0.2 # Densification and pruning - self._opacity_reset_interval = 3000 + self._opacity_reset_interval = 30000 self._min_opacity = 0.005 self._max_screen_size = 20 self._percent_dense = 0.01 @@ -60,9 +64,11 @@ def __init__( def run(self): first_iter = 0 if self._checkpoint_path: - gaussian_model_state_dict, self.optimizer_state_dict, first_iter = ( - torch.load(checkpoint_path) - ) + ( + gaussian_model_state_dict, + self.optimizer_state_dict, + first_iter, + ) = torch.load(checkpoint_path) self.gaussian_model.load_state_dict(gaussian_model_state_dict) self.optimizer.load_state_dict(optmizer_state_dict) @@ -101,6 +107,15 @@ def run(self): # except Exception: # import pdb; pdb.set_trace() + if iteration in self._saving_iterations: + print("\n[ITER {}] Saving Gaussians".format(iteration)) + point_cloud_path = os.path.join( + self.model_path, "point_cloud/iteration_{}".format(iteration) + ) + self.gaussian_model.save_ply( + os.path.join(point_cloud_path, "point_cloud.ply") + ) + with torch.no_grad(): # Progress bar ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log @@ -111,17 +126,8 @@ def run(self): progress_bar.close() # Log and save - if iteration in self._testing_iterations: - self._report(iteration) - - if iteration in self._saving_iterations: - print("\n[ITER {}] Saving Gaussians".format(iteration)) - point_cloud_path = os.path.join( - self.model_path, "point_cloud/iteration_{}".format(iteration) - ) - self.gaussian_model.save_ply( - os.path.join(point_cloud_path, "point_cloud.ply") - ) + # if iteration in self._testing_iterations: + # self._report(iteration) # Densification if iteration < self._densification_iteration_stop: @@ -158,9 +164,12 @@ def run(self): self.model_path + "/chkpnt" + str(iteration) + ".pth", ) - def _prepare_model_path(self): + def _prepare_model_path(self, output_path): unique_str = str(uuid.uuid4()) - model_path = os.path.join("./output/", unique_str[0:10]) + model_path = os.path.join( + "./output/" if output_path is None else f"./output/{output_path}/", + unique_str[0:10], + ) # Set up output folder print("Output folder: {}".format(model_path)) @@ -181,7 +190,7 @@ def _report(self, iteration): ], } - for config_name, cameras in validation_configs: + for config_name, cameras in validation_configs.items(): if not cameras or len(cameras) == 0: continue diff --git a/scripts/convert.py b/scripts/convert.py index a0079c331..5c83f1067 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -52,7 +52,8 @@ def main(source_path, resize, use_gpu): step=" Feature matching", command=f"colmap sequential_matcher \ --database_path {source_path / 'distorted' / 'database.db'} \ - --SiftMatching.use_gpu {use_gpu}", + --SiftMatching.use_gpu {use_gpu} \ + --SiftMatching.max_num_matches 10000", ) _run_command( diff --git a/scripts/extract_video_frames.py b/scripts/extract_video_frames.py index 8e9ceeb81..11e0f298b 100644 --- a/scripts/extract_video_frames.py +++ b/scripts/extract_video_frames.py @@ -23,7 +23,7 @@ def main(video_filename: Path, output_path: Path, k: int = 10): current_frame += 1 continue - output_filename = output_path / f"{current_frame}.jpg" + output_filename = output_path / f"{str(current_frame).zfill(6)}.jpg" cv2.imwrite(output_filename.as_posix(), frame) print(f"Writing {output_filename}.") diff --git a/scripts/metrics.py b/scripts/metrics.py index 901c3d193..e1681a07a 100644 --- a/scripts/metrics.py +++ b/scripts/metrics.py @@ -38,7 +38,6 @@ def readImages(renders_dir, gt_dir): def evaluate(model_paths): - full_dict = {} per_view_dict = {} full_dict_polytopeonly = {} diff --git a/setup.py b/setup.py index 26d554c14..76bcb5def 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,9 @@ packages=find_packages(), # Automatically discover and include all packages scripts=[], install_requires=[ - "pyyaml" + "pyyaml", + "black", + "isort", + "importchecker" ] )