Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Apr 10, 2024
1 parent 1660517 commit 84afe4b
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 108 deletions.
26 changes: 11 additions & 15 deletions gaussian_splatting/local_initialization_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from gaussian_splatting.trainer import Trainer
from gaussian_splatting.utils.general import PILtoTorch, safe_state
from gaussian_splatting.utils.graphics import BasicPointCloud
from gaussian_splatting.utils.loss import l1_loss, ssim
from gaussian_splatting.utils.loss import PhotometricLoss


class LocalInitializationTrainer(Trainer):
Expand All @@ -32,12 +32,10 @@ def __init__(self, image, sh_degree: int = 3, iterations: int = 10000):
# TODO: set camera extent???

self.optimizer = Optimizer(self.gaussian_model)
self._photometric_loss = PhotometricLoss(lambda_dssim=0.2)

self.camera = self._get_orthogonal_camera(image)

self._iterations = iterations
self._lambda_dssim = 0.2

# Densification and pruning
self._opacity_reset_interval = 10001
self._min_opacity = 0.005
Expand All @@ -52,11 +50,11 @@ def __init__(self, image, sh_degree: int = 3, iterations: int = 10000):

safe_state(seed=2234)

def run(self):
progress_bar = tqdm(range(self._iterations), desc="Initialization")
def run(self, iterations: int = 3000):
progress_bar = tqdm(range(iterations), desc="Initialization")

best_loss, best_iteration, losses = None, 0, []
for iteration in range(self._iterations):
for iteration in range(iterations):
self.optimizer.update_learning_rate(iteration)
rendered_image, viewspace_point_tensor, visibility_filter, radii = render(
self.camera, self.gaussian_model
Expand All @@ -73,19 +71,17 @@ def run(self):
)

gt_image = self.camera.original_image.cuda()
Ll1 = l1_loss(rendered_image, gt_image)
loss = (1.0 - self._lambda_dssim) * Ll1 + self._lambda_dssim * (
1.0 - ssim(rendered_image, gt_image)
)
loss = self._photometric_loss(rendered_image, gt_image)
loss.backward()

self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)

if best_loss is None or best_loss > loss:
best_loss = loss.cpu().item()
best_iteration = iteration
losses.append(loss.cpu().item())

loss.backward()

self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)

with torch.no_grad():
# Densification
Expand Down
99 changes: 44 additions & 55 deletions gaussian_splatting/local_transformation_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from gaussian_splatting.render import render
from gaussian_splatting.trainer import Trainer
from gaussian_splatting.utils.general import PILtoTorch, safe_state
from gaussian_splatting.utils.loss import l1_loss, ssim
from gaussian_splatting.utils.general import safe_state
from gaussian_splatting.utils.loss import PhotometricLoss


class QuaternionRotation(nn.Module):
Expand Down Expand Up @@ -81,84 +81,73 @@ def translation(self):


class LocalTransformationTrainer(Trainer):
def __init__(self, image, camera, gaussian_model, iterations: int = 100):
self.camera = camera
def __init__(self, gaussian_model):
self.gaussian_model = gaussian_model

self.xyz = gaussian_model.get_xyz.detach()

self.transformation_model = TransformationModel()
self.transformation_model.to(self.xyz.device)

self.image = PILtoTorch(image).to(self.xyz.device)
self.transformation_model.to(gaussian_model.get_xyz.device)

self.optimizer = torch.optim.Adam(
self.transformation_model.parameters(), lr=0.0001
)

self._iterations = iterations
self._lambda_dssim = 0.2
self._photometric_loss = PhotometricLoss(lambda_dssim=0.2)

safe_state(seed=2234)

def run(self):
progress_bar = tqdm(range(self._iterations), desc="Transformation")
def run(self, current_camera, gt_image, iterations: int = 1000):
gt_image = gt_image.to(self.gaussian_model.get_xyz.device)

progress_bar = tqdm(range(iterations), desc="Transformation")

