Skip to content

Commit

Permalink
gradual global training
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Apr 12, 2024
1 parent 3ae97ce commit e6f3097
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 135 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ tensorboard_3d
screenshots
data/
gaussian_splatting.egg-info/
artifacts/
21 changes: 18 additions & 3 deletions gaussian_splatting/dataset/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,34 @@


class ImageDataset:
def __init__(self, images_path: Path, step_size: int = 1):
def __init__(
self, images_path: Path, step_size: int = 1, downscale_factor: int = 1
):
self._images_paths = [
f for i, f in enumerate(images_path.iterdir()) if i % step_size == 0
]
self._images_paths.sort(key=lambda f: int(f.stem))

self._downscale_factor = downscale_factor

def __len__(self):
return len(self._images_paths)

def get_frame(self, i: int):
image_path = self._images_paths[i]
image = Image.open(image_path)

if self._downscale_factor > 1:
image = self._downscale(image)

image = PILtoTorch(image)

return image

def __len__(self):
return len(self._images_paths)
def _downscale(self, image):
h, w = image.size
image = image.resize(
(h // self._downscale_factor, w // self._downscale_factor), Image.LANCZOS
)

return image
2 changes: 0 additions & 2 deletions gaussian_splatting/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,6 @@ def initialize_from_point_cloud(self, point_cloud):
features[:, :3, 0] = fused_color
features[:, 3:, 1:] = 0.0

print("Number of points at initialisation : ", fused_point_cloud.shape[0])

dist2 = torch.clamp_min(
distCUDA2(torch.from_numpy(np.asarray(point_cloud.points)).float().cuda()),
0.0000001,
Expand Down
49 changes: 20 additions & 29 deletions gaussian_splatting/pose_free/global_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import os
from random import randint

from tqdm import tqdm
from pathlib import Path

from gaussian_splatting.optimizer import Optimizer
from gaussian_splatting.render import render
Expand All @@ -11,15 +8,16 @@


class GlobalTrainer(Trainer):
def __init__(self, gaussian_model, output_path=None):
def __init__(self, gaussian_model, iterations: int = 100, output_path=None):
self._model_path = self._prepare_model_path(output_path)

self.gaussian_model = gaussian_model
self.cameras = []

self.optimizer = Optimizer(self.gaussian_model)
self._photometric_loss = PhotometricLoss(lambda_dssim=0.2)

self._iterations = iterations

self._debug = False

# Densification and pruning
Expand All @@ -30,25 +28,16 @@ def __init__(self, gaussian_model, output_path=None):

safe_state()

def add_camera(self, camera):
self.cameras.append(camera)

def run(self, iterations: int = 1000):
ema_loss_for_log = 0.0
cameras = None
first_iter = 1
progress_bar = tqdm(range(first_iter, iterations), desc="Training progress")
for iteration in range(first_iter, iterations + 1):
def run(self, current_camera, next_camera, progress_bar=None, run_id: int = 0):
cameras = (current_camera, next_camera)
for iteration in range(self._iterations):
self.optimizer.update_learning_rate(iteration)

# Every 1000 its we increase the levels of SH up to a maximum degree
if iteration % 1000 == 0:
self.gaussian_model.oneupSHdegree()

# Pick a random camera
if not cameras:
cameras = self.cameras.copy()
camera = cameras.pop(randint(0, len(cameras) - 1))
camera = cameras[iteration % 2]

# Render image
rendered_image, viewspace_point_tensor, visibility_filter, radii = render(
Expand All @@ -59,22 +48,24 @@ def run(self, iterations: int = 1000):
gt_image = camera.original_image.cuda()
loss = self._photometric_loss(rendered_image, gt_image)
loss.backward()
loss_value = loss.cpu().item()

# Optimizer step
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)

# Progress bar
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
progress_bar.update(1)

progress_bar.close()

point_cloud_path = os.path.join(
self._model_path, "point_cloud/iteration_{}".format(iteration)
if progress_bar is not None:
progress_bar.set_postfix(
{
"stage": "global",
"iteration": f"{iteration}/{self._iterations}",
"loss": f"{loss_value:.5f}",
}
)

self.gaussian_model.save_ply(
Path(self._model_path) / "point_cloud" / str(run_id) / "point_cloud.ply"
)
self.gaussian_model.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))

# Densification
self.gaussian_model.update_stats(
Expand Down
102 changes: 58 additions & 44 deletions gaussian_splatting/pose_free/local_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,37 +21,40 @@ class LocalTrainer:
def __init__(
self,
sh_degree: int = 3,
init_iterations: int = 250,
transfo_iterations: int = 250,
init_iterations: int = 1000,
transfo_iterations: int = 1000,
debug: bool = False,
):
self._depth_estimator = self._load_depth_estimator()
self._depth_estimator = pipeline("depth-estimation", model="vinvino02/glpn-nyu")
self._point_cloud_step = 25
self._sh_degree = sh_degree

self._photometric_loss = PhotometricLoss(lambda_dssim=0.2)

self._init_iterations = init_iterations
self._init_early_stopper = EarlyStopper(patience=10)
self._init_save_artifacts_iterations = 50
self._init_save_artifacts_iterations = 100

self._transfo_lr = 0.0001
self._transfo_lr = 0.00001
self._transfo_iterations = transfo_iterations
self._transfo_early_stopper = EarlyStopper(patience=10)
self._transfo_save_artifacts_iterations = 100
self._transfo_early_stopper = EarlyStopper(
patience=100,
)
self._transfo_save_artifacts_iterations = 10

self._debug = True
self._debug = debug

self._output_path = Path("artifacts/local/")
self._output_path.mkdir(exist_ok=True, parents=True)

safe_state(seed=2234)

def run_init(self, image, camera, run_id: int = 0):
output_path = self._output_path / "init"
def run_init(self, image, camera, progress_bar=None, run_id: int = 0):
output_path = self._output_path / "init" / str(run_id)
output_path.mkdir(exist_ok=True, parents=True)

gaussian_model = self._get_initial_gaussian_model(image)
gaussian_model = self.get_initial_gaussian_model(image, output_path)
optimizer = Optimizer(gaussian_model)
self._init_early_stopper.reset()

image = image.cuda()
losses = []
Expand All @@ -69,31 +72,37 @@ def run_init(self, image, camera, run_id: int = 0):
optimizer.zero_grad(set_to_none=True)

if self._init_early_stopper.step(loss_value):
self._init_early_stopper.print_early_stop()
break

if (
self._debug
or iteration % self._init_save_artifacts_iterations == 0
or iteration == self._init_iterations - 1
):
self._save_artifacts(
losses, rendered_image, output_path / str(run_id), iteration
if self._debug and (iteration % self._init_save_artifacts_iterations == 0):
self._save_artifacts(losses, rendered_image, output_path, iteration)

if progress_bar is not None:
progress_bar.set_postfix(
{
"stage": "init",
"iteration": f"{iteration}/{self._init_iterations}",
"loss": f"{loss_value:.5f}",
}
)

if self._debug:
save_image(image, output_path / f"{run_id}_ground_truth.png")
self._save_artifacts(losses, rendered_image, output_path, "best")
save_image(image, output_path / "ground_truth.png")

return gaussian_model

def run_transfo(self, image, camera, gaussian_model, run_id: int = 0):
output_path = self._output_path / "transfo"
def run_transfo(
self, image, camera, gaussian_model, progress_bar=None, run_id: int = 0
):
output_path = self._output_path / "transfo" / str(run_id)
output_path.mkdir(exist_ok=True, parents=True)

transformation_model = AffineTransformationModel()
optimizer = torch.optim.Adam(
transformation_model.parameters(), lr=self._transfo_lr
)
self._transfo_early_stopper.reset()

image = image.cuda()
transformation_model = transformation_model.cuda()
Expand All @@ -114,36 +123,48 @@ def run_transfo(self, image, camera, gaussian_model, run_id: int = 0):
optimizer.step()

if self._transfo_early_stopper.step(loss_value):
self._transfo_early_stopper.print_early_stop()
transformation = self._transfo_early_stopper.get_best_params(
transformation
)
transformation = self._transfo_early_stopper.get_best_params()
break
else:
transformation = transformation_model.transformation
self._init_early_stopper.set_best_params(transformation)
self._transfo_early_stopper.set_best_params(transformation)

if (
self._debug
or iteration % self._transfo_save_artifacts_iterations == 0
or iteration == self._transfo_iterations - 1
if self._debug and (
iteration % self._transfo_save_artifacts_iterations == 0
):
self._save_artifacts(
losses,
rendered_image,
output_path / str(run_id),
output_path,
iteration,
)

if progress_bar is not None:
progress_bar.set_postfix(
{
"stage": "transfo",
"iteration": f"{iteration}/{self._transfo_iterations}",
"loss": f"{loss_value:.5f}",
}
)

if self._debug:
save_image(image, output_path / f"{run_id}_ground_truth.png")
self._save_artifacts(losses, rendered_image, output_path, "best")
save_image(image, output_path / f"ground_truth.png")

return transformation

def _get_initial_gaussian_model(self, image):
def get_initial_gaussian_model(self, image, output_folder: Path = None):
PIL_image = TorchToPIL(image)

depth_estimation = self._depth_estimator(PIL_image)["predicted_depth"]

if self._debug and output_folder is not None:
_min, _max = depth_estimation.min().item(), depth_estimation.max().item()
save_image(
(depth_estimation - _min) / (_max - _min),
output_folder / f"depth_estimation_{_min:.3f}_{_max:.3f}.png",
)

point_cloud = self._get_initial_point_cloud_from_depth_estimation(
image, depth_estimation, step=self._point_cloud_step
)
Expand Down Expand Up @@ -195,17 +216,10 @@ def _get_initial_point_cloud_from_depth_estimation(

return point_cloud

def _load_depth_estimator(self):
checkpoint = "vinvino02/glpn-nyu"
depth_estimator = pipeline("depth-estimation", model=checkpoint)

return depth_estimator

def _save_artifacts(self, losses, rendered_image, output_path, iteration):
output_path.mkdir(exist_ok=True, parents=True)
plt.cla()
plt.plot(losses)
plt.yscale("log")
plt.savefig(output_path / "losses.png")

save_image(rendered_image, self._output_path / f"rendered_{iteration}.png")
save_image(rendered_image, output_path / f"rendered_{iteration}.png")
Loading

0 comments on commit e6f3097

Please sign in to comment.