Skip to content

Commit

Permalink
tune parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Apr 23, 2024
1 parent abd82ff commit 05c703f
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 28 deletions.
1 change: 0 additions & 1 deletion gaussian_splatting/pose_free/depth_estimator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from transformers import pipeline

from gaussian_splatting.dataset.image_dataset import ImageDataset
from gaussian_splatting.utils.general import TorchToPIL


Expand Down
21 changes: 10 additions & 11 deletions gaussian_splatting/pose_free/global_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,15 @@


class GlobalTrainer(Trainer):
def __init__(self, gaussian_model, iterations: int = 100, output_path=None):
def __init__(self, gaussian_model, output_path=None):
self._model_path = self._prepare_model_path(output_path)

self.gaussian_model = gaussian_model

self.optimizer = Optimizer(self.gaussian_model)
self._photometric_loss = PhotometricLoss(lambda_dssim=0.2)

self._iterations = iterations

self._debug = False
self._debug = True

# Densification and pruning
self._min_opacity = 0.005
Expand All @@ -28,9 +26,9 @@ def __init__(self, gaussian_model, iterations: int = 100, output_path=None):

safe_state()

def run(self, current_camera, next_camera, progress_bar=None, run_id: int = 0):
def run(self, current_camera, next_camera, iterations: int = 100, progress_bar=None, run_id: int = 0):
cameras = (current_camera, next_camera)
for iteration in range(self._iterations):
for iteration in range(iterations):
self.optimizer.update_learning_rate(iteration)

# Every 1000 its we increase the levels of SH up to a maximum degree
Expand Down Expand Up @@ -58,7 +56,7 @@ def run(self, current_camera, next_camera, progress_bar=None, run_id: int = 0):
progress_bar.set_postfix(
{
"stage": "global",
"iteration": f"{iteration}/{self._iterations}",
"iteration": f"{iteration}/{iterations}",
"loss": f"{loss_value:.5f}",
}
)
Expand All @@ -68,8 +66,9 @@ def run(self, current_camera, next_camera, progress_bar=None, run_id: int = 0):
)

# Densification
self.gaussian_model.update_stats(
viewspace_point_tensor, visibility_filter, radii
)
self._densify_and_prune(True)
#self.gaussian_model.update_stats(
# viewspace_point_tensor, visibility_filter, radii
#)
#self._split_points()
# self._densify_and_prune(True)
# self._reset_opacity()
27 changes: 17 additions & 10 deletions gaussian_splatting/pose_free/local_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,22 @@ def __init__(
debug: bool = False,
):
self._depth_estimator = DepthEstimator()
self._point_cloud_step = 25
self._point_cloud_step = 2
self._sh_degree = sh_degree

self._photometric_loss = PhotometricLoss(lambda_dssim=0.2)
self._init_loss = PhotometricLoss(lambda_dssim=0.2)
self._transfo_loss = PhotometricLoss(lambda_dssim=0.2)

self._init_iterations = init_iterations
self._init_early_stopper = EarlyStopper(patience=10)
self._init_early_stopper = EarlyStopper(patience=50)
self._init_save_artifacts_iterations = 100

self._transfo_lr = 0.00001
self._transfo_lr = 1.e-5
self._transfo_iterations = transfo_iterations
self._transfo_early_stopper = EarlyStopper(
patience=100,
patience=25,
)
self._transfo_save_artifacts_iterations = 10
self._transfo_save_artifacts_iterations = 25

self._debug = debug

Expand All @@ -63,7 +64,7 @@ def run_init(self, image, camera, progress_bar=None, run_id: int = 0):

rendered_image, _, _, _ = render(camera, gaussian_model)