best_loss, best_iteration, losses = None, 0, []
best_xyz = None
for iteration in range(self._iterations):
xyz = self.transformation_model(self.xyz)
losses = []
best_loss, best_iteration, best_xyz = None, 0, None
patience = 0
for iteration in range(iterations):
xyz = self.transformation_model(self.gaussian_model.get_xyz.detach())
self.gaussian_model.set_optimizable_tensors({"xyz": xyz})

rendered_image, viewspace_point_tensor, visibility_filter, radii = render(
self.camera, self.gaussian_model
current_camera, self.gaussian_model
)

if iteration % 10 == 0:
plt.cla()
plt.plot(losses)
plt.yscale("log")
plt.savefig("artifacts/local/transfo/losses.png")

torchvision.utils.save_image(
rendered_image, f"artifacts/local/transfo/rendered_{iteration}.png"
)

gt_image = self.image
Ll1 = l1_loss(rendered_image, gt_image)
loss = (1.0 - self._lambda_dssim) * Ll1 + self._lambda_dssim * (
1.0 - ssim(rendered_image, gt_image)
)
if best_loss is None or best_loss > loss:
best_loss = loss.cpu().item()
best_iteration = iteration
best_xyz = xyz
losses.append(loss.cpu().item())

loss = self._photometric_loss(rendered_image, gt_image)
loss.backward()

self.optimizer.step()

progress_bar.set_postfix(
{
"Loss": f"{loss:.{5}f}",
}
)
losses.append(loss.cpu().item())

progress_bar.set_postfix({"Loss": f"{loss:.{5}f}"})
progress_bar.update(1)

if iteration % 10 == 0:
self._save_artifacts(losses, rendered_image, iteration)

if best_loss is None or best_loss > loss:
best_loss = loss.cpu().item()
best_iteration = iteration
best_xyz = xyz.detach()
elif best_loss < loss and patience > 10:
self._save_artifacts(losses, rendered_image, iteration)
break
else:
patience += 1

progress_bar.close()
print(
f"Training done. Best loss = {best_loss:.{5}f} at iteration {best_iteration}."
)
self.gaussian_model.set_optimizable_tensors({"xyz": best_xyz})

torchvision.utils.save_image(
rendered_image, f"artifacts/local/transfo/rendered_best.png"
)
torchvision.utils.save_image(gt_image, f"artifacts/local/transfo/gt.png")
print(f"Best loss = {best_loss:.{5}f} at iteration {best_iteration}.")
self.gaussian_model.set_optimizable_tensors({"xyz": best_xyz})

def get_affine_transformation(self):
rotation = self.transformation_model.rotation.numpy()
translation = self.transformation_model.translation.numpy()

return rotation, translation

def _save_artifacts(self, losses, rendered_image, iteration):
plt.cla()
plt.plot(losses)
plt.yscale("log")
plt.savefig("artifacts/local/transfo/losses.png")

torchvision.utils.save_image(
rendered_image, f"artifacts/local/transfo/rendered_{iteration}.png"
)
12 changes: 12 additions & 0 deletions gaussian_splatting/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@
from torch.autograd import Variable


class PhotometricLoss:
def __init__(self, lambda_dssim: float = 0.2):
self._lambda_dssim = lambda_dssim

def __call__(self, network_output, gt):
l1_value = l1_loss(network_output, gt)
ssim_value = ssim(network_output, gt)

loss = (1.0 - self._lambda_dssim) * l1_value + self._lambda_dssim * (1.0 - ssim_value)

return loss

def l1_loss(network_output, gt):
return torch.abs((network_output - gt)).mean()

Expand Down
89 changes: 51 additions & 38 deletions scripts/train_colmap_free.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,85 @@
import copy
from pathlib import Path

import copy
import numpy as np
import torchvision
import torch
from tqdm import tqdm

