Skip to content

Commit

Permalink
cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Apr 10, 2024
1 parent 84afe4b commit 78fd273
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 290 deletions.
Binary file added gaussian_splatting/.DS_Store
Binary file not shown.
86 changes: 86 additions & 0 deletions gaussian_splatting/colmap_free/global_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
import uuid
from random import randint

import torch
from tqdm import tqdm

from gaussian_splatting.optimizer import Optimizer
from gaussian_splatting.render import render
from gaussian_splatting.trainer import Trainer
from gaussian_splatting.utils.general import safe_state
from gaussian_splatting.utils.loss import PhotometricLoss


class GlobalTrainer(Trainer):
def __init__(self, gaussian_model):
self._model_path = self._prepare_model_path()

self.gaussian_model = gaussian_model
self.cameras = []

self.optimizer = Optimizer(self.gaussian_model)
self._photometric_loss = PhotometricLoss(lambda_dssim=0.2)

self._debug = False

# Densification and pruning
self._min_opacity = 0.005
self._max_screen_size = 20
self._percent_dense = 0.01
self._densification_grad_threshold = 0.0002

safe_state()

def add_camera(self, camera):
self.cameras.append(camera)

def run(self, iterations: int = 1000):
ema_loss_for_log = 0.0
cameras = None
first_iter = 1
progress_bar = tqdm(range(first_iter, iterations), desc="Training progress")
for iteration in range(first_iter, iterations + 1):
self.optimizer.update_learning_rate(iteration)

# Every 1000 its we increase the levels of SH up to a maximum degree
if iteration % 1000 == 0:
self.gaussian_model.oneupSHdegree()

# Pick a random camera
if not cameras:
cameras = self.cameras.copy()
camera = cameras.pop(randint(0, len(cameras) - 1))

# Render image
rendered_image, viewspace_point_tensor, visibility_filter, radii = render(
camera, self.gaussian_model
)

# Loss
gt_image = camera.original_image.cuda()
loss = self._photometric_loss(rendered_image, gt_image)
loss.backward()

# Optimizer step
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)

# Progress bar
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
progress_bar.update(1)

progress_bar.close()

point_cloud_path = os.path.join(
self._model_path, "point_cloud/iteration_{}".format(iteration)
)
self.gaussian_model.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))

# Densification
self.gaussian_model.update_stats(
viewspace_point_tensor, visibility_filter, radii
)
self._densify_and_prune(True)
self._reset_opacity()
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def run(self, iterations: int = 3000):
best_iteration = iteration
losses.append(loss.cpu().item())


with torch.no_grad():
# Densification
if iteration < self._densification_iteration_stop:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,90 +1,21 @@
import torch
import torch.nn as nn
import torchvision
from matplotlib import pyplot as plt
from tqdm import tqdm

from gaussian_splatting.colmap_free.transformation_model import \
AffineTransformationModel
from gaussian_splatting.render import render
from gaussian_splatting.trainer import Trainer
from gaussian_splatting.utils.general import safe_state
from gaussian_splatting.utils.loss import PhotometricLoss


class QuaternionRotation(nn.Module):
def __init__(self):
super().__init__()
self.quaternion = nn.Parameter(torch.Tensor([[0, 0, 0, 1]]))

def forward(self, input_tensor):
rotation_matrix = self.get_rotation_matrix()
rotated_tensor = torch.matmul(input_tensor, rotation_matrix)

return rotated_tensor

def get_rotation_matrix(self):
# Normalize quaternion to ensure unit magnitude
quaternion_norm = torch.norm(self.quaternion, p=2, dim=1, keepdim=True)
normalized_quaternion = self.quaternion / quaternion_norm

