Skip to content

Commit

Permalink
switched depth estimator for a faster implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Apr 18, 2024
1 parent 204908a commit abd82ff
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 253 deletions.
234 changes: 0 additions & 234 deletions gaussian_splatting/colmap_free_trainer.py

This file was deleted.

29 changes: 29 additions & 0 deletions gaussian_splatting/pose_free/depth_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
from transformers import pipeline

from gaussian_splatting.dataset.image_dataset import ImageDataset
from gaussian_splatting.utils.general import TorchToPIL


class DepthEstimator:
def __init__(self, model: str = "Intel/dpt-large"):
self._model = pipeline("depth-estimation", model=model)

def run(self, image):
PIL_image = TorchToPIL(image)
depth_estimation = self._model(PIL_image)["predicted_depth"]

depth_estimation = torch.nn.functional.interpolate(
depth_estimation.unsqueeze(1),
size=PIL_image.size[::-1],
mode="bicubic",
align_corners=False,
).squeeze()

_min = depth_estimation.min()
_max = depth_estimation.max()
depth_estimation = (depth_estimation - _min) / (_max - _min)

depth_estimation = -1 * (depth_estimation - 1)

return depth_estimation
28 changes: 9 additions & 19 deletions gaussian_splatting/pose_free/local_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import torch
from matplotlib import pyplot as plt
from torchvision.utils import save_image
from transformers import pipeline

from gaussian_splatting.model import GaussianModel
from gaussian_splatting.optimizer import Optimizer
from gaussian_splatting.pose_free.depth_estimator import DepthEstimator
from gaussian_splatting.pose_free.transformation_model import \
AffineTransformationModel
from gaussian_splatting.render import render
from gaussian_splatting.utils.early_stopper import EarlyStopper
from gaussian_splatting.utils.general import TorchToPIL, safe_state
from gaussian_splatting.utils.general import safe_state
from gaussian_splatting.utils.graphics import BasicPointCloud
from gaussian_splatting.utils.loss import PhotometricLoss

Expand All @@ -25,7 +25,7 @@ def __init__(
transfo_iterations: int = 1000,
debug: bool = False,
):
self._depth_estimator = pipeline("depth-estimation", model="vinvino02/glpn-nyu")
self._depth_estimator = DepthEstimator()
self._point_cloud_step = 25
self._sh_degree = sh_degree

Expand Down Expand Up @@ -155,14 +155,11 @@ def run_transfo(
return transformation

def get_initial_gaussian_model(self, image, output_folder: Path = None):
PIL_image = TorchToPIL(image)
depth_estimation = self._depth_estimator(PIL_image)["predicted_depth"]

depth_estimation = self._depth_estimator.run(image)
if self._debug and output_folder is not None:
_min, _max = depth_estimation.min().item(), depth_estimation.max().item()
save_image(
(depth_estimation - _min) / (_max - _min),
output_folder / f"depth_estimation_{_min:.3f}_{_max:.3f}.png",
depth_estimation,
output_folder / f"depth_estimation.png",
)

point_cloud = self._get_initial_point_cloud_from_depth_estimation(
Expand All @@ -177,21 +174,14 @@ def get_initial_gaussian_model(self, image, output_folder: Path = None):
def _get_initial_point_cloud_from_depth_estimation(
self, frame, depth_estimation, step: int = 50
):
# Frame and depth_estimation width do not exactly match.
_, w, h = depth_estimation.shape

_min_depth = depth_estimation.min()
_max_depth = depth_estimation.max()

w, h = depth_estimation.shape
half_step = step // 2
points, colors, normals = [], [], []
for x in range(step, w - step, step):
for y in range(step, h - step, step):
_depth = depth_estimation[0, x, y].item()
_depth = depth_estimation[x, y].item()
# Normalized points
points.append(
[y / h, x / w, (_depth - _min_depth) / (_max_depth - _min_depth)]
)
points.append([y / h, x / w, _depth])
# Average RGB color in the window color around selected pixel
colors.append(
frame[
Expand Down
Loading

0 comments on commit abd82ff

Please sign in to comment.