Skip to content

Commit

Permalink
Transformation functional
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Apr 11, 2024
1 parent 6bfe455 commit fd1e9b9
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 20 deletions.
2 changes: 0 additions & 2 deletions gaussian_splatting/colmap_free/global_trainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
import uuid
from random import randint

import torch
from tqdm import tqdm

from gaussian_splatting.optimizer import Optimizer
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import math
from pathlib import Path

import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from tqdm import tqdm
from transformers import pipeline
from pathlib import Path

from gaussian_splatting.dataset.cameras import Camera
from gaussian_splatting.model import GaussianModel
Expand Down Expand Up @@ -51,10 +51,9 @@ def __init__(self, image, sh_degree: int = 3, iterations: int = 10000):

safe_state(seed=2234)

self._output_path =Path(" artifacts/local/init/")
self._output_path = Path("artifacts/local/init/")
self._output_path.mkdir(exist_ok=True, parents=True)


def run(self, iterations: int = 3000):
progress_bar = tqdm(range(iterations), desc="Initialization")

Expand All @@ -78,7 +77,7 @@ def run(self, iterations: int = 3000):
losses.append(loss.cpu().item())

if iteration % 100 == 0:
self._save_artifacts(self, losses, rendered_image, iteration)
self._save_artifacts(losses, rendered_image, iteration)

with torch.no_grad():
# Densification
Expand Down
28 changes: 19 additions & 9 deletions gaussian_splatting/colmap_free/local_transformation_trainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from pathlib import Path

import torch
import torchvision
from matplotlib import pyplot as plt
from tqdm import tqdm
from pathlib import Path

from gaussian_splatting.colmap_free.transformation_model import \
AffineTransformationModel
Expand All @@ -24,7 +25,7 @@ def __init__(self, gaussian_model):
)
self._photometric_loss = PhotometricLoss(lambda_dssim=0.2)

self._output_path = Path(" artifacts/local/transfo/")
self._output_path = Path("artifacts/local/transfo/")
self._output_path.mkdir(exist_ok=True, parents=True)

safe_state(seed=2234)
Expand All @@ -36,9 +37,11 @@ def run(self, current_camera, gt_image, iterations: int = 1000, run: int = 0):

losses = []
best_loss, best_iteration, best_xyz = None, 0, None
best_rotation, best_translation = None, None
patience = 0
initial_xyz = self.gaussian_model.get_xyz.detach()
for iteration in range(iterations):
xyz = self.transformation_model(self.gaussian_model.get_xyz.detach())
xyz = self.transformation_model(initial_xyz)
self.gaussian_model.set_optimizable_tensors({"xyz": xyz})

rendered_image, viewspace_point_tensor, visibility_filter, radii = render(
Expand All @@ -54,28 +57,35 @@ def run(self, current_camera, gt_image, iterations: int = 1000, run: int = 0):
progress_bar.set_postfix({"Loss": f"{loss:.{5}f}"})
progress_bar.update(1)

if iteration % 10 == 0 or iteration == len(iterations) - 1:
if iteration % 10 == 0 or iteration == iterations - 1:
self._save_artifacts(losses, rendered_image, iteration, run)

if best_loss is None or best_loss > loss:
best_loss = loss.cpu().item()
best_iteration = iteration
best_xyz = xyz.detach()
#elif best_loss < loss and patience > 10:

best_rotation = self.transformation_model.rotation.numpy()
best_translation = self.transformation_model.translation.numpy()

# elif best_loss < loss and patience > 10:
# self._save_artifacts(losses, rendered_image, iteration, run)
# break
#else:
# else:
# patience += 1

progress_bar.close()

torchvision.utils.save_image(gt_image, self._output_path / "gt.png")

print(f"Best loss = {best_loss:.{5}f} at iteration {best_iteration}.")
self.gaussian_model.set_optimizable_tensors({"xyz": best_xyz})

rotation = self.transformation_model.rotation.numpy()
translation = self.transformation_model.translation.numpy()
if best_rotation is None or best_translation is None:
best_rotation = self.transformation_model.rotation.numpy()
best_translation = self.transformation_model.translation.numpy()

return rotation, translation
return best_rotation, best_translation

def _save_artifacts(self, losses, rendered_image, iteration, run):
plt.cla()
Expand Down
16 changes: 12 additions & 4 deletions scripts/train_colmap_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torchvision
from tqdm import tqdm

from gaussian_splatting.colmap_free.global_trainer import GlobalTrainer
# from gaussian_splatting.colmap_free.global_trainer import GlobalTrainer
from gaussian_splatting.colmap_free.local_initialization_trainer import \
LocalInitializationTrainer
from gaussian_splatting.colmap_free.local_transformation_trainer import \
Expand All @@ -22,7 +22,7 @@ def main():
debug = True
iteration_step_size = 50
initialization_iterations = 250
transformation_iterations = 250
transformation_iterations = 1000
global_iterations = 5

photometric_loss = PhotometricLoss(lambda_dssim=0.2)
Expand Down Expand Up @@ -71,16 +71,24 @@ def main():

if debug:
# Save artifact
current_camera_image, _, _, _ = render(
current_camera, current_gaussian_model
)
next_camera_image, _, _, _ = render(next_camera, current_gaussian_model)
next_gaussian_image, _, _, _ = render(current_camera, next_gaussian_model)
loss = photometric_loss(next_camera_image, next_gaussian_image)

print(loss)
output_path = Path("artifacts/global")
output_path.mkdir(exist_ok=True, parents=True)
torchvision.utils.save_image(
current_camera_image, output_path / f"current_camera_{iteration}.png"
)
torchvision.utils.save_image(
next_camera_image, f"artifacts/global/next_camera_{iteration}.png"
next_camera_image, output_path / f"next_camera_{iteration}.png"
)
torchvision.utils.save_image(
next_gaussian_image, f"artifacts/global/next_gaussian_{iteration}.png"
next_gaussian_image, output_path / f"next_gaussian_{iteration}.png"
)

break
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"pyyaml",
"black",
"isort",
"importchecker"
"importchecker",
"matplotlib",
"transformers"
]
)

0 comments on commit fd1e9b9

Please sign in to comment.