diff --git a/gaussian_splatting/local_trainer.py b/gaussian_splatting/local_trainer.py index 133ebd3dd..d3d87bd63 100644 --- a/gaussian_splatting/local_trainer.py +++ b/gaussian_splatting/local_trainer.py @@ -1,3 +1,4 @@ +import math import numpy as np from transformers import pipeline from tqdm import tqdm @@ -5,6 +6,7 @@ import torch import torchvision +from gaussian_splatting.utils.general import safe_state from gaussian_splatting.model import GaussianModel from gaussian_splatting.optimizer import Optimizer from gaussian_splatting.render import render @@ -22,7 +24,11 @@ def __init__(self, image, sh_degree: int = 3): image = PILtoTorch(image) - initial_point_cloud = self._get_initial_point_cloud(image, depth_estimation) + initial_point_cloud = self._get_initial_point_cloud( + image, + depth_estimation, + step=10 + ) self.gaussian_model = GaussianModel(sh_degree) self.gaussian_model.initialize_from_point_cloud(initial_point_cloud) @@ -32,10 +38,10 @@ def __init__(self, image, sh_degree: int = 3): self._camera = self._get_orthogonal_camera(image) - self._iterations = 10000 + self._iterations = 1000 self._lambda_dssim = 0.2 - self._opacity_reset_interval = 3000 + self._opacity_reset_interval = 1000 self._min_opacity = 0.005 self._max_screen_size = 20 self._percent_dense = 0.01 @@ -46,18 +52,20 @@ def __init__(self, image, sh_degree: int = 3): self._debug = True + safe_state(seed=2234) + def run(self): progress_bar = tqdm( range(self._iterations), desc="Training progress" ) for iteration in range(self._iterations): - self.optimizer.update_learning_rate(iteration) + # self.optimizer.update_learning_rate(iteration) rendered_image, viewspace_point_tensor, visibility_filter, radii = render( self._camera, self.gaussian_model ) - if iteration == 0: - torchvision.utils.save_image(rendered_image, f"rendered_{iteration}.png") + if iteration % 100 == 0: + torchvision.utils.save_image(rendered_image, f"artifacts/rendered_{iteration}.png") gt_image = self._camera.original_image.cuda() Ll1 = l1_loss(rendered_image, gt_image) @@ -77,30 +85,34 @@ def run(self): viewspace_point_tensor, visibility_filter, radii ) - # if ( - # iteration >= self._densification_iteration_start - # and iteration % self._densification_interval == 0 - # ): - # self._densify_and_prune( - # iteration > self._opacity_reset_interval - # ) + if ( + iteration >= self._densification_iteration_start + and iteration % self._densification_interval == 0 + ): + self._densify_and_prune( + iteration > self._opacity_reset_interval + ) # Reset opacity interval - # if iteration % self._opacity_reset_interval == 0: - # self._reset_opacity() - - progress_bar.set_postfix({"Loss": f"{loss:.{5}f}"}) + if iteration > 0 and iteration % self._opacity_reset_interval == 0: + print("Reset Opacity") + self._reset_opacity() + + progress_bar.set_postfix({ + "Loss": f"{loss:.{5}f}", + "Num_visible": f"{visibility_filter.int().sum().item()}" + }) progress_bar.update(1) - torchvision.utils.save_image(rendered_image, f"rendered_{iteration}.png") - torchvision.utils.save_image(gt_image, f"gt.png") + torchvision.utils.save_image(rendered_image, f"artifacts/rendered_{iteration}.png") + torchvision.utils.save_image(gt_image, f"artifacts/gt.png") def _get_orthogonal_camera(self, image): camera = Camera( R=np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]), - T=np.array([0.5, 0.5, -1]), - FoVx=1., - FoVy=1., + T=np.array([-0.5, -0.5, 1.]), + FoVx=2 * math.atan(0.5), + FoVy=2 * math.atan(0.5), image=image, gt_alpha_mask=None, image_name="patate", @@ -114,15 +126,19 @@ def _get_initial_point_cloud(self, frame, depth_estimation, step: int = 50): # Frame and depth_estimation width do not exactly match. _, w, h = depth_estimation.shape + _min_depth = depth_estimation.min() + _max_depth = depth_estimation.max() + half_step = step // 2 points, colors, normals = [], [], [] for x in range(step, w - step, step): for y in range(step, h - step, step): - # Normalized h, w + _depth = depth_estimation[0, x, y].item() + # Normalized points points.append([ - x / w, y / h, - depth_estimation[0, x, y].item() + x / w, + (_depth - _min_depth) / (_max_depth - _min_depth) ]) # Average RGB color in the window color around selected pixel colors.append( diff --git a/gaussian_splatting/render.py b/gaussian_splatting/render.py index cc4c5df6b..e3889afea 100644 --- a/gaussian_splatting/render.py +++ b/gaussian_splatting/render.py @@ -76,7 +76,6 @@ def render( colors_precomp=None, cov3D_precomp=None, ) - # Those Gaussians that were frustum culled or had a radius of 0 were not visible. # They will be excluded from value updates used in the splitting criteria. return ( diff --git a/gaussian_splatting/trainer.py b/gaussian_splatting/trainer.py index c110bd324..9a4b43f1a 100644 --- a/gaussian_splatting/trainer.py +++ b/gaussian_splatting/trainer.py @@ -89,7 +89,7 @@ def run(self): if not cameras: cameras = self.dataset.get_train_cameras().copy() camera = cameras.pop(randint(0, len(cameras) - 1)) - import pdb; pdb.set_trace() + # Render image rendered_image, viewspace_point_tensor, visibility_filter, radii = render( camera, self.gaussian_model