Skip to content

Commit

Permalink
clean train_colmap_free training loop
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Apr 11, 2024
1 parent fd1e9b9 commit 0df8051
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 93 deletions.
26 changes: 4 additions & 22 deletions gaussian_splatting/colmap_free/local_initialization_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
from pathlib import Path

import numpy as np
Expand All @@ -8,22 +7,20 @@
from tqdm import tqdm
from transformers import pipeline

from gaussian_splatting.dataset.cameras import Camera
from gaussian_splatting.model import GaussianModel
from gaussian_splatting.optimizer import Optimizer
from gaussian_splatting.render import render
from gaussian_splatting.trainer import Trainer
from gaussian_splatting.utils.general import PILtoTorch, safe_state
from gaussian_splatting.utils.general import TorchToPIL, safe_state
from gaussian_splatting.utils.graphics import BasicPointCloud
from gaussian_splatting.utils.loss import PhotometricLoss


class LocalInitializationTrainer(Trainer):
def __init__(self, image, sh_degree: int = 3, iterations: int = 10000):
def __init__(self, image, camera, sh_degree: int = 3, iterations: int = 10000):
DPT = self._load_DPT()
depth_estimation = DPT(image)["predicted_depth"]
depth_estimation = DPT(TorchToPIL(image))["predicted_depth"]

image = PILtoTorch(image)
initial_point_cloud = self._get_initial_point_cloud(
image, depth_estimation, step=25
)
Expand All @@ -35,7 +32,7 @@ def __init__(self, image, sh_degree: int = 3, iterations: int = 10000):
self.optimizer = Optimizer(self.gaussian_model)
self._photometric_loss = PhotometricLoss(lambda_dssim=0.2)

self.camera = self._get_orthogonal_camera(image)
self.camera = camera

# Densification and pruning
self._opacity_reset_interval = 10001
Expand Down Expand Up @@ -114,21 +111,6 @@ def run(self, iterations: int = 3000):

torchvision.utils.save_image(gt_image, self._output_path / "gt.png")

def _get_orthogonal_camera(self, image):
camera = Camera(
R=np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]),
T=np.array([-0.5, -0.5, 1.0]),
FoVx=2 * math.atan(0.5),
FoVy=2 * math.atan(0.5),
image=image,
gt_alpha_mask=None,
image_name="patate",
colmap_id=0,
uid=0,
)

return camera

def _get_initial_point_cloud(self, frame, depth_estimation, step: int = 50):
# Frame and depth_estimation width do not exactly match.
_, w, h = depth_estimation.shape
Expand Down
12 changes: 4 additions & 8 deletions gaussian_splatting/colmap_free/local_transformation_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, gaussian_model):
self.transformation_model.to(gaussian_model.get_xyz.device)

self.optimizer = torch.optim.Adam(
self.transformation_model.parameters(), lr=0.0001
self.transformation_model.parameters(), lr=0.0005
)
self._photometric_loss = PhotometricLoss(lambda_dssim=0.2)

Expand Down Expand Up @@ -57,14 +57,14 @@ def run(self, current_camera, gt_image, iterations: int = 1000, run: int = 0):
progress_bar.set_postfix({"Loss": f"{loss:.{5}f}"})
progress_bar.update(1)

if iteration % 10 == 0 or iteration == iterations - 1:
if iteration % 50 == 0 or iteration == iterations - 1:
self._save_artifacts(losses, rendered_image, iteration, run)

if best_loss is None or best_loss > loss:
best_loss = loss.cpu().item()
best_iteration = iteration
best_xyz = xyz.detach()

best_xyz = xyz.detach()
best_rotation = self.transformation_model.rotation.numpy()
best_translation = self.transformation_model.translation.numpy()

Expand All @@ -76,15 +76,11 @@ def run(self, current_camera, gt_image, iterations: int = 1000, run: int = 0):

progress_bar.close()

torchvision.utils.save_image(gt_image, self._output_path / "gt.png")
torchvision.utils.save_image(gt_image, self._output_path / f"{run}_gt.png")