from gaussian_splatting.dataset.cameras import Camera
from gaussian_splatting.render import render
from gaussian_splatting.dataset.image_dataset import ImageDataset
from gaussian_splatting.global_trainer import GlobalTrainer
from gaussian_splatting.local_initialization_trainer import \
LocalInitializationTrainer
from gaussian_splatting.local_transformation_trainer import \
LocalTransformationTrainer
from gaussian_splatting.render import render
from gaussian_splatting.global_trainer import GlobalTrainer
from gaussian_splatting.utils.general import PILtoTorch
from gaussian_splatting.utils.loss import PhotometricLoss


def main():
debug = True
iteration_step_size = 5
initialization_iterations = 50
transformation_iterations = 50

photometric_loss = PhotometricLoss(lambda_dssim=0.2)
dataset = ImageDataset(images_path=Path("data/phil/1/input/"))
first_image = dataset.get_frame(0)

local_initialization_trainer = LocalInitializationTrainer(
first_image, iterations=500
)
local_initialization_trainer.run()
# Initialize Local3DGS gaussians
current_image = dataset.get_frame(0)
local_initialization_trainer = LocalInitializationTrainer(current_image)
local_initialization_trainer.run(iterations=initialization_iterations)

next_gaussian_model = local_initialization_trainer.gaussian_model
local_transformation_trainer = LocalTransformationTrainer(next_gaussian_model)
#global_trainer = GlobalTrainer(
# gaussian_model=gaussian_model,
# cameras=[camera],
# iterations=100
#)

current_camera = local_initialization_trainer.camera
for iteration in tqdm(range(iteration_step_size, len(dataset), iteration_step_size)):
# Keep a copy of current gaussians to compare
with torch.no_grad():
current_gaussian_model = copy.deepcopy(next_gaussian_model)

step = 5
gaussian_model = local_initialization_trainer.gaussian_model
camera = local_initialization_trainer.camera
global_trainer = GlobalTrainer(
gaussian_model=gaussian_model, cameras=[camera], iterations=100
)
for iteration in range(step, len(dataset), step):
next_image = dataset.get_frame(iteration)
local_transformation_trainer = LocalTransformationTrainer(
# Find transformation from current to next camera poses
next_image = PILtoTorch(dataset.get_frame(iteration))
rotation, translation = local_transformation_trainer.run(
current_camera,
next_image,
camera=camera,
gaussian_model=copy.deepcopy(gaussian_model),
iterations=250,
iterations=transformation_iterations,
)
local_transformation_trainer.run()
rotation, translation = local_transformation_trainer.get_affine_transformation()

next_image = PILtoTorch(next_image)
# Add new camera to Global3DGS training cameras
next_camera = Camera(
R=np.matmul(camera.R, rotation),
T=camera.T + translation,
FoVx=camera.FoVx,
FoVy=camera.FoVy,
R=np.matmul(current_camera.R, rotation),
T=current_camera.T + translation,
FoVx=current_camera.FoVx,
FoVy=current_camera.FoVy,
image=next_image,
gt_alpha_mask=None,
image_name="patate",
colmap_id=iteration,
uid=iteration,
)
global_trainer.add_camera(next_camera)
global_trainer.run()

rendered_image, viewspace_point_tensor, visibility_filter, radii = render(
next_camera,
local_initialization_trainer.gaussian_model,
)
torchvision.utils.save_image(
rendered_image, f"artifacts/global/rendered_{iteration}.png"
)
torchvision.utils.save_image(next_image, f"artifacts/global/gt_{iteration}.png")
if debug:
# Save artifact
next_camera_image, _, _, _ = render(next_camera, current_gaussian_model)
next_gaussian_image, _, _, _ = render(current_camera, next_gaussian_model)
loss = photometric_loss(next_camera_image, next_gaussian_image)
assert loss < 0.01

print(f">>> Iteration {iteration} / {len(dataset) // step}")
torchvision.utils.save_image(next_camera_image, f"artifacts/global/next_camera_{iteration}.png")
torchvision.utils.save_image(next_gaussian_image, f"artifacts/global/next_gaussian_{iteration}.png")
#global_trainer.add_camera(next_camera)
#global_trainer.run()

current_camera = next_camera

if __name__ == "__main__":
main()

0 comments on commit 84afe4b

Please sign in to comment.