Skip to content

Commit

Permalink
local 3dgs transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Apr 10, 2024
1 parent 8cbad5f commit 90f3fa3
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 66 deletions.
15 changes: 4 additions & 11 deletions gaussian_splatting/colmap_free_trainer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os
import uuid
from random import randint

import torch
from tqdm import tqdm

from gaussian_splatting.dataset.dataset import Dataset
from gaussian_splatting.model import GaussianModel
from gaussian_splatting.optimizer import Optimizer
from gaussian_splatting.render import render
Expand All @@ -14,7 +12,6 @@
from gaussian_splatting.utils.loss import l1_loss, ssim



class ColmapFreeTrainer:
def __init__(
self,
Expand All @@ -34,28 +31,24 @@ def __init__(

safe_state()


def run(self):
progress_bar = tqdm(
range(len(self.dataset)), desc="Training progress"
)
progress_bar = tqdm(range(len(self.dataset)), desc="Training progress")
for iteration in range(len(dataset)):

I_t = self.dataset[i]
I_t_plus_1 = self.dataset[i + 1]

local_3DGS_trainer = LocalTrainer()

#self.optimizer.update_learning_rate(iteration)
# self.optimizer.update_learning_rate(iteration)

# Every 1000 its we increase the levels of SH up to a maximum degree
# if iteration % 1000 == 0:
# self.gaussian_model.oneupSHdegree()

# Pick a random camera
#if not cameras:
# if not cameras:
# cameras = self.dataset.get_train_cameras().copy()
#camera = cameras.pop(randint(0, len(cameras) - 1))
# camera = cameras.pop(randint(0, len(cameras) - 1))

# Render image
rendered_image, viewspace_point_tensor, visibility_filter, radii = render(
Expand Down
7 changes: 3 additions & 4 deletions gaussian_splatting/dataset/image_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from PIL import Image
from pathlib import Path
from gaussian_splatting.utils.general import PILtoTorch

class ImageDataset:
from PIL import Image


class ImageDataset:
def __init__(self, images_path: Path):
self._images_paths = [f for f in images_path.iterdir()]
self._images_paths.sort(key=lambda f: int(f.stem))
Expand All @@ -16,4 +16,3 @@ def get_frame(self, i: int):

def __len__(self):
return len(self._image_paths)

Original file line number Diff line number Diff line change
@@ -1,33 +1,30 @@
import math
from matplotlib import pyplot as plt
import numpy as np
from transformers import pipeline
from tqdm import tqdm

import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from tqdm import tqdm
from transformers import pipeline

from gaussian_splatting.utils.general import safe_state
from gaussian_splatting.dataset.cameras import Camera
from gaussian_splatting.model import GaussianModel
from gaussian_splatting.optimizer import Optimizer
from gaussian_splatting.render import render
from gaussian_splatting.trainer import Trainer
from gaussian_splatting.utils.general import PILtoTorch, safe_state
from gaussian_splatting.utils.graphics import BasicPointCloud
from gaussian_splatting.utils.general import PILtoTorch
from gaussian_splatting.dataset.cameras import Camera
from gaussian_splatting.utils.loss import l1_loss, ssim
from gaussian_splatting.trainer import Trainer


class LocalTrainer(Trainer):
def __init__(self, image, sh_degree: int = 3):
class LocalInitializationTrainer(Trainer):
def __init__(self, image, sh_degree: int = 3, iterations: int = 10000):
DPT = self._load_DPT()
depth_estimation = DPT(image)["predicted_depth"]

image = PILtoTorch(image)
initial_point_cloud = self._get_initial_point_cloud(
image,
depth_estimation,
step=25
image, depth_estimation, step=25
)

self.gaussian_model = GaussianModel(sh_degree)
Expand All @@ -36,9 +33,9 @@ def __init__(self, image, sh_degree: int = 3):

self.optimizer = Optimizer(self.gaussian_model)

self._camera = self._get_orthogonal_camera(image)
self.camera = self._get_orthogonal_camera(image)

self._iterations = 10000
self._iterations = iterations
self._lambda_dssim = 0.2

# Densification and pruning
Expand All @@ -56,33 +53,33 @@ def __init__(self, image, sh_degree: int = 3):
safe_state(seed=2234)

def run(self):
progress_bar = tqdm(
range(self._iterations), desc="Training progress"
)
progress_bar = tqdm(range(self._iterations), desc="Initialization")

best_loss, best_iteration, losses = None, 0, []
for iteration in range(self._iterations):
self.optimizer.update_learning_rate(iteration)
rendered_image, viewspace_point_tensor, visibility_filter, radii = render(
self._camera, self.gaussian_model
self.camera, self.gaussian_model
)

if iteration % 100 == 0:
plt.cla()
plt.plot(losses)
plt.yscale('log')
plt.savefig('artifacts/losses.png')
plt.yscale("log")
plt.savefig("artifacts/local/init/losses.png")

torchvision.utils.save_image(rendered_image, f"artifacts/rendered_{iteration}.png")
torchvision.utils.save_image(
rendered_image, f"artifacts/local/init/rendered_{iteration}.png"
)

gt_image = self._camera.original_image.cuda()
gt_image = self.camera.original_image.cuda()
Ll1 = l1_loss(rendered_image, gt_image)
loss = (1.0 - self._lambda_dssim) * Ll1 + self._lambda_dssim * (
1.0 - ssim(rendered_image, gt_image)
)
if best_loss is None or best_loss > loss:
best_loss = loss.cpu().item()
best_iteartion = iteration
best_iteration = iteration
losses.append(loss.cpu().item())

loss.backward()
Expand Down Expand Up @@ -110,22 +107,25 @@ def run(self):
print("Reset Opacity")
self._reset_opacity()

progress_bar.set_postfix({
"Loss": f"{loss:.{5}f}",
"Num_visible":
f"{visibility_filter.int().sum().item()}/{len(visibility_filter)}"
})
progress_bar.set_postfix(
{
"Loss": f"{loss:.{5}f}",
"Num_visible": f"{visibility_filter.int().sum().item()}/{len(visibility_filter)}",
}
)
progress_bar.update(1)

print(f"Training done. Best loss = {best_loss} at iteration {best_iteration}.")

torchvision.utils.save_image(rendered_image, f"artifacts/rendered_{iteration}.png")
torchvision.utils.save_image(gt_image, f"artifacts/gt.png")
torchvision.utils.save_image(
rendered_image, f"artifacts/local/init/rendered_{iteration}.png"
)
torchvision.utils.save_image(gt_image, f"artifacts/local/init/gt.png")

def _get_orthogonal_camera(self, image):
camera = Camera(
R=np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]),
T=np.array([-0.5, -0.5, 1.]),
R=np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]),
T=np.array([-0.5, -0.5, 1.0]),
FoVx=2 * math.atan(0.5),
FoVy=2 * math.atan(0.5),
image=image,
Expand All @@ -150,20 +150,24 @@ def _get_initial_point_cloud(self, frame, depth_estimation, step: int = 50):
for y in range(step, h - step, step):
_depth = depth_estimation[0, x, y].item()
# Normalized points
points.append([
y / h,
x / w,
(_depth - _min_depth) / (_max_depth - _min_depth)
])
points.append(
[y / h, x / w, (_depth - _min_depth) / (_max_depth - _min_depth)]
)
# Average RGB color in the window color around selected pixel
colors.append(
frame[
:,
x - half_step: x + half_step,
y - half_step: y + half_step
].mean(axis=[1, 2]).tolist()
:, x - half_step : x + half_step, y - half_step : y + half_step
]
.mean(axis=[1, 2])
.tolist()
)
normals.append(
[
0.0,
0.0,
0.0,
]
)
normals.append([0., 0., 0.,])

