diff --git a/gaussian_splatting/colmap_free/local_initialization_trainer.py b/gaussian_splatting/colmap_free/local_initialization_trainer.py index 1c19b8ef2..2267a1ccf 100644 --- a/gaussian_splatting/colmap_free/local_initialization_trainer.py +++ b/gaussian_splatting/colmap_free/local_initialization_trainer.py @@ -6,6 +6,7 @@ from matplotlib import pyplot as plt from tqdm import tqdm from transformers import pipeline +from pathlib import Path from gaussian_splatting.dataset.cameras import Camera from gaussian_splatting.model import GaussianModel @@ -50,6 +51,10 @@ def __init__(self, image, sh_degree: int = 3, iterations: int = 10000): safe_state(seed=2234) + self._output_path =Path(" artifacts/local/init/") + self._output_path.mkdir(exist_ok=True, parents=True) + + def run(self, iterations: int = 3000): progress_bar = tqdm(range(iterations), desc="Initialization") @@ -60,16 +65,6 @@ def run(self, iterations: int = 3000): self.camera, self.gaussian_model ) - if iteration % 100 == 0: - plt.cla() - plt.plot(losses) - plt.yscale("log") - plt.savefig("artifacts/local/init/losses.png") - - torchvision.utils.save_image( - rendered_image, f"artifacts/local/init/rendered_{iteration}.png" - ) - gt_image = self.camera.original_image.cuda() loss = self._photometric_loss(rendered_image, gt_image) loss.backward() @@ -82,6 +77,9 @@ def run(self, iterations: int = 3000): best_iteration = iteration losses.append(loss.cpu().item()) + if iteration % 100 == 0: + self._save_artifacts(self, losses, rendered_image, iteration) + with torch.no_grad(): # Densification if iteration < self._densification_iteration_stop: @@ -115,10 +113,7 @@ def run(self, iterations: int = 3000): f"Training done. Best loss = {best_loss:.{5}f} at iteration {best_iteration}." ) - torchvision.utils.save_image( - rendered_image, f"artifacts/local/init/rendered_{iteration}.png" - ) - torchvision.utils.save_image(gt_image, f"artifacts/local/init/gt.png") + torchvision.utils.save_image(gt_image, self._output_path / "gt.png") def _get_orthogonal_camera(self, image): camera = Camera( @@ -180,3 +175,13 @@ def _load_DPT(self): depth_estimator = pipeline("depth-estimation", model=checkpoint) return depth_estimator + + def _save_artifacts(self, losses, rendered_image, iteration): + plt.cla() + plt.plot(losses) + plt.yscale("log") + plt.savefig(self._output_path / "losses.png") + + torchvision.utils.save_image( + rendered_image, self._output_path / f"rendered_{iteration}.png" + ) diff --git a/gaussian_splatting/colmap_free/local_transformation_trainer.py b/gaussian_splatting/colmap_free/local_transformation_trainer.py index 724ee16d1..6aff10e0b 100644 --- a/gaussian_splatting/colmap_free/local_transformation_trainer.py +++ b/gaussian_splatting/colmap_free/local_transformation_trainer.py @@ -2,6 +2,7 @@ import torchvision from matplotlib import pyplot as plt from tqdm import tqdm +from pathlib import Path from gaussian_splatting.colmap_free.transformation_model import \ AffineTransformationModel @@ -23,9 +24,12 @@ def __init__(self, gaussian_model): ) self._photometric_loss = PhotometricLoss(lambda_dssim=0.2) + self._output_path = Path(" artifacts/local/transfo/") + self._output_path.mkdir(exist_ok=True, parents=True) + safe_state(seed=2234) - def run(self, current_camera, gt_image, iterations: int = 1000): + def run(self, current_camera, gt_image, iterations: int = 1000, run: int = 0): gt_image = gt_image.to(self.gaussian_model.get_xyz.device) progress_bar = tqdm(range(iterations), desc="Transformation") @@ -50,18 +54,18 @@ def run(self, current_camera, gt_image, iterations: int = 1000): progress_bar.set_postfix({"Loss": f"{loss:.{5}f}"}) progress_bar.update(1) - if iteration % 10 == 0: - self._save_artifacts(losses, rendered_image, iteration) + if iteration % 10 == 0 or iteration == len(iterations) - 1: + self._save_artifacts(losses, rendered_image, iteration, run) if best_loss is None or best_loss > loss: best_loss = loss.cpu().item() best_iteration = iteration best_xyz = xyz.detach() - elif best_loss < loss and patience > 10: - self._save_artifacts(losses, rendered_image, iteration) - break - else: - patience += 1 + #elif best_loss < loss and patience > 10: + # self._save_artifacts(losses, rendered_image, iteration, run) + # break + #else: + # patience += 1 progress_bar.close() @@ -73,12 +77,12 @@ def run(self, current_camera, gt_image, iterations: int = 1000): return rotation, translation - def _save_artifacts(self, losses, rendered_image, iteration): + def _save_artifacts(self, losses, rendered_image, iteration, run): plt.cla() plt.plot(losses) plt.yscale("log") - plt.savefig("artifacts/local/transfo/losses.png") + plt.savefig(self._output_path / "losses.png") torchvision.utils.save_image( - rendered_image, f"artifacts/local/transfo/rendered_{iteration}.png" + rendered_image, self._output_path / f"{run}_rendered_{iteration}.png" ) diff --git a/scripts/train_colmap_free.py b/scripts/train_colmap_free.py index 97ce97137..da1c7a731 100644 --- a/scripts/train_colmap_free.py +++ b/scripts/train_colmap_free.py @@ -20,10 +20,10 @@ def main(): debug = True - iteration_step_size = 5 - initialization_iterations = 50 - transformation_iterations = 50 - global_iterations = 50 + iteration_step_size = 50 + initialization_iterations = 250 + transformation_iterations = 250 + global_iterations = 5 photometric_loss = PhotometricLoss(lambda_dssim=0.2) dataset = ImageDataset(images_path=Path("data/phil/1/input/")) @@ -53,6 +53,7 @@ def main(): current_camera, next_image, iterations=transformation_iterations, + run=iteration, ) # Add new camera to Global3DGS training cameras @@ -73,8 +74,8 @@ def main(): next_camera_image, _, _, _ = render(next_camera, current_gaussian_model) next_gaussian_image, _, _, _ = render(current_camera, next_gaussian_model) loss = photometric_loss(next_camera_image, next_gaussian_image) - assert loss < 0.01 + print(loss) torchvision.utils.save_image( next_camera_image, f"artifacts/global/next_camera_{iteration}.png" ) @@ -82,6 +83,7 @@ def main(): next_gaussian_image, f"artifacts/global/next_gaussian_{iteration}.png" ) + break # global_trainer.add_camera(next_camera) # global_trainer.run(global_iterations)