Skip to content

Commit

Permalink
Fix orthogonal camera initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Apr 10, 2024
1 parent 3c6f5f8 commit 1c4dd8e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 27 deletions.
66 changes: 41 additions & 25 deletions gaussian_splatting/local_trainer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import math
import numpy as np
from transformers import pipeline
from tqdm import tqdm

import torch
import torchvision

from gaussian_splatting.utils.general import safe_state
from gaussian_splatting.model import GaussianModel
from gaussian_splatting.optimizer import Optimizer
from gaussian_splatting.render import render
Expand All @@ -22,7 +24,11 @@ def __init__(self, image, sh_degree: int = 3):

image = PILtoTorch(image)

initial_point_cloud = self._get_initial_point_cloud(image, depth_estimation)
initial_point_cloud = self._get_initial_point_cloud(
image,
depth_estimation,
step=10
)

self.gaussian_model = GaussianModel(sh_degree)
self.gaussian_model.initialize_from_point_cloud(initial_point_cloud)
Expand All @@ -32,10 +38,10 @@ def __init__(self, image, sh_degree: int = 3):

self._camera = self._get_orthogonal_camera(image)

self._iterations = 10000
self._iterations = 1000
self._lambda_dssim = 0.2

self._opacity_reset_interval = 3000
self._opacity_reset_interval = 1000
self._min_opacity = 0.005
self._max_screen_size = 20
self._percent_dense = 0.01
Expand All @@ -46,18 +52,20 @@ def __init__(self, image, sh_degree: int = 3):

self._debug = True

safe_state(seed=2234)

def run(self):
progress_bar = tqdm(
range(self._iterations), desc="Training progress"
)
for iteration in range(self._iterations):
self.optimizer.update_learning_rate(iteration)
# self.optimizer.update_learning_rate(iteration)
rendered_image, viewspace_point_tensor, visibility_filter, radii = render(
self._camera, self.gaussian_model
)

if iteration == 0:
torchvision.utils.save_image(rendered_image, f"rendered_{iteration}.png")
if iteration % 100 == 0:
torchvision.utils.save_image(rendered_image, f"artifacts/rendered_{iteration}.png")

gt_image = self._camera.original_image.cuda()
Ll1 = l1_loss(rendered_image, gt_image)
Expand All @@ -77,30 +85,34 @@ def run(self):
viewspace_point_tensor, visibility_filter, radii
)

# if (
# iteration >= self._densification_iteration_start
# and iteration % self._densification_interval == 0
# ):
# self._densify_and_prune(
# iteration > self._opacity_reset_interval
# )
if (
iteration >= self._densification_iteration_start
and iteration % self._densification_interval == 0
):
self._densify_and_prune(
iteration > self._opacity_reset_interval
)

# Reset opacity interval
# if iteration % self._opacity_reset_interval == 0:
# self._reset_opacity()

progress_bar.set_postfix({"Loss": f"{loss:.{5}f}"})
if iteration > 0 and iteration % self._opacity_reset_interval == 0:
print("Reset Opacity")
self._reset_opacity()

progress_bar.set_postfix({
"Loss": f"{loss:.{5}f}",
"Num_visible": f"{visibility_filter.int().sum().item()}"
})
progress_bar.update(1)

torchvision.utils.save_image(rendered_image, f"rendered_{iteration}.png")
torchvision.utils.save_image(gt_image, f"gt.png")
torchvision.utils.save_image(rendered_image, f"artifacts/rendered_{iteration}.png")
torchvision.utils.save_image(gt_image, f"artifacts/gt.png")

def _get_orthogonal_camera(self, image):
camera = Camera(
R=np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]),
T=np.array([0.5, 0.5, -1]),
FoVx=1.,
FoVy=1.,
T=np.array([-0.5, -0.5, 1.]),
FoVx=2 * math.atan(0.5),
FoVy=2 * math.atan(0.5),
image=image,
gt_alpha_mask=None,
image_name="patate",
Expand All @@ -114,15 +126,19 @@ def _get_initial_point_cloud(self, frame, depth_estimation, step: int = 50):
# Frame and depth_estimation width do not exactly match.
_, w, h = depth_estimation.shape

_min_depth = depth_estimation.min()
_max_depth = depth_estimation.max()

half_step = step // 2
points, colors, normals = [], [], []
for x in range(step, w - step, step):
for y in range(step, h - step, step):
# Normalized h, w
_depth = depth_estimation[0, x, y].item()
# Normalized points
points.append([
x / w,
y / h,
depth_estimation[0, x, y].item()
x / w,
(_depth - _min_depth) / (_max_depth - _min_depth)
])
# Average RGB color in the window color around selected pixel
colors.append(
Expand Down
1 change: 0 additions & 1 deletion gaussian_splatting/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def render(
colors_precomp=None,
cov3D_precomp=None,
)

# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
# They will be excluded from value updates used in the splitting criteria.
return (
Expand Down
2 changes: 1 addition & 1 deletion gaussian_splatting/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def run(self):
if not cameras:
cameras = self.dataset.get_train_cameras().copy()
camera = cameras.pop(randint(0, len(cameras) - 1))
import pdb; pdb.set_trace()

# Render image
rendered_image, viewspace_point_tensor, visibility_filter, radii = render(
camera, self.gaussian_model
Expand Down

0 comments on commit 1c4dd8e

Please sign in to comment.