point_cloud = BasicPointCloud(
points=np.array(points),
Expand All @@ -178,5 +182,3 @@ def _load_DPT(self):
depth_estimator = pipeline("depth-estimation", model=checkpoint)

return depth_estimator


95 changes: 95 additions & 0 deletions gaussian_splatting/local_transformation_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import torchvision
from matplotlib import pyplot as plt
from tqdm import tqdm

from gaussian_splatting.render import render
from gaussian_splatting.trainer import Trainer
from gaussian_splatting.utils.general import PILtoTorch, safe_state
from gaussian_splatting.utils.loss import l1_loss, ssim


class TransformationModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(in_features=3, out_features=3)

torch.nn.init.eye_(self.linear.weight.data)
torch.nn.init.zeros_(self.linear.bias.data)

def forward(self, xyz):
transformed_xyz = self.linear(xyz)

return transformed_xyz


class LocalTransformationTrainer(Trainer):
def __init__(self, image, camera, gaussian_model):
self.camera = camera
self.gaussian_model = gaussian_model

self.xyz = gaussian_model.get_xyz.detach()

self.transformation_model = TransformationModel()
self.transformation_model.to(self.xyz.device)

self.image = PILtoTorch(image).to(self.xyz.device)

self.optimizer = torch.optim.Adam(
self.transformation_model.parameters(), lr=0.0001
)

self._iterations = 101
self._lambda_dssim = 0.2

safe_state(seed=2234)

def run(self):
progress_bar = tqdm(range(self._iterations), desc="Transformation")

best_loss, best_iteration, losses = None, 0, []
for iteration in range(self._iterations):
xyz = self.transformation_model(self.xyz)
self.gaussian_model.set_optimizable_tensors({"xyz": xyz})

rendered_image, viewspace_point_tensor, visibility_filter, radii = render(
self.camera, self.gaussian_model
)

if iteration % 10 == 0:
plt.cla()
plt.plot(losses)
plt.yscale("log")
plt.savefig("artifacts/local/transfo/losses.png")

torchvision.utils.save_image(
rendered_image, f"artifacts/local/transfo/rendered_{iteration}.png"
)

gt_image = self.image
Ll1 = l1_loss(rendered_image, gt_image)
loss = (1.0 - self._lambda_dssim) * Ll1 + self._lambda_dssim * (
1.0 - ssim(rendered_image, gt_image)
)
if best_loss is None or best_loss > loss:
best_loss = loss.cpu().item()
best_iteration = iteration
losses.append(loss.cpu().item())

loss.backward()

self.optimizer.step()

progress_bar.set_postfix(
{
"Loss": f"{loss:.{5}f}",
}
)
progress_bar.update(1)

print(f"Training done. Best loss = {best_loss} at iteration {best_iteration}.")

torchvision.utils.save_image(
rendered_image, f"artifacts/local/transfo/rendered_{iteration}.png"
)
torchvision.utils.save_image(gt_image, f"artifacts/local/transfo/gt.png")
2 changes: 1 addition & 1 deletion gaussian_splatting/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, sh_degree: int = 3):
self.inverse_opacity_activation = inverse_sigmoid
self.rotation_activation = torch.nn.functional.normalize