print(f"Best loss = {best_loss:.{5}f} at iteration {best_iteration}.")
self.gaussian_model.set_optimizable_tensors({"xyz": best_xyz})

if best_rotation is None or best_translation is None:
best_rotation = self.transformation_model.rotation.numpy()
best_translation = self.transformation_model.translation.numpy()

return best_rotation, best_translation

def _save_artifacts(self, losses, rendered_image, iteration, run):
Expand Down
4 changes: 4 additions & 0 deletions gaussian_splatting/dataset/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from PIL import Image

from gaussian_splatting.utils.general import PILtoTorch


class ImageDataset:
def __init__(self, images_path: Path):
Expand All @@ -12,6 +14,8 @@ def get_frame(self, i: int):
image_path = self._images_paths[i]
image = Image.open(image_path)

image = PILtoTorch(image)

return image

def __len__(self):
Expand Down
7 changes: 7 additions & 0 deletions gaussian_splatting/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import numpy as np
import torch
from torchvision.transforms.functional import to_pil_image


def inverse_sigmoid(x):
Expand All @@ -33,6 +34,12 @@ def PILtoTorch(pil_image, resolution=None):
return image


def TorchToPIL(torch_image):
image = to_pil_image(torch_image)

return image


def get_expon_lr_func(
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
):
Expand Down
189 changes: 126 additions & 63 deletions scripts/train_colmap_free.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import copy
import math
from pathlib import Path

import numpy as np
import torch
import torchvision
from tqdm import tqdm

# from gaussian_splatting.colmap_free.global_trainer import GlobalTrainer
from gaussian_splatting.colmap_free.local_initialization_trainer import \
Expand All @@ -14,88 +13,152 @@
from gaussian_splatting.dataset.cameras import Camera
from gaussian_splatting.dataset.image_dataset import ImageDataset
from gaussian_splatting.render import render
from gaussian_splatting.utils.general import PILtoTorch
from gaussian_splatting.utils.loss import PhotometricLoss


def main():
debug = True
iteration_step_size = 50
initialization_iterations = 250
transformation_iterations = 1000
global_iterations = 5

photometric_loss = PhotometricLoss(lambda_dssim=0.2)
dataset = ImageDataset(images_path=Path("data/phil/1/input/"))

# Initialize Local3DGS gaussians
current_image = dataset.get_frame(0)
local_initialization_trainer = LocalInitializationTrainer(current_image)
local_initialization_trainer.run(iterations=initialization_iterations)
global_trainer = None
for iteration in range(0, len(dataset), iteration_step_size):
print(f">>> Current: {iteration} / Next: {iteration + iteration_step_size}")

# We set a copy of the initialized model to both the local transformation and the
# global models.
next_gaussian_model = local_initialization_trainer.gaussian_model
local_transformation_trainer = LocalTransformationTrainer(next_gaussian_model)
# global_trainer = GlobalTrainer(copy.deepcopy(next_gaussian_model))
current_image = dataset.get_frame(iteration)
next_image = dataset.get_frame(iteration + iteration_step_size)

current_camera = local_initialization_trainer.camera
for iteration in tqdm(
range(iteration_step_size, len(dataset), iteration_step_size)
):
# Keep a copy of current gaussians to compare
with torch.no_grad():
current_gaussian_model = copy.deepcopy(next_gaussian_model)

# Find transformation from current to next camera poses
next_image = PILtoTorch(dataset.get_frame(iteration))
rotation, translation = local_transformation_trainer.run(
current_camera,
next_image,
iterations=transformation_iterations,
run=iteration,
)
if iteration == 0:
current_camera = _get_orthogonal_camera(current_image)
else:
current_camera = next_camera

