From e6f309745e5a70ed386667303df76b3e38917cb9 Mon Sep 17 00:00:00 2001 From: Isabelle Bouchard Date: Fri, 12 Apr 2024 17:15:22 -0400 Subject: [PATCH] gradual global training --- .gitignore | 1 + gaussian_splatting/dataset/image_dataset.py | 21 +++- gaussian_splatting/model.py | 2 - .../pose_free/global_trainer.py | 49 ++++----- gaussian_splatting/pose_free/local_trainer.py | 102 ++++++++++-------- .../pose_free/pose_free_trainer.py | 73 +++++-------- gaussian_splatting/utils/early_stopper.py | 21 ++-- 7 files changed, 134 insertions(+), 135 deletions(-) diff --git a/.gitignore b/.gitignore index f587c3713..f9fe57bc7 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ tensorboard_3d screenshots data/ gaussian_splatting.egg-info/ +artifacts/ diff --git a/gaussian_splatting/dataset/image_dataset.py b/gaussian_splatting/dataset/image_dataset.py index ad6b80c5b..424fa99fd 100644 --- a/gaussian_splatting/dataset/image_dataset.py +++ b/gaussian_splatting/dataset/image_dataset.py @@ -6,19 +6,34 @@ class ImageDataset: - def __init__(self, images_path: Path, step_size: int = 1): + def __init__( + self, images_path: Path, step_size: int = 1, downscale_factor: int = 1 + ): self._images_paths = [ f for i, f in enumerate(images_path.iterdir()) if i % step_size == 0 ] self._images_paths.sort(key=lambda f: int(f.stem)) + self._downscale_factor = downscale_factor + + def __len__(self): + return len(self._images_paths) + def get_frame(self, i: int): image_path = self._images_paths[i] image = Image.open(image_path) + if self._downscale_factor > 1: + image = self._downscale(image) + image = PILtoTorch(image) return image - def __len__(self): - return len(self._images_paths) + def _downscale(self, image): + h, w = image.size + image = image.resize( + (h // self._downscale_factor, w // self._downscale_factor), Image.LANCZOS + ) + + return image diff --git a/gaussian_splatting/model.py b/gaussian_splatting/model.py index 295eef978..c8533d8d5 100644 --- a/gaussian_splatting/model.py +++ b/gaussian_splatting/model.py @@ -153,8 +153,6 @@ def initialize_from_point_cloud(self, point_cloud): features[:, :3, 0] = fused_color features[:, 3:, 1:] = 0.0 - print("Number of points at initialisation : ", fused_point_cloud.shape[0]) - dist2 = torch.clamp_min( distCUDA2(torch.from_numpy(np.asarray(point_cloud.points)).float().cuda()), 0.0000001, diff --git a/gaussian_splatting/pose_free/global_trainer.py b/gaussian_splatting/pose_free/global_trainer.py index 5146510ba..95edca85e 100644 --- a/gaussian_splatting/pose_free/global_trainer.py +++ b/gaussian_splatting/pose_free/global_trainer.py @@ -1,7 +1,4 @@ -import os -from random import randint - -from tqdm import tqdm +from pathlib import Path from gaussian_splatting.optimizer import Optimizer from gaussian_splatting.render import render @@ -11,15 +8,16 @@ class GlobalTrainer(Trainer): - def __init__(self, gaussian_model, output_path=None): + def __init__(self, gaussian_model, iterations: int = 100, output_path=None): self._model_path = self._prepare_model_path(output_path) self.gaussian_model = gaussian_model - self.cameras = [] self.optimizer = Optimizer(self.gaussian_model) self._photometric_loss = PhotometricLoss(lambda_dssim=0.2) + self._iterations = iterations + self._debug = False # Densification and pruning @@ -30,25 +28,16 @@ def __init__(self, gaussian_model, output_path=None): safe_state() - def add_camera(self, camera): - self.cameras.append(camera) - - def run(self, iterations: int = 1000): - ema_loss_for_log = 0.0 - cameras = None - first_iter = 1 - progress_bar = tqdm(range(first_iter, iterations), desc="Training progress") - for iteration in range(first_iter, iterations + 1): + def run(self, current_camera, next_camera, progress_bar=None, run_id: int = 0): + cameras = (current_camera, next_camera) + for iteration in range(self._iterations): self.optimizer.update_learning_rate(iteration) # Every 1000 its we increase the levels of SH up to a maximum degree if iteration % 1000 == 0: self.gaussian_model.oneupSHdegree() - # Pick a random camera - if not cameras: - cameras = self.cameras.copy() - camera = cameras.pop(randint(0, len(cameras) - 1)) + camera = cameras[iteration % 2] # Render image rendered_image, viewspace_point_tensor, visibility_filter, radii = render( @@ -59,22 +48,24 @@ def run(self, iterations: int = 1000): gt_image = camera.original_image.cuda() loss = self._photometric_loss(rendered_image, gt_image) loss.backward() + loss_value = loss.cpu().item() # Optimizer step self.optimizer.step() self.optimizer.zero_grad(set_to_none=True) - # Progress bar - ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log - progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) - progress_bar.update(1) - - progress_bar.close() - - point_cloud_path = os.path.join( - self._model_path, "point_cloud/iteration_{}".format(iteration) + if progress_bar is not None: + progress_bar.set_postfix( + { + "stage": "global", + "iteration": f"{iteration}/{self._iterations}", + "loss": f"{loss_value:.5f}", + } + ) + + self.gaussian_model.save_ply( + Path(self._model_path) / "point_cloud" / str(run_id) / "point_cloud.ply" ) - self.gaussian_model.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) # Densification self.gaussian_model.update_stats( diff --git a/gaussian_splatting/pose_free/local_trainer.py b/gaussian_splatting/pose_free/local_trainer.py index d79e92b9d..a9be7478b 100644 --- a/gaussian_splatting/pose_free/local_trainer.py +++ b/gaussian_splatting/pose_free/local_trainer.py @@ -21,10 +21,11 @@ class LocalTrainer: def __init__( self, sh_degree: int = 3, - init_iterations: int = 250, - transfo_iterations: int = 250, + init_iterations: int = 1000, + transfo_iterations: int = 1000, + debug: bool = False, ): - self._depth_estimator = self._load_depth_estimator() + self._depth_estimator = pipeline("depth-estimation", model="vinvino02/glpn-nyu") self._point_cloud_step = 25 self._sh_degree = sh_degree @@ -32,26 +33,28 @@ def __init__( self._init_iterations = init_iterations self._init_early_stopper = EarlyStopper(patience=10) - self._init_save_artifacts_iterations = 50 + self._init_save_artifacts_iterations = 100 - self._transfo_lr = 0.0001 + self._transfo_lr = 0.00001 self._transfo_iterations = transfo_iterations - self._transfo_early_stopper = EarlyStopper(patience=10) - self._transfo_save_artifacts_iterations = 100 + self._transfo_early_stopper = EarlyStopper( + patience=100, + ) + self._transfo_save_artifacts_iterations = 10 - self._debug = True + self._debug = debug self._output_path = Path("artifacts/local/") - self._output_path.mkdir(exist_ok=True, parents=True) safe_state(seed=2234) - def run_init(self, image, camera, run_id: int = 0): - output_path = self._output_path / "init" + def run_init(self, image, camera, progress_bar=None, run_id: int = 0): + output_path = self._output_path / "init" / str(run_id) output_path.mkdir(exist_ok=True, parents=True) - gaussian_model = self._get_initial_gaussian_model(image) + gaussian_model = self.get_initial_gaussian_model(image, output_path) optimizer = Optimizer(gaussian_model) + self._init_early_stopper.reset() image = image.cuda() losses = [] @@ -69,31 +72,37 @@ def run_init(self, image, camera, run_id: int = 0): optimizer.zero_grad(set_to_none=True) if self._init_early_stopper.step(loss_value): - self._init_early_stopper.print_early_stop() break - if ( - self._debug - or iteration % self._init_save_artifacts_iterations == 0 - or iteration == self._init_iterations - 1 - ): - self._save_artifacts( - losses, rendered_image, output_path / str(run_id), iteration + if self._debug and (iteration % self._init_save_artifacts_iterations == 0): + self._save_artifacts(losses, rendered_image, output_path, iteration) + + if progress_bar is not None: + progress_bar.set_postfix( + { + "stage": "init", + "iteration": f"{iteration}/{self._init_iterations}", + "loss": f"{loss_value:.5f}", + } ) if self._debug: - save_image(image, output_path / f"{run_id}_ground_truth.png") + self._save_artifacts(losses, rendered_image, output_path, "best") + save_image(image, output_path / "ground_truth.png") return gaussian_model - def run_transfo(self, image, camera, gaussian_model, run_id: int = 0): - output_path = self._output_path / "transfo" + def run_transfo( + self, image, camera, gaussian_model, progress_bar=None, run_id: int = 0 + ): + output_path = self._output_path / "transfo" / str(run_id) output_path.mkdir(exist_ok=True, parents=True) transformation_model = AffineTransformationModel() optimizer = torch.optim.Adam( transformation_model.parameters(), lr=self._transfo_lr ) + self._transfo_early_stopper.reset() image = image.cuda() transformation_model = transformation_model.cuda() @@ -114,36 +123,48 @@ def run_transfo(self, image, camera, gaussian_model, run_id: int = 0): optimizer.step() if self._transfo_early_stopper.step(loss_value): - self._transfo_early_stopper.print_early_stop() - transformation = self._transfo_early_stopper.get_best_params( - transformation - ) + transformation = self._transfo_early_stopper.get_best_params() break else: transformation = transformation_model.transformation - self._init_early_stopper.set_best_params(transformation) + self._transfo_early_stopper.set_best_params(transformation) - if ( - self._debug - or iteration % self._transfo_save_artifacts_iterations == 0 - or iteration == self._transfo_iterations - 1 + if self._debug and ( + iteration % self._transfo_save_artifacts_iterations == 0 ): self._save_artifacts( losses, rendered_image, - output_path / str(run_id), + output_path, iteration, ) + if progress_bar is not None: + progress_bar.set_postfix( + { + "stage": "transfo", + "iteration": f"{iteration}/{self._transfo_iterations}", + "loss": f"{loss_value:.5f}", + } + ) + if self._debug: - save_image(image, output_path / f"{run_id}_ground_truth.png") + self._save_artifacts(losses, rendered_image, output_path, "best") + save_image(image, output_path / f"ground_truth.png") return transformation - def _get_initial_gaussian_model(self, image): + def get_initial_gaussian_model(self, image, output_folder: Path = None): PIL_image = TorchToPIL(image) - depth_estimation = self._depth_estimator(PIL_image)["predicted_depth"] + + if self._debug and output_folder is not None: + _min, _max = depth_estimation.min().item(), depth_estimation.max().item() + save_image( + (depth_estimation - _min) / (_max - _min), + output_folder / f"depth_estimation_{_min:.3f}_{_max:.3f}.png", + ) + point_cloud = self._get_initial_point_cloud_from_depth_estimation( image, depth_estimation, step=self._point_cloud_step ) @@ -195,17 +216,10 @@ def _get_initial_point_cloud_from_depth_estimation( return point_cloud - def _load_depth_estimator(self): - checkpoint = "vinvino02/glpn-nyu" - depth_estimator = pipeline("depth-estimation", model=checkpoint) - - return depth_estimator - def _save_artifacts(self, losses, rendered_image, output_path, iteration): - output_path.mkdir(exist_ok=True, parents=True) plt.cla() plt.plot(losses) plt.yscale("log") plt.savefig(output_path / "losses.png") - save_image(rendered_image, self._output_path / f"rendered_{iteration}.png") + save_image(rendered_image, output_path / f"rendered_{iteration}.png") diff --git a/gaussian_splatting/pose_free/pose_free_trainer.py b/gaussian_splatting/pose_free/pose_free_trainer.py index ae864d858..85b6d4c85 100644 --- a/gaussian_splatting/pose_free/pose_free_trainer.py +++ b/gaussian_splatting/pose_free/pose_free_trainer.py @@ -1,7 +1,6 @@ -import copy from pathlib import Path -import torchvision +from torchvision.utils import save_image from tqdm import tqdm from gaussian_splatting.dataset.image_dataset import ImageDataset @@ -10,81 +9,57 @@ from gaussian_splatting.render import render from gaussian_splatting.utils.camera import (get_orthogonal_camera, transform_camera) -from gaussian_splatting.utils.loss import PhotometricLoss class PoseFreeTrainer: def __init__(self, source_path: Path): self._debug = True - self._initialization_iterations = 1000 - self._transformation_iterations = 250 - self._global_iterations = 50 + self._dataset = ImageDataset( + images_path=source_path, step_size=5, downscale_factor=1 + ) - self._photometric_loss = PhotometricLoss(lambda_dssim=0.2) - self._dataset = ImageDataset(images_path=source_path, step_size=5) + self._local_trainer = LocalTrainer( + init_iterations=1000, transfo_iterations=1000, debug=True + ) self._output_path = Path("artifacts/global") self._output_path.mkdir(exist_ok=True, parents=True) - self._local_trainer = LocalTrainer(init_iterations=25, transfo_iterations=25) - def run(self): current_image = self._dataset.get_frame(0) + initial_gaussian_model = self._local_trainer.get_initial_gaussian_model( + current_image, self._output_path + ) + global_trainer = GlobalTrainer(initial_gaussian_model, iterations=1000) + current_camera = get_orthogonal_camera(current_image) - global_trainer = self._initialize_global_trainer(current_image, current_camera) - for i in tqdm(range(1, len(self._dataset))): + progress_bar = tqdm(range(1, len(self._dataset))) + for i in progress_bar: next_image = self._dataset.get_frame(i) gaussian_model = self._local_trainer.run_init( - current_image, current_camera, run_id=i + current_image, current_camera, progress_bar, run_id=i ) rotation, translation = self._local_trainer.run_transfo( next_image, current_camera, gaussian_model, + progress_bar, run_id=i, ) next_camera = transform_camera( current_camera, rotation, translation, next_image, _id=i ) - global_trainer.add_camera(next_camera) + global_trainer.run(current_camera, next_camera, progress_bar, run_id=i) + + if self._debug: + rendered_image, _, _, _ = render(next_camera, gaussian_model) + save_image( + rendered_image, self._output_path / f"{i}_rendered_image.png" + ) + save_image(next_image, self._output_path / f"{i}_image.png") current_image = next_image current_camera = next_camera - - global_trainer.run(self._global_iterations) - - def _initialize_global_trainer(self, initial_image, initial_camera): - initial_gaussian_model = self._local_trainer.run_init( - initial_image, initial_camera, run_id=0 - ) - - global_gaussian_model = copy.deepcopy(initial_gaussian_model) - global_trainer = GlobalTrainer(global_gaussian_model) - - return global_trainer - - def _save_artifacts( - self, - current_camera, - current_gaussian_model, - current_image, - next_camera, - next_gaussian_model, - next_image, - iteration, - ): - current_camera_image, _, _, _ = render(current_camera, current_gaussian_model) - next_camera_image, _, _, _ = render(next_camera, current_gaussian_model) - next_gaussian_image, _, _, _ = render(current_camera, next_gaussian_model) - - for image, filename in [ - (current_camera_image, f"{iteration}_current_camera.png"), - (next_camera_image, f"{iteration}_next_camera.png"), - (next_gaussian_image, f"{iteration}_next_gaussian.png"), - (current_image, f"{iteration}_current_image.png"), - (next_image, f"{iteration}_next_image.png"), - ]: - torchvision.utils.save_image(image, self._output_path / filename) diff --git a/gaussian_splatting/utils/early_stopper.py b/gaussian_splatting/utils/early_stopper.py index cbeeeaa05..1d0cfa072 100644 --- a/gaussian_splatting/utils/early_stopper.py +++ b/gaussian_splatting/utils/early_stopper.py @@ -1,9 +1,5 @@ class EarlyStopper: - def __init__( - self, - patience: int, - tolerance: float = 0.0, - ): + def __init__(self, patience: int, tolerance: float = 0.0, verbose: bool = False): self.patience = patience self.current_epoch = 0 self.epochs_since_best_loss = 0 @@ -13,6 +9,8 @@ def __init__( self._tolerance = tolerance + self._verbose = verbose + def step(self, loss: float) -> bool: self.current_epoch += 1 @@ -24,13 +22,20 @@ def step(self, loss: float) -> bool: stop = self.epochs_since_best_loss == self.patience + if self._verbose: + print(f"Early stopping after {self.current_epoch} epochs.") + return stop + def reset(self): + self.current_epoch = 0 + self.epochs_since_best_loss = 0 + self.best_loss = float("inf") + + self._best_params = None + def set_best_params(self, best_params): self._best_params = best_params def get_best_params(self): return self._best_params - - def print_early_stop(self): - print(f"Early stopping after {self.current_epoch} epochs.")