Skip to content

Commit

Permalink
linear layer to quaternion rotation + translation
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Apr 10, 2024
1 parent 90f3fa3 commit 9594a2d
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 10 deletions.
2 changes: 1 addition & 1 deletion gaussian_splatting/dataset/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ def get_frame(self, i: int):
return image

def __len__(self):
return len(self._image_paths)
return len(self._images_paths)
5 changes: 4 additions & 1 deletion gaussian_splatting/local_initialization_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ def run(self):
)
progress_bar.update(1)

print(f"Training done. Best loss = {best_loss} at iteration {best_iteration}.")
progress_bar.close()
print(
f"Training done. Best loss = {best_loss:.{5}f} at iteration {best_iteration}."
)

torchvision.utils.save_image(
rendered_image, f"artifacts/local/init/rendered_{iteration}.png"
Expand Down
82 changes: 74 additions & 8 deletions gaussian_splatting/local_transformation_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.nn as nn
import torchvision
from matplotlib import pyplot as plt
from tqdm import tqdm
Expand All @@ -9,22 +10,78 @@
from gaussian_splatting.utils.loss import l1_loss, ssim


class TransformationModel(torch.nn.Module):
class QuaternionRotation(nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(in_features=3, out_features=3)
self.quaternion = nn.Parameter(torch.Tensor([[0, 0, 0, 1]]))

torch.nn.init.eye_(self.linear.weight.data)
torch.nn.init.zeros_(self.linear.bias.data)
def forward(self, input_tensor):
rotation_matrix = self.get_rotation_matrix()
rotated_tensor = torch.matmul(input_tensor, rotation_matrix)

return rotated_tensor

def get_rotation_matrix(self):
# Normalize quaternion to ensure unit magnitude
quaternion_norm = torch.norm(self.quaternion, p=2, dim=1, keepdim=True)
normalized_quaternion = self.quaternion / quaternion_norm

x, y, z, w = normalized_quaternion[0]
rotation_matrix = torch.zeros(
3, 3, dtype=torch.float32, device=self.quaternion.device
)
rotation_matrix[0, 0] = 1 - 2 * (y**2 + z**2)
rotation_matrix[0, 1] = 2 * (x * y - w * z)
rotation_matrix[0, 2] = 2 * (x * z + w * y)
rotation_matrix[1, 0] = 2 * (x * y + w * z)
rotation_matrix[1, 1] = 1 - 2 * (x**2 + z**2)
rotation_matrix[1, 2] = 2 * (y * z - w * x)
rotation_matrix[2, 0] = 2 * (x * z - w * y)
rotation_matrix[2, 1] = 2 * (y * z + w * x)
rotation_matrix[2, 2] = 1 - 2 * (x**2 + y**2)

return rotation_matrix


class Translation(nn.Module):
def __init__(self):
super().__init__()
self.translation = nn.Parameter(torch.Tensor(1, 3))
nn.init.zeros_(self.translation)

def forward(self, input_tensor):
translated_tensor = torch.add(self.translation, input_tensor)

return translated_tensor


class TransformationModel(nn.Module):
def __init__(self):
super().__init__()
self._rotation = QuaternionRotation()
self._translation = Translation()

def forward(self, xyz):
transformed_xyz = self.linear(xyz)
transformed_xyz = self._rotation(xyz)
transformed_xyz = self._translation(transformed_xyz)

return transformed_xyz

@property
def rotation(self):
rotation = self._rotation.get_rotation_matrix().detach().cpu()

return rotation

@property
def translation(self):
translation = self._translation.translation.detach().cpu()

return translation


class LocalTransformationTrainer(Trainer):
def __init__(self, image, camera, gaussian_model):
def __init__(self, image, camera, gaussian_model, iterations: int = 100):
self.camera = camera
self.gaussian_model = gaussian_model

Expand All @@ -39,7 +96,7 @@ def __init__(self, image, camera, gaussian_model):
self.transformation_model.parameters(), lr=0.0001
)

self._iterations = 101
self._iterations = iterations
self._lambda_dssim = 0.2

safe_state(seed=2234)
Expand Down Expand Up @@ -87,9 +144,18 @@ def run(self):
)
progress_bar.update(1)

print(f"Training done. Best loss = {best_loss} at iteration {best_iteration}.")
progress_bar.close()
print(
f"Training done. Best loss = {best_loss:.{5}f} at iteration {best_iteration}."
)

torchvision.utils.save_image(
rendered_image, f"artifacts/local/transfo/rendered_{iteration}.png"
)
torchvision.utils.save_image(gt_image, f"artifacts/local/transfo/gt.png")

def get_affine_transformation(self):
rotation = self.transformation_model.rotation.numpy()
translation = self.transformation_model.translation.numpy()

return rotation, translation
7 changes: 7 additions & 0 deletions gaussian_splatting/utils/affine_transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import numpy as np


def apply_affine_transformation(xyz, rotation, translation):
transformed_xyz = np.matmul(xyz, rotation.T) + translation

return transformed_xyz

0 comments on commit 9594a2d

Please sign in to comment.