# Add new camera to Global3DGS training cameras
next_camera = Camera(
R=np.matmul(current_camera.R, rotation),
T=current_camera.T + translation,
FoVx=current_camera.FoVx,
FoVy=current_camera.FoVy,
image=next_image,
gt_alpha_mask=None,
image_name="patate",
colmap_id=iteration,
uid=iteration,
current_camera, current_gaussian_model = _initialize_local_gaussian_model(
current_image, current_camera, run=iteration
)

# if iteration == 0:
# global_gaussian_model = copy.deepcopy(current_gaussian_model)
# global_trainer = GlobalTrainer(global_gaussian_model)

if debug:
# Save artifact
current_camera_image, _, _, _ = render(
current_camera, current_gaussian_model
)
next_camera_image, _, _, _ = render(next_camera, current_gaussian_model)
next_gaussian_image, _, _, _ = render(current_camera, next_gaussian_model)
loss = photometric_loss(next_camera_image, next_gaussian_image)

print(loss)
output_path = Path("artifacts/global")
output_path.mkdir(exist_ok=True, parents=True)
torchvision.utils.save_image(
current_camera_image, output_path / f"current_camera_{iteration}.png"
)
torchvision.utils.save_image(
next_camera_image, output_path / f"next_camera_{iteration}.png"
)
torchvision.utils.save_image(
next_gaussian_image, output_path / f"next_gaussian_{iteration}.png"
)
current_gaussian_model_copy = copy.deepcopy(current_gaussian_model)

next_camera, next_gaussian_model = _transform_local_gaussian_model(
next_image, current_camera, current_gaussian_model, run=iteration
)

break
# global_trainer.add_camera(next_camera)
# global_trainer.run(global_iterations)

current_camera = next_camera
if debug:
save_artifacts(
current_camera,
current_gaussian_model_copy,
current_image,
next_camera,
next_gaussian_model,
next_image,
iteration,
)
if iteration >= 50:
break


def _initialize_local_gaussian_model(
image, camera, iterations: int = 250, run: int = 0
):
local_initialization_trainer = LocalInitializationTrainer(image, camera)
local_initialization_trainer.run(iterations=iterations)

current_camera = local_initialization_trainer.camera
gaussian_model = local_initialization_trainer.gaussian_model

return camera, gaussian_model


def _transform_local_gaussian_model(
image, camera, gaussian_model, iterations: int = 250, run: int = 0
):
local_transformation_trainer = LocalTransformationTrainer(gaussian_model)
rotation, translation = local_transformation_trainer.run(
camera,
image,
iterations=iterations,
run=run,
)

transformed_camera = _transform_camera(camera, rotation, translation, image, run)
transformed_gaussian_model = local_transformation_trainer.gaussian_model

return transformed_camera, transformed_gaussian_model


def _get_orthogonal_camera(image):
camera = Camera(
R=np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]),
T=np.array([-0.5, -0.5, 1.0]),
FoVx=2 * math.atan(0.5),
FoVy=2 * math.atan(0.5),
image=image,
gt_alpha_mask=None,
image_name="patate",
colmap_id=0,
uid=0,
)

return camera


def _transform_camera(camera, rotation, translation, image, _id, image_name=""):
transformed_camera = Camera(
R=np.matmul(camera.R, rotation),
T=(camera.T + translation),
FoVx=camera.FoVx,
FoVy=camera.FoVy,
image=image,
gt_alpha_mask=None,
image_name=image_name,
colmap_id=_id,
uid=_id,
)
return transformed_camera


def save_artifacts(
current_camera,
current_gaussian_model,
current_image,
next_camera,
next_gaussian_model,
next_image,
iteration,
):
# Save artifact
current_camera_image, _, _, _ = render(current_camera, current_gaussian_model)
next_camera_image, _, _, _ = render(next_camera, current_gaussian_model)
next_gaussian_image, _, _, _ = render(current_camera, next_gaussian_model)

output_path = Path("artifacts/global")
output_path.mkdir(exist_ok=True, parents=True)
torchvision.utils.save_image(
current_camera_image, output_path / f"{iteration}_current_camera.png"
)
torchvision.utils.save_image(
next_camera_image, output_path / f"{iteration}_next_camera.png"
)
torchvision.utils.save_image(
next_gaussian_image, output_path / f"{iteration}_next_gaussian.png"
)
torchvision.utils.save_image(
current_image, output_path / f"{iteration}_current_image.png"
)
torchvision.utils.save_image(
next_image, output_path / f"{iteration}_next_image.png"
)


if __name__ == "__main__":
Expand Down

0 comments on commit 0df8051

Please sign in to comment.