loss = self._photometric_loss(rendered_image, image)
loss = self._init_loss(rendered_image, image)
loss.backward()
loss_value = loss.cpu().item()
losses.append(loss_value)
Expand Down Expand Up @@ -115,19 +116,24 @@ def run_transfo(

rendered_image, _, _, _ = render(camera, gaussian_model)

loss = self._photometric_loss(rendered_image, image)
loss = self._transfo_loss(rendered_image, image)
loss.backward()
loss_value = loss.cpu().item()
losses.append(loss_value)

optimizer.step()

if self._transfo_early_stopper.step(loss_value):
transformation = self._transfo_early_stopper.get_best_params()
best_params = self._transfo_early_stopper.get_best_params()
gaussian_model.set_optimizable_tensors({"xyz": best_params["xyz"]})
transformation = best_params["transformation"]
break
else:
transformation = transformation_model.transformation
self._transfo_early_stopper.set_best_params(transformation)
self._transfo_early_stopper.set_best_params({
"transformation": transformation,
"xyz": xyz
})

if self._debug and (
iteration % self._transfo_save_artifacts_iterations == 0
Expand All @@ -149,6 +155,7 @@ def run_transfo(
)

if self._debug:
rendered_image, _, _, _ = render(camera, gaussian_model)
self._save_artifacts(losses, rendered_image, output_path, "best")
save_image(image, output_path / f"ground_truth.png")

Expand Down
24 changes: 19 additions & 5 deletions gaussian_splatting/pose_free/pose_free_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
import copy

from torchvision.utils import save_image
from tqdm import tqdm
Expand All @@ -16,7 +17,7 @@ def __init__(self, source_path: Path):
self._debug = True

self._dataset = ImageDataset(
images_path=source_path, step_size=5, downscale_factor=1
images_path=source_path, step_size=10, downscale_factor=1
)

self._local_trainer = LocalTrainer(
Expand All @@ -31,7 +32,7 @@ def run(self):
initial_gaussian_model = self._local_trainer.get_initial_gaussian_model(
current_image, self._output_path
)
global_trainer = GlobalTrainer(initial_gaussian_model, iterations=1000)
global_trainer = GlobalTrainer(initial_gaussian_model)

current_camera = get_orthogonal_camera(current_image)

Expand All @@ -42,6 +43,9 @@ def run(self):
gaussian_model = self._local_trainer.run_init(
current_image, current_camera, progress_bar, run_id=i
)
if self._debug:
current_gaussian_model = copy.deepcopy(gaussian_model)

rotation, translation = self._local_trainer.run_transfo(
next_image,
current_camera,
Expand All @@ -52,12 +56,22 @@ def run(self):
next_camera = transform_camera(
current_camera, rotation, translation, next_image, _id=i
)
global_trainer.run(current_camera, next_camera, progress_bar, run_id=i)
global_trainer.run(
current_camera,
next_camera,
iterations=(1000 if i == 0 else 100),
progress_bar=progress_bar,
run_id=i
)

if self._debug:
rendered_image, _, _, _ = render(next_camera, gaussian_model)
rendered_image, _, _, _ = render(next_camera, current_gaussian_model)
save_image(
rendered_image, self._output_path / f"{i}_camera_rendered_image.png"
)
rendered_image, _, _, _ = render(current_camera, gaussian_model)
save_image(
rendered_image, self._output_path / f"{i}_rendered_image.png"
rendered_image, self._output_path / f"{i}_gaussian_rendered_image.png"
)
save_image(next_image, self._output_path / f"{i}_image.png")

Expand Down
8 changes: 7 additions & 1 deletion gaussian_splatting/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@


class PhotometricLoss:
def __init__(self, lambda_dssim: float = 0.2):
def __init__(self, lambda_dssim: float = 0.2, mask_white_pixels: bool = False):
self._lambda_dssim = lambda_dssim
self._mask_white_pixels = mask_white_pixels

def __call__(self, network_output, gt):
if self._mask_white_pixels:
mask = (network_output != 1.0).int()
network_output = network_output * mask
gt = gt * mask

l1_value = l1_loss(network_output, gt)
ssim_value = ssim(network_output, gt)

Expand Down

0 comments on commit 05c703f

Please sign in to comment.