diff --git a/gaussian_splatting/local_trainer.py b/gaussian_splatting/local_trainer.py index d3d87bd63..889d4dfb6 100644 --- a/gaussian_splatting/local_trainer.py +++ b/gaussian_splatting/local_trainer.py @@ -1,4 +1,5 @@ import math +from matplotlib import pyplot as plt import numpy as np from transformers import pipeline from tqdm import tqdm @@ -23,11 +24,10 @@ def __init__(self, image, sh_degree: int = 3): depth_estimation = DPT(image)["predicted_depth"] image = PILtoTorch(image) - initial_point_cloud = self._get_initial_point_cloud( image, depth_estimation, - step=10 + step=25 ) self.gaussian_model = GaussianModel(sh_degree) @@ -38,10 +38,11 @@ def __init__(self, image, sh_degree: int = 3): self._camera = self._get_orthogonal_camera(image) - self._iterations = 1000 + self._iterations = 10000 self._lambda_dssim = 0.2 - self._opacity_reset_interval = 1000 + # Densification and pruning + self._opacity_reset_interval = 10001 self._min_opacity = 0.005 self._max_screen_size = 20 self._percent_dense = 0.01 @@ -58,13 +59,20 @@ def run(self): progress_bar = tqdm( range(self._iterations), desc="Training progress" ) + + best_loss, best_iteration, losses = None, 0, [] for iteration in range(self._iterations): - # self.optimizer.update_learning_rate(iteration) + self.optimizer.update_learning_rate(iteration) rendered_image, viewspace_point_tensor, visibility_filter, radii = render( self._camera, self.gaussian_model ) if iteration % 100 == 0: + plt.cla() + plt.plot(losses) + plt.yscale('log') + plt.savefig('artifacts/losses.png') + torchvision.utils.save_image(rendered_image, f"artifacts/rendered_{iteration}.png") gt_image = self._camera.original_image.cuda() @@ -72,6 +80,10 @@ def run(self): loss = (1.0 - self._lambda_dssim) * Ll1 + self._lambda_dssim * ( 1.0 - ssim(rendered_image, gt_image) ) + if best_loss is None or best_loss > loss: + best_loss = loss.cpu().item() + best_iteartion = iteration + losses.append(loss.cpu().item()) loss.backward() @@ -100,10 +112,13 @@ def run(self): progress_bar.set_postfix({ "Loss": f"{loss:.{5}f}", - "Num_visible": f"{visibility_filter.int().sum().item()}" + "Num_visible": + f"{visibility_filter.int().sum().item()}/{len(visibility_filter)}" }) progress_bar.update(1) + print(f"Training done. Best loss = {best_loss} at iteration {best_iteration}.") + torchvision.utils.save_image(rendered_image, f"artifacts/rendered_{iteration}.png") torchvision.utils.save_image(gt_image, f"artifacts/gt.png") diff --git a/gaussian_splatting/trainer.py b/gaussian_splatting/trainer.py index 9a4b43f1a..b06de44e6 100644 --- a/gaussian_splatting/trainer.py +++ b/gaussian_splatting/trainer.py @@ -220,7 +220,9 @@ def _densify_and_prune(self, prune_big_points): # Prune transparent and large gaussians. prune_mask = (self.gaussian_model.get_opacity < self._min_opacity).squeeze() if prune_big_points: + # Viewspace big_points_vs = self.gaussian_model.max_radii2D > self._max_screen_size + # World space big_points_ws = ( self.gaussian_model.get_scaling.max(dim=1).values > 0.1 * self.gaussian_model.camera_extent diff --git a/scripts/train_colmap_free.py b/scripts/train_colmap_free.py index 251946dea..dd11d9047 100644 --- a/scripts/train_colmap_free.py +++ b/scripts/train_colmap_free.py @@ -12,7 +12,5 @@ def main(): local_trainer.run() - import pdb; pdb.set_trace() - if __name__ == "__main__": main()