self.camera_extent = 1.
self.camera_extent = 1.0

def state_dict(self):
state_dict = (
Expand Down
3 changes: 2 additions & 1 deletion gaussian_splatting/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def inverse_sigmoid(x):
return torch.log(x / (1 - x))


def PILtoTorch(pil_image, resolution = None):
def PILtoTorch(pil_image, resolution=None):
if resolution is not None:
pil_image = pil_image.resize(resolution)

Expand All @@ -32,6 +32,7 @@ def PILtoTorch(pil_image, resolution = None):

return image


def get_expon_lr_func(
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
):
Expand Down
19 changes: 15 additions & 4 deletions scripts/train_colmap_free.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
from pathlib import Path

from gaussian_splatting.local_trainer import LocalTrainer
from gaussian_splatting.dataset.image_dataset import ImageDataset
from gaussian_splatting.local_initialization_trainer import \
LocalInitializationTrainer
from gaussian_splatting.local_transformation_trainer import \
LocalTransformationTrainer


def main():
dataset = ImageDataset(images_path=Path("data/phil/1/input/"))
image = dataset.get_frame(0)
image_0 = dataset.get_frame(0)
image_1 = dataset.get_frame(10)

local_trainer = LocalTrainer(image)
local_trainer.run()
local_initialization_trainer = LocalInitializationTrainer(image_0, iterations=100)
local_initialization_trainer.run()

local_transformation_trainer = LocalTransformationTrainer(
image_1,
camera=local_initialization_trainer.camera,
gaussian_model=local_initialization_trainer.gaussian_model,
)
local_transformation_trainer.run()


if __name__ == "__main__":
Expand Down

0 comments on commit 90f3fa3

Please sign in to comment.