From 9594a2db01eeede57c2896c9b25b7ced96eb4b6e Mon Sep 17 00:00:00 2001 From: Isabelle Bouchard Date: Tue, 26 Mar 2024 16:33:04 -0400 Subject: [PATCH] linear layer to quaternion rotation + translation --- gaussian_splatting/dataset/image_dataset.py | 2 +- .../local_initialization_trainer.py | 5 +- .../local_transformation_trainer.py | 82 +++++++++++++++++-- .../utils/affine_transformation.py | 7 ++ 4 files changed, 86 insertions(+), 10 deletions(-) create mode 100644 gaussian_splatting/utils/affine_transformation.py diff --git a/gaussian_splatting/dataset/image_dataset.py b/gaussian_splatting/dataset/image_dataset.py index 968a151ac..1317aa349 100644 --- a/gaussian_splatting/dataset/image_dataset.py +++ b/gaussian_splatting/dataset/image_dataset.py @@ -15,4 +15,4 @@ def get_frame(self, i: int): return image def __len__(self): - return len(self._image_paths) + return len(self._images_paths) diff --git a/gaussian_splatting/local_initialization_trainer.py b/gaussian_splatting/local_initialization_trainer.py index 869063d3e..ae4fe13ec 100644 --- a/gaussian_splatting/local_initialization_trainer.py +++ b/gaussian_splatting/local_initialization_trainer.py @@ -115,7 +115,10 @@ def run(self): ) progress_bar.update(1) - print(f"Training done. Best loss = {best_loss} at iteration {best_iteration}.") + progress_bar.close() + print( + f"Training done. Best loss = {best_loss:.{5}f} at iteration {best_iteration}." + ) torchvision.utils.save_image( rendered_image, f"artifacts/local/init/rendered_{iteration}.png" diff --git a/gaussian_splatting/local_transformation_trainer.py b/gaussian_splatting/local_transformation_trainer.py index 52f2ca59c..72814f2cf 100644 --- a/gaussian_splatting/local_transformation_trainer.py +++ b/gaussian_splatting/local_transformation_trainer.py @@ -1,4 +1,5 @@ import torch +import torch.nn as nn import torchvision from matplotlib import pyplot as plt from tqdm import tqdm @@ -9,22 +10,78 @@ from gaussian_splatting.utils.loss import l1_loss, ssim -class TransformationModel(torch.nn.Module): +class QuaternionRotation(nn.Module): def __init__(self): super().__init__() - self.linear = torch.nn.Linear(in_features=3, out_features=3) + self.quaternion = nn.Parameter(torch.Tensor([[0, 0, 0, 1]])) - torch.nn.init.eye_(self.linear.weight.data) - torch.nn.init.zeros_(self.linear.bias.data) + def forward(self, input_tensor): + rotation_matrix = self.get_rotation_matrix() + rotated_tensor = torch.matmul(input_tensor, rotation_matrix) + + return rotated_tensor + + def get_rotation_matrix(self): + # Normalize quaternion to ensure unit magnitude + quaternion_norm = torch.norm(self.quaternion, p=2, dim=1, keepdim=True) + normalized_quaternion = self.quaternion / quaternion_norm + + x, y, z, w = normalized_quaternion[0] + rotation_matrix = torch.zeros( + 3, 3, dtype=torch.float32, device=self.quaternion.device + ) + rotation_matrix[0, 0] = 1 - 2 * (y**2 + z**2) + rotation_matrix[0, 1] = 2 * (x * y - w * z) + rotation_matrix[0, 2] = 2 * (x * z + w * y) + rotation_matrix[1, 0] = 2 * (x * y + w * z) + rotation_matrix[1, 1] = 1 - 2 * (x**2 + z**2) + rotation_matrix[1, 2] = 2 * (y * z - w * x) + rotation_matrix[2, 0] = 2 * (x * z - w * y) + rotation_matrix[2, 1] = 2 * (y * z + w * x) + rotation_matrix[2, 2] = 1 - 2 * (x**2 + y**2) + + return rotation_matrix + + +class Translation(nn.Module): + def __init__(self): + super().__init__() + self.translation = nn.Parameter(torch.Tensor(1, 3)) + nn.init.zeros_(self.translation) + + def forward(self, input_tensor): + translated_tensor = torch.add(self.translation, input_tensor) + + return translated_tensor + + +class TransformationModel(nn.Module): + def __init__(self): + super().__init__() + self._rotation = QuaternionRotation() + self._translation = Translation() def forward(self, xyz): - transformed_xyz = self.linear(xyz) + transformed_xyz = self._rotation(xyz) + transformed_xyz = self._translation(transformed_xyz) return transformed_xyz + @property + def rotation(self): + rotation = self._rotation.get_rotation_matrix().detach().cpu() + + return rotation + + @property + def translation(self): + translation = self._translation.translation.detach().cpu() + + return translation + class LocalTransformationTrainer(Trainer): - def __init__(self, image, camera, gaussian_model): + def __init__(self, image, camera, gaussian_model, iterations: int = 100): self.camera = camera self.gaussian_model = gaussian_model @@ -39,7 +96,7 @@ def __init__(self, image, camera, gaussian_model): self.transformation_model.parameters(), lr=0.0001 ) - self._iterations = 101 + self._iterations = iterations self._lambda_dssim = 0.2 safe_state(seed=2234) @@ -87,9 +144,18 @@ def run(self): ) progress_bar.update(1) - print(f"Training done. Best loss = {best_loss} at iteration {best_iteration}.") + progress_bar.close() + print( + f"Training done. Best loss = {best_loss:.{5}f} at iteration {best_iteration}." + ) torchvision.utils.save_image( rendered_image, f"artifacts/local/transfo/rendered_{iteration}.png" ) torchvision.utils.save_image(gt_image, f"artifacts/local/transfo/gt.png") + + def get_affine_transformation(self): + rotation = self.transformation_model.rotation.numpy() + translation = self.transformation_model.translation.numpy() + + return rotation, translation diff --git a/gaussian_splatting/utils/affine_transformation.py b/gaussian_splatting/utils/affine_transformation.py new file mode 100644 index 000000000..27637ebab --- /dev/null +++ b/gaussian_splatting/utils/affine_transformation.py @@ -0,0 +1,7 @@ +import numpy as np + + +def apply_affine_transformation(xyz, rotation, translation): + transformed_xyz = np.matmul(xyz, rotation.T) + translation + + return transformed_xyz