diff --git a/gaussian_splatting/colmap_free_trainer.py b/gaussian_splatting/colmap_free_trainer.py new file mode 100644 index 000000000..2078719fa --- /dev/null +++ b/gaussian_splatting/colmap_free_trainer.py @@ -0,0 +1,241 @@ +import os +import uuid +from random import randint + +import torch +from tqdm import tqdm + +from gaussian_splatting.dataset.dataset import Dataset +from gaussian_splatting.model import GaussianModel +from gaussian_splatting.optimizer import Optimizer +from gaussian_splatting.render import render +from gaussian_splatting.utils.general import safe_state +from gaussian_splatting.utils.image import psnr +from gaussian_splatting.utils.loss import l1_loss, ssim + + + +class ColmapFreeTrainer: + def __init__( + self, + source_path, + keep_eval=False, + resolution=-1, + sh_degree=3, + checkpoint_path=None, + ): + self._model_path = self._prepare_model_path() + + self.dataset = ImageDataset(images_path=source_path) + + self.global_3DGS = GaussianModel(sh_degree) + self.global_3DGS.initialize(self.dataset) + self.global_3DGS_optimizer = Optimizer(self.global_3DGS) + + safe_state() + + + def run(self): + progress_bar = tqdm( + range(len(self.dataset)), desc="Training progress" + ) + for iteration in range(len(dataset)): + + I_t = self.dataset[i] + I_t_plus_1 = self.dataset[i + 1] + + local_3DGS_trainer = LocalTrainer() + + #self.optimizer.update_learning_rate(iteration) + + # Every 1000 its we increase the levels of SH up to a maximum degree + # if iteration % 1000 == 0: + # self.gaussian_model.oneupSHdegree() + + # Pick a random camera + #if not cameras: + # cameras = self.dataset.get_train_cameras().copy() + #camera = cameras.pop(randint(0, len(cameras) - 1)) + + # Render image + rendered_image, viewspace_point_tensor, visibility_filter, radii = render( + camera, self.gaussian_model + ) + + # Loss + gt_image = camera.original_image.cuda() + Ll1 = l1_loss(rendered_image, gt_image) + loss = (1.0 - self._lambda_dssim) * Ll1 + self._lambda_dssim * ( + 1.0 - ssim(rendered_image, gt_image) + ) + + # try: + loss.backward() + # except Exception: + # import pdb; pdb.set_trace() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + if iteration % 10 == 0: + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) + progress_bar.update(10) + if iteration == self._iterations: + progress_bar.close() + + # Log and save + if iteration in self._testing_iterations: + self._report(iteration) + + if iteration in self._saving_iterations: + print("\n[ITER {}] Saving Gaussians".format(iteration)) + point_cloud_path = os.path.join( + self.model_path, "point_cloud/iteration_{}".format(iteration) + ) + self.gaussian_model.save_ply( + os.path.join(point_cloud_path, "point_cloud.ply") + ) + + # Densification + if iteration < self._densification_iteration_stop: + self.gaussian_model.update_stats( + viewspace_point_tensor, visibility_filter, radii + ) + + if ( + iteration >= self._densification_iteration_start + and iteration % self._densification_interval == 0 + ): + self._densify_and_prune( + iteration > self._opacity_reset_interval + ) + + # Reset opacity interval + if iteration % self._opacity_reset_interval == 0: + self._reset_opacity() + + # Optimizer step + if iteration < self._iterations: + self.optimizer.step() + self.optimizer.zero_grad(set_to_none=True) + + # Save checkpoint + if iteration in self._checkpoint_iterations: + print("\n[ITER {}] Saving Checkpoint".format(iteration)) + torch.save( + ( + self.gaussian_model.state_dict(), + self.optimizer.state_dict(), + iteration, + ), + self.model_path + "/chkpnt" + str(iteration) + ".pth", + ) + + def _prepare_model_path(self): + unique_str = str(uuid.uuid4()) + model_path = os.path.join("./output/", unique_str[0:10]) + + # Set up output folder + print("Output folder: {}".format(model_path)) + os.makedirs(model_path, exist_ok=True) + + return model_path + + def _report(self, iteration): + # Report test and samples of training set + torch.cuda.empty_cache() + validation_configs = { + "test": self.dataset.get_test_cameras(), + "train": [ + self.dataset.get_train_cameras()[ + idx % len(self.dataset.get_train_cameras()) + ] + for idx in range(5, 30, 5) + ], + } + + for config_name, cameras in validation_configs: + if not cameras or len(cameras) == 0: + continue + + l1_test, psnr_test = 0.0, 0.0 + for idx, camera in enumerate(cameras): + rendered_image, _, _, _ = render(camera, self.gaussian_model) + gt_image = camera.original_image.to("cuda") + + rendered_image = torch.clamp(rendered_image, 0.0, 1.0) + gt_image = torch.clamp(gt, 0.0, 1.0) + + l1_test += l1_loss(image, gt_image).mean().double() + psnr_test += psnr(image, gt_image).mean().double() + + psnr_test /= len(cameras) + l1_test /= len(cameras) + + print( + f"\n[ITER {iteration}] Evaluating {config_name}: L1 {l1_test} PSNR {psnr_test}" + ) + + torch.cuda.empty_cache() + + def _densify_and_prune(self, prune_big_points): + # Clone large gaussian in over-reconstruction areas + self._clone_points() + # Split small gaussians in under-construction areas. + self._split_points() + + # Prune transparent and large gaussians. + prune_mask = (self.gaussian_model.get_opacity < self._min_opacity).squeeze() + if prune_big_points: + big_points_vs = self.gaussian_model.max_radii2D > self._max_screen_size + big_points_ws = ( + self.gaussian_model.get_scaling.max(dim=1).values + > 0.1 * self.gaussian_model.camera_extent + ) + prune_mask = torch.logical_or( + torch.logical_or(prune_mask, big_points_vs), big_points_ws + ) + if self._debug: + print(f"Pruning: {prune_mask.sum().item()} points.") + self._prune_points(valid_mask=~prune_mask) + + torch.cuda.empty_cache() + + def _split_points(self): + new_points, split_mask = self.gaussian_model.split_points( + self._densification_grad_threshold, self._percent_dense + ) + self._concatenate_points(new_points) + + prune_mask = torch.cat( + ( + split_mask, + torch.zeros(2 * split_mask.sum(), device="cuda", dtype=bool), + ) + ) + if self._debug: + print(f"Densification: split {split_mask.sum().item()} points.") + self._prune_points(valid_mask=~prune_mask) + + def _clone_points(self): + new_points, clone_mask = self.gaussian_model.clone_points( + self._densification_grad_threshold, self._percent_dense + ) + if self._debug: + print(f"Densification: clone {clone_mask.sum().item()} points.") + self._concatenate_points(new_points) + + def _reset_opacity(self): + new_opacity = self.gaussian_model.reset_opacity() + optimizable_tensors = self.optimizer.replace_points(new_opacity, "opacity") + self.gaussian_model.set_optimizable_tensors(optimizable_tensors) + + def _prune_points(self, valid_mask): + optimizable_tensors = self.optimizer.prune_points(valid_mask) + self.gaussian_model.set_optimizable_tensors(optimizable_tensors) + self.gaussian_model.mask_stats(valid_mask) + + def _concatenate_points(self, new_tensors): + optimizable_tensors = self.optimizer.concatenate_points(new_tensors) + self.gaussian_model.set_optimizable_tensors(optimizable_tensors) + self.gaussian_model.reset_stats() diff --git a/gaussian_splatting/dataset/colmap_loader.py b/gaussian_splatting/dataset/colmap_loader.py index c9fcb6f16..415ff3da9 100644 --- a/gaussian_splatting/dataset/colmap_loader.py +++ b/gaussian_splatting/dataset/colmap_loader.py @@ -251,6 +251,7 @@ def read_extrinsics_binary(path_to_model_file): xys=xys, point3D_ids=point3D_ids, ) + return images diff --git a/gaussian_splatting/dataset/image_dataset.py b/gaussian_splatting/dataset/image_dataset.py new file mode 100644 index 000000000..48137f12c --- /dev/null +++ b/gaussian_splatting/dataset/image_dataset.py @@ -0,0 +1,19 @@ +from PIL import Image +from pathlib import Path +from gaussian_splatting.utils.general import PILtoTorch + +class ImageDataset: + + def __init__(self, images_path: Path): + self._images_paths = [f for f in images_path.iterdir()] + self._images_paths.sort(key=lambda f: int(f.stem)) + + def get_frame(self, i: int): + image_path = self._images_paths[i] + image = Image.open(image_path) + + return image + + def __len__(self): + return len(self._image_paths) + diff --git a/gaussian_splatting/local_trainer.py b/gaussian_splatting/local_trainer.py new file mode 100644 index 000000000..133ebd3dd --- /dev/null +++ b/gaussian_splatting/local_trainer.py @@ -0,0 +1,151 @@ +import numpy as np +from transformers import pipeline +from tqdm import tqdm + +import torch +import torchvision + +from gaussian_splatting.model import GaussianModel +from gaussian_splatting.optimizer import Optimizer +from gaussian_splatting.render import render +from gaussian_splatting.utils.graphics import BasicPointCloud +from gaussian_splatting.utils.general import PILtoTorch +from gaussian_splatting.dataset.cameras import Camera +from gaussian_splatting.utils.loss import l1_loss, ssim +from gaussian_splatting.trainer import Trainer + + +class LocalTrainer(Trainer): + def __init__(self, image, sh_degree: int = 3): + DPT = self._load_DPT() + depth_estimation = DPT(image)["predicted_depth"] + + image = PILtoTorch(image) + + initial_point_cloud = self._get_initial_point_cloud(image, depth_estimation) + + self.gaussian_model = GaussianModel(sh_degree) + self.gaussian_model.initialize_from_point_cloud(initial_point_cloud) + # TODO: set camera extent??? + + self.optimizer = Optimizer(self.gaussian_model) + + self._camera = self._get_orthogonal_camera(image) + + self._iterations = 10000 + self._lambda_dssim = 0.2 + + self._opacity_reset_interval = 3000 + self._min_opacity = 0.005 + self._max_screen_size = 20 + self._percent_dense = 0.01 + self._densification_interval = 100 + self._densification_iteration_start = 500 + self._densification_iteration_stop = 15000 + self._densification_grad_threshold = 0.0002 + + self._debug = True + + def run(self): + progress_bar = tqdm( + range(self._iterations), desc="Training progress" + ) + for iteration in range(self._iterations): + self.optimizer.update_learning_rate(iteration) + rendered_image, viewspace_point_tensor, visibility_filter, radii = render( + self._camera, self.gaussian_model + ) + + if iteration == 0: + torchvision.utils.save_image(rendered_image, f"rendered_{iteration}.png") + + gt_image = self._camera.original_image.cuda() + Ll1 = l1_loss(rendered_image, gt_image) + loss = (1.0 - self._lambda_dssim) * Ll1 + self._lambda_dssim * ( + 1.0 - ssim(rendered_image, gt_image) + ) + + loss.backward() + + self.optimizer.step() + self.optimizer.zero_grad(set_to_none=True) + + with torch.no_grad(): + # Densification + if iteration < self._densification_iteration_stop: + self.gaussian_model.update_stats( + viewspace_point_tensor, visibility_filter, radii + ) + + # if ( + # iteration >= self._densification_iteration_start + # and iteration % self._densification_interval == 0 + # ): + # self._densify_and_prune( + # iteration > self._opacity_reset_interval + # ) + + # Reset opacity interval + # if iteration % self._opacity_reset_interval == 0: + # self._reset_opacity() + + progress_bar.set_postfix({"Loss": f"{loss:.{5}f}"}) + progress_bar.update(1) + + torchvision.utils.save_image(rendered_image, f"rendered_{iteration}.png") + torchvision.utils.save_image(gt_image, f"gt.png") + + def _get_orthogonal_camera(self, image): + camera = Camera( + R=np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]), + T=np.array([0.5, 0.5, -1]), + FoVx=1., + FoVy=1., + image=image, + gt_alpha_mask=None, + image_name="patate", + colmap_id=0, + uid=0, + ) + + return camera + + def _get_initial_point_cloud(self, frame, depth_estimation, step: int = 50): + # Frame and depth_estimation width do not exactly match. + _, w, h = depth_estimation.shape + + half_step = step // 2 + points, colors, normals = [], [], [] + for x in range(step, w - step, step): + for y in range(step, h - step, step): + # Normalized h, w + points.append([ + x / w, + y / h, + depth_estimation[0, x, y].item() + ]) + # Average RGB color in the window color around selected pixel + colors.append( + frame[ + :, + x - half_step: x + half_step, + y - half_step: y + half_step + ].mean(axis=[1, 2]).tolist() + ) + normals.append([0., 0., 0.,]) + + point_cloud = BasicPointCloud( + points=np.array(points), + colors=np.array(colors), + normals=np.array(normals), + ) + + return point_cloud + + def _load_DPT(self): + checkpoint = "vinvino02/glpn-nyu" + depth_estimator = pipeline("depth-estimation", model=checkpoint) + + return depth_estimator + + diff --git a/gaussian_splatting/model.py b/gaussian_splatting/model.py index d9b4438ba..9c0dba23e 100644 --- a/gaussian_splatting/model.py +++ b/gaussian_splatting/model.py @@ -55,6 +55,8 @@ def __init__(self, sh_degree: int = 3): self.inverse_opacity_activation = inverse_sigmoid self.rotation_activation = torch.nn.functional.normalize + self.camera_extent = 1. + def state_dict(self): state_dict = ( self.active_sh_degree, @@ -67,6 +69,7 @@ def state_dict(self): self.max_radii2D, self.xyz_gradient_accum, self.denom, + self.camera_extent, ) return state_dict @@ -133,9 +136,11 @@ def oneupSHdegree(self): def initialize(self, dataset): self.camera_extent = dataset.scene_info.nerf_normalization["radius"] - point_cloud = dataset.scene_info.point_cloud + self.initialize_from_point_cloud(point_cloud) + + def initialize_from_point_cloud(self, point_cloud): fused_point_cloud = torch.tensor(np.asarray(point_cloud.points)).float().cuda() fused_color = RGB2SH( torch.tensor(np.asarray(point_cloud.colors)).float().cuda() diff --git a/gaussian_splatting/trainer.py b/gaussian_splatting/trainer.py index fd322869b..c110bd324 100644 --- a/gaussian_splatting/trainer.py +++ b/gaussian_splatting/trainer.py @@ -89,7 +89,7 @@ def run(self): if not cameras: cameras = self.dataset.get_train_cameras().copy() camera = cameras.pop(randint(0, len(cameras) - 1)) - + import pdb; pdb.set_trace() # Render image rendered_image, viewspace_point_tensor, visibility_filter, radii = render( camera, self.gaussian_model @@ -102,10 +102,7 @@ def run(self): 1.0 - ssim(rendered_image, gt_image) ) - # try: loss.backward() - # except Exception: - # import pdb; pdb.set_trace() if iteration in self._saving_iterations: print("\n[ITER {}] Saving Gaussians".format(iteration)) diff --git a/gaussian_splatting/utils/general.py b/gaussian_splatting/utils/general.py index c55064af1..e64bf8b9b 100644 --- a/gaussian_splatting/utils/general.py +++ b/gaussian_splatting/utils/general.py @@ -19,14 +19,18 @@ def inverse_sigmoid(x): return torch.log(x / (1 - x)) -def PILtoTorch(pil_image, resolution): - resized_image_PIL = pil_image.resize(resolution) - resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 - if len(resized_image.shape) == 3: - return resized_image.permute(2, 0, 1) - else: - return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) +def PILtoTorch(pil_image, resolution = None): + if resolution is not None: + pil_image = pil_image.resize(resolution) + image = torch.from_numpy(np.array(pil_image)) / 255.0 + + if len(image.shape) > 3: + image = image.unsqueeze(dim=-1) + + image = image.permute(2, 0, 1) + + return image def get_expon_lr_func( lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 diff --git a/scripts/train_colmap_free.py b/scripts/train_colmap_free.py new file mode 100644 index 000000000..251946dea --- /dev/null +++ b/scripts/train_colmap_free.py @@ -0,0 +1,18 @@ +from pathlib import Path + +from gaussian_splatting.local_trainer import LocalTrainer +from gaussian_splatting.dataset.image_dataset import ImageDataset + + +def main(): + dataset = ImageDataset(images_path=Path("data/phil/1/input/")) + image = dataset.get_frame(0) + + local_trainer = LocalTrainer(image) + local_trainer.run() + + + import pdb; pdb.set_trace() + +if __name__ == "__main__": + main()