x, y, z, w = normalized_quaternion[0]
rotation_matrix = torch.zeros(
3, 3, dtype=torch.float32, device=self.quaternion.device
)
rotation_matrix[0, 0] = 1 - 2 * (y**2 + z**2)
rotation_matrix[0, 1] = 2 * (x * y - w * z)
rotation_matrix[0, 2] = 2 * (x * z + w * y)
rotation_matrix[1, 0] = 2 * (x * y + w * z)
rotation_matrix[1, 1] = 1 - 2 * (x**2 + z**2)
rotation_matrix[1, 2] = 2 * (y * z - w * x)
rotation_matrix[2, 0] = 2 * (x * z - w * y)
rotation_matrix[2, 1] = 2 * (y * z + w * x)
rotation_matrix[2, 2] = 1 - 2 * (x**2 + y**2)

return rotation_matrix


class Translation(nn.Module):
def __init__(self):
super().__init__()
self.translation = nn.Parameter(torch.Tensor(1, 3))
nn.init.zeros_(self.translation)

def forward(self, input_tensor):
translated_tensor = torch.add(self.translation, input_tensor)

return translated_tensor


class TransformationModel(nn.Module):
def __init__(self):
super().__init__()
self._rotation = QuaternionRotation()
self._translation = Translation()

def forward(self, xyz):
transformed_xyz = self._rotation(xyz)
transformed_xyz = self._translation(transformed_xyz)

return transformed_xyz

@property
def rotation(self):
rotation = self._rotation.get_rotation_matrix().detach().cpu()

return rotation

@property
def translation(self):
translation = self._translation.translation.detach().cpu()

return translation


class LocalTransformationTrainer(Trainer):
def __init__(self, gaussian_model):
self.gaussian_model = gaussian_model

self.transformation_model = TransformationModel()
self.transformation_model = AffineTransformationModel()
self.transformation_model.to(gaussian_model.get_xyz.device)

self.optimizer = torch.optim.Adam(
Expand Down
72 changes: 72 additions & 0 deletions gaussian_splatting/colmap_free/transformation_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch
from torch import nn


class QuaternionRotationLayer(nn.Module):
def __init__(self):
super().__init__()
self.quaternion = nn.Parameter(torch.Tensor([[0, 0, 0, 1]]))

def forward(self, input_tensor):
rotation_matrix = self.get_rotation_matrix()
rotated_tensor = torch.matmul(input_tensor, rotation_matrix)

return rotated_tensor

def get_rotation_matrix(self):
# Normalize quaternion to ensure unit magnitude
quaternion_norm = torch.norm(self.quaternion, p=2, dim=1, keepdim=True)
normalized_quaternion = self.quaternion / quaternion_norm

x, y, z, w = normalized_quaternion[0]
rotation_matrix = torch.zeros(
3, 3, dtype=torch.float32, device=self.quaternion.device
)
rotation_matrix[0, 0] = 1 - 2 * (y**2 + z**2)
rotation_matrix[0, 1] = 2 * (x * y - w * z)
rotation_matrix[0, 2] = 2 * (x * z + w * y)
rotation_matrix[1, 0] = 2 * (x * y + w * z)
rotation_matrix[1, 1] = 1 - 2 * (x**2 + z**2)
rotation_matrix[1, 2] = 2 * (y * z - w * x)
rotation_matrix[2, 0] = 2 * (x * z - w * y)
rotation_matrix[2, 1] = 2 * (y * z + w * x)
rotation_matrix[2, 2] = 1 - 2 * (x**2 + y**2)

return rotation_matrix


class TranslationLayer(nn.Module):
def __init__(self):
super().__init__()
self.translation = nn.Parameter(torch.Tensor(1, 3))
nn.init.zeros_(self.translation)

def forward(self, input_tensor):
translated_tensor = torch.add(self.translation, input_tensor)

return translated_tensor


class AffineTransformationModel(nn.Module):
def __init__(self):
super().__init__()
self._rotation = QuaternionRotationLayer()
self._translation = TranslationLayer()

def forward(self, xyz):
transformed_xyz = self._rotation(xyz)
transformed_xyz = self._translation(transformed_xyz)

return transformed_xyz

@property
def rotation(self):
rotation = self._rotation.get_rotation_matrix().detach().cpu()

return rotation

@property
def translation(self):
translation = self._translation.translation.detach().cpu()

return translation
Loading

0 comments on commit 78fd273

Please sign in to comment.