Skip to content

Commit

Permalink
no opacity reset in local trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Apr 10, 2024
1 parent 1c4dd8e commit 8cbad5f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
27 changes: 21 additions & 6 deletions gaussian_splatting/local_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from matplotlib import pyplot as plt
import numpy as np
from transformers import pipeline
from tqdm import tqdm
Expand All @@ -23,11 +24,10 @@ def __init__(self, image, sh_degree: int = 3):
depth_estimation = DPT(image)["predicted_depth"]

image = PILtoTorch(image)

initial_point_cloud = self._get_initial_point_cloud(
image,
depth_estimation,
step=10
step=25
)

self.gaussian_model = GaussianModel(sh_degree)
Expand All @@ -38,10 +38,11 @@ def __init__(self, image, sh_degree: int = 3):

self._camera = self._get_orthogonal_camera(image)

self._iterations = 1000
self._iterations = 10000
self._lambda_dssim = 0.2

self._opacity_reset_interval = 1000
# Densification and pruning
self._opacity_reset_interval = 10001
self._min_opacity = 0.005
self._max_screen_size = 20
self._percent_dense = 0.01
Expand All @@ -58,20 +59,31 @@ def run(self):
progress_bar = tqdm(
range(self._iterations), desc="Training progress"
)

best_loss, best_iteration, losses = None, 0, []
for iteration in range(self._iterations):
# self.optimizer.update_learning_rate(iteration)
self.optimizer.update_learning_rate(iteration)
rendered_image, viewspace_point_tensor, visibility_filter, radii = render(
self._camera, self.gaussian_model
)

if iteration % 100 == 0:
plt.cla()
plt.plot(losses)
plt.yscale('log')
plt.savefig('artifacts/losses.png')

torchvision.utils.save_image(rendered_image, f"artifacts/rendered_{iteration}.png")

gt_image = self._camera.original_image.cuda()
Ll1 = l1_loss(rendered_image, gt_image)
loss = (1.0 - self._lambda_dssim) * Ll1 + self._lambda_dssim * (
1.0 - ssim(rendered_image, gt_image)
)
if best_loss is None or best_loss > loss:
best_loss = loss.cpu().item()
best_iteartion = iteration
losses.append(loss.cpu().item())

loss.backward()

Expand Down Expand Up @@ -100,10 +112,13 @@ def run(self):

progress_bar.set_postfix({
"Loss": f"{loss:.{5}f}",
"Num_visible": f"{visibility_filter.int().sum().item()}"
"Num_visible":
f"{visibility_filter.int().sum().item()}/{len(visibility_filter)}"
})
progress_bar.update(1)

print(f"Training done. Best loss = {best_loss} at iteration {best_iteration}.")

torchvision.utils.save_image(rendered_image, f"artifacts/rendered_{iteration}.png")
torchvision.utils.save_image(gt_image, f"artifacts/gt.png")

Expand Down
2 changes: 2 additions & 0 deletions gaussian_splatting/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def _densify_and_prune(self, prune_big_points):
# Prune transparent and large gaussians.
prune_mask = (self.gaussian_model.get_opacity < self._min_opacity).squeeze()
if prune_big_points:
# Viewspace
big_points_vs = self.gaussian_model.max_radii2D > self._max_screen_size
# World space
big_points_ws = (
self.gaussian_model.get_scaling.max(dim=1).values
> 0.1 * self.gaussian_model.camera_extent
Expand Down
2 changes: 0 additions & 2 deletions scripts/train_colmap_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,5 @@ def main():
local_trainer.run()


import pdb; pdb.set_trace()

if __name__ == "__main__":
main()

0 comments on commit 8cbad5f

Please sign in to comment.