From 90f3fa390ce9d34978cae66fa321ebd2c311a31e Mon Sep 17 00:00:00 2001 From: Isabelle Bouchard Date: Mon, 25 Mar 2024 17:32:03 -0400 Subject: [PATCH] local 3dgs transformation --- gaussian_splatting/colmap_free_trainer.py | 15 +-- gaussian_splatting/dataset/image_dataset.py | 7 +- ...ner.py => local_initialization_trainer.py} | 92 +++++++++--------- .../local_transformation_trainer.py | 95 +++++++++++++++++++ gaussian_splatting/model.py | 2 +- gaussian_splatting/utils/general.py | 3 +- scripts/train_colmap_free.py | 19 +++- 7 files changed, 167 insertions(+), 66 deletions(-) rename gaussian_splatting/{local_trainer.py => local_initialization_trainer.py} (71%) create mode 100644 gaussian_splatting/local_transformation_trainer.py diff --git a/gaussian_splatting/colmap_free_trainer.py b/gaussian_splatting/colmap_free_trainer.py index 2078719fa..5d89365cd 100644 --- a/gaussian_splatting/colmap_free_trainer.py +++ b/gaussian_splatting/colmap_free_trainer.py @@ -1,11 +1,9 @@ import os import uuid -from random import randint import torch from tqdm import tqdm -from gaussian_splatting.dataset.dataset import Dataset from gaussian_splatting.model import GaussianModel from gaussian_splatting.optimizer import Optimizer from gaussian_splatting.render import render @@ -14,7 +12,6 @@ from gaussian_splatting.utils.loss import l1_loss, ssim - class ColmapFreeTrainer: def __init__( self, @@ -34,28 +31,24 @@ def __init__( safe_state() - def run(self): - progress_bar = tqdm( - range(len(self.dataset)), desc="Training progress" - ) + progress_bar = tqdm(range(len(self.dataset)), desc="Training progress") for iteration in range(len(dataset)): - I_t = self.dataset[i] I_t_plus_1 = self.dataset[i + 1] local_3DGS_trainer = LocalTrainer() - #self.optimizer.update_learning_rate(iteration) + # self.optimizer.update_learning_rate(iteration) # Every 1000 its we increase the levels of SH up to a maximum degree # if iteration % 1000 == 0: # self.gaussian_model.oneupSHdegree() # Pick a random camera - #if not cameras: + # if not cameras: # cameras = self.dataset.get_train_cameras().copy() - #camera = cameras.pop(randint(0, len(cameras) - 1)) + # camera = cameras.pop(randint(0, len(cameras) - 1)) # Render image rendered_image, viewspace_point_tensor, visibility_filter, radii = render( diff --git a/gaussian_splatting/dataset/image_dataset.py b/gaussian_splatting/dataset/image_dataset.py index 48137f12c..968a151ac 100644 --- a/gaussian_splatting/dataset/image_dataset.py +++ b/gaussian_splatting/dataset/image_dataset.py @@ -1,9 +1,9 @@ -from PIL import Image from pathlib import Path -from gaussian_splatting.utils.general import PILtoTorch -class ImageDataset: +from PIL import Image + +class ImageDataset: def __init__(self, images_path: Path): self._images_paths = [f for f in images_path.iterdir()] self._images_paths.sort(key=lambda f: int(f.stem)) @@ -16,4 +16,3 @@ def get_frame(self, i: int): def __len__(self): return len(self._image_paths) - diff --git a/gaussian_splatting/local_trainer.py b/gaussian_splatting/local_initialization_trainer.py similarity index 71% rename from gaussian_splatting/local_trainer.py rename to gaussian_splatting/local_initialization_trainer.py index 889d4dfb6..869063d3e 100644 --- a/gaussian_splatting/local_trainer.py +++ b/gaussian_splatting/local_initialization_trainer.py @@ -1,33 +1,30 @@ import math -from matplotlib import pyplot as plt -import numpy as np -from transformers import pipeline -from tqdm import tqdm +import numpy as np import torch import torchvision +from matplotlib import pyplot as plt +from tqdm import tqdm +from transformers import pipeline -from gaussian_splatting.utils.general import safe_state +from gaussian_splatting.dataset.cameras import Camera from gaussian_splatting.model import GaussianModel from gaussian_splatting.optimizer import Optimizer from gaussian_splatting.render import render +from gaussian_splatting.trainer import Trainer +from gaussian_splatting.utils.general import PILtoTorch, safe_state from gaussian_splatting.utils.graphics import BasicPointCloud -from gaussian_splatting.utils.general import PILtoTorch -from gaussian_splatting.dataset.cameras import Camera from gaussian_splatting.utils.loss import l1_loss, ssim -from gaussian_splatting.trainer import Trainer -class LocalTrainer(Trainer): - def __init__(self, image, sh_degree: int = 3): +class LocalInitializationTrainer(Trainer): + def __init__(self, image, sh_degree: int = 3, iterations: int = 10000): DPT = self._load_DPT() depth_estimation = DPT(image)["predicted_depth"] image = PILtoTorch(image) initial_point_cloud = self._get_initial_point_cloud( - image, - depth_estimation, - step=25 + image, depth_estimation, step=25 ) self.gaussian_model = GaussianModel(sh_degree) @@ -36,9 +33,9 @@ def __init__(self, image, sh_degree: int = 3): self.optimizer = Optimizer(self.gaussian_model) - self._camera = self._get_orthogonal_camera(image) + self.camera = self._get_orthogonal_camera(image) - self._iterations = 10000 + self._iterations = iterations self._lambda_dssim = 0.2 # Densification and pruning @@ -56,33 +53,33 @@ def __init__(self, image, sh_degree: int = 3): safe_state(seed=2234) def run(self): - progress_bar = tqdm( - range(self._iterations), desc="Training progress" - ) + progress_bar = tqdm(range(self._iterations), desc="Initialization") best_loss, best_iteration, losses = None, 0, [] for iteration in range(self._iterations): self.optimizer.update_learning_rate(iteration) rendered_image, viewspace_point_tensor, visibility_filter, radii = render( - self._camera, self.gaussian_model + self.camera, self.gaussian_model ) if iteration % 100 == 0: plt.cla() plt.plot(losses) - plt.yscale('log') - plt.savefig('artifacts/losses.png') + plt.yscale("log") + plt.savefig("artifacts/local/init/losses.png") - torchvision.utils.save_image(rendered_image, f"artifacts/rendered_{iteration}.png") + torchvision.utils.save_image( + rendered_image, f"artifacts/local/init/rendered_{iteration}.png" + ) - gt_image = self._camera.original_image.cuda() + gt_image = self.camera.original_image.cuda() Ll1 = l1_loss(rendered_image, gt_image) loss = (1.0 - self._lambda_dssim) * Ll1 + self._lambda_dssim * ( 1.0 - ssim(rendered_image, gt_image) ) if best_loss is None or best_loss > loss: best_loss = loss.cpu().item() - best_iteartion = iteration + best_iteration = iteration losses.append(loss.cpu().item()) loss.backward() @@ -110,22 +107,25 @@ def run(self): print("Reset Opacity") self._reset_opacity() - progress_bar.set_postfix({ - "Loss": f"{loss:.{5}f}", - "Num_visible": - f"{visibility_filter.int().sum().item()}/{len(visibility_filter)}" - }) + progress_bar.set_postfix( + { + "Loss": f"{loss:.{5}f}", + "Num_visible": f"{visibility_filter.int().sum().item()}/{len(visibility_filter)}", + } + ) progress_bar.update(1) print(f"Training done. Best loss = {best_loss} at iteration {best_iteration}.") - torchvision.utils.save_image(rendered_image, f"artifacts/rendered_{iteration}.png") - torchvision.utils.save_image(gt_image, f"artifacts/gt.png") + torchvision.utils.save_image( + rendered_image, f"artifacts/local/init/rendered_{iteration}.png" + ) + torchvision.utils.save_image(gt_image, f"artifacts/local/init/gt.png") def _get_orthogonal_camera(self, image): camera = Camera( - R=np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]), - T=np.array([-0.5, -0.5, 1.]), + R=np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), + T=np.array([-0.5, -0.5, 1.0]), FoVx=2 * math.atan(0.5), FoVy=2 * math.atan(0.5), image=image, @@ -150,20 +150,24 @@ def _get_initial_point_cloud(self, frame, depth_estimation, step: int = 50): for y in range(step, h - step, step): _depth = depth_estimation[0, x, y].item() # Normalized points - points.append([ - y / h, - x / w, - (_depth - _min_depth) / (_max_depth - _min_depth) - ]) + points.append( + [y / h, x / w, (_depth - _min_depth) / (_max_depth - _min_depth)] + ) # Average RGB color in the window color around selected pixel colors.append( frame[ - :, - x - half_step: x + half_step, - y - half_step: y + half_step - ].mean(axis=[1, 2]).tolist() + :, x - half_step : x + half_step, y - half_step : y + half_step + ] + .mean(axis=[1, 2]) + .tolist() + ) + normals.append( + [ + 0.0, + 0.0, + 0.0, + ] ) - normals.append([0., 0., 0.,]) point_cloud = BasicPointCloud( points=np.array(points), @@ -178,5 +182,3 @@ def _load_DPT(self): depth_estimator = pipeline("depth-estimation", model=checkpoint) return depth_estimator - - diff --git a/gaussian_splatting/local_transformation_trainer.py b/gaussian_splatting/local_transformation_trainer.py new file mode 100644 index 000000000..52f2ca59c --- /dev/null +++ b/gaussian_splatting/local_transformation_trainer.py @@ -0,0 +1,95 @@ +import torch +import torchvision +from matplotlib import pyplot as plt +from tqdm import tqdm + +from gaussian_splatting.render import render +from gaussian_splatting.trainer import Trainer +from gaussian_splatting.utils.general import PILtoTorch, safe_state +from gaussian_splatting.utils.loss import l1_loss, ssim + + +class TransformationModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(in_features=3, out_features=3) + + torch.nn.init.eye_(self.linear.weight.data) + torch.nn.init.zeros_(self.linear.bias.data) + + def forward(self, xyz): + transformed_xyz = self.linear(xyz) + + return transformed_xyz + + +class LocalTransformationTrainer(Trainer): + def __init__(self, image, camera, gaussian_model): + self.camera = camera + self.gaussian_model = gaussian_model + + self.xyz = gaussian_model.get_xyz.detach() + + self.transformation_model = TransformationModel() + self.transformation_model.to(self.xyz.device) + + self.image = PILtoTorch(image).to(self.xyz.device) + + self.optimizer = torch.optim.Adam( + self.transformation_model.parameters(), lr=0.0001 + ) + + self._iterations = 101 + self._lambda_dssim = 0.2 + + safe_state(seed=2234) + + def run(self): + progress_bar = tqdm(range(self._iterations), desc="Transformation") + + best_loss, best_iteration, losses = None, 0, [] + for iteration in range(self._iterations): + xyz = self.transformation_model(self.xyz) + self.gaussian_model.set_optimizable_tensors({"xyz": xyz}) + + rendered_image, viewspace_point_tensor, visibility_filter, radii = render( + self.camera, self.gaussian_model + ) + + if iteration % 10 == 0: + plt.cla() + plt.plot(losses) + plt.yscale("log") + plt.savefig("artifacts/local/transfo/losses.png") + + torchvision.utils.save_image( + rendered_image, f"artifacts/local/transfo/rendered_{iteration}.png" + ) + + gt_image = self.image + Ll1 = l1_loss(rendered_image, gt_image) + loss = (1.0 - self._lambda_dssim) * Ll1 + self._lambda_dssim * ( + 1.0 - ssim(rendered_image, gt_image) + ) + if best_loss is None or best_loss > loss: + best_loss = loss.cpu().item() + best_iteration = iteration + losses.append(loss.cpu().item()) + + loss.backward() + + self.optimizer.step() + + progress_bar.set_postfix( + { + "Loss": f"{loss:.{5}f}", + } + ) + progress_bar.update(1) + + print(f"Training done. Best loss = {best_loss} at iteration {best_iteration}.") + + torchvision.utils.save_image( + rendered_image, f"artifacts/local/transfo/rendered_{iteration}.png" + ) + torchvision.utils.save_image(gt_image, f"artifacts/local/transfo/gt.png") diff --git a/gaussian_splatting/model.py b/gaussian_splatting/model.py index 9c0dba23e..295eef978 100644 --- a/gaussian_splatting/model.py +++ b/gaussian_splatting/model.py @@ -55,7 +55,7 @@ def __init__(self, sh_degree: int = 3): self.inverse_opacity_activation = inverse_sigmoid self.rotation_activation = torch.nn.functional.normalize - self.camera_extent = 1. + self.camera_extent = 1.0 def state_dict(self): state_dict = ( diff --git a/gaussian_splatting/utils/general.py b/gaussian_splatting/utils/general.py index e64bf8b9b..1bfbb236c 100644 --- a/gaussian_splatting/utils/general.py +++ b/gaussian_splatting/utils/general.py @@ -19,7 +19,7 @@ def inverse_sigmoid(x): return torch.log(x / (1 - x)) -def PILtoTorch(pil_image, resolution = None): +def PILtoTorch(pil_image, resolution=None): if resolution is not None: pil_image = pil_image.resize(resolution) @@ -32,6 +32,7 @@ def PILtoTorch(pil_image, resolution = None): return image + def get_expon_lr_func( lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 ): diff --git a/scripts/train_colmap_free.py b/scripts/train_colmap_free.py index dd11d9047..0bb11d69e 100644 --- a/scripts/train_colmap_free.py +++ b/scripts/train_colmap_free.py @@ -1,15 +1,26 @@ from pathlib import Path -from gaussian_splatting.local_trainer import LocalTrainer from gaussian_splatting.dataset.image_dataset import ImageDataset +from gaussian_splatting.local_initialization_trainer import \ + LocalInitializationTrainer +from gaussian_splatting.local_transformation_trainer import \ + LocalTransformationTrainer def main(): dataset = ImageDataset(images_path=Path("data/phil/1/input/")) - image = dataset.get_frame(0) + image_0 = dataset.get_frame(0) + image_1 = dataset.get_frame(10) - local_trainer = LocalTrainer(image) - local_trainer.run() + local_initialization_trainer = LocalInitializationTrainer(image_0, iterations=100) + local_initialization_trainer.run() + + local_transformation_trainer = LocalTransformationTrainer( + image_1, + camera=local_initialization_trainer.camera, + gaussian_model=local_initialization_trainer.gaussian_model, + ) + local_transformation_trainer.run() if __name__ == "__main__":