Skip to content

Commit

Permalink
clean training
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Mar 19, 2024
1 parent 1372015 commit 4f3f351
Show file tree
Hide file tree
Showing 9 changed files with 268 additions and 346 deletions.
1 change: 0 additions & 1 deletion cameras.json

This file was deleted.

13 changes: 6 additions & 7 deletions gaussian_splatting/gaussian_renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def render(
Background tensor (bg_color) must be on GPU!
"""

if bg_color is None:
bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")

Expand Down Expand Up @@ -80,9 +79,9 @@ def render(

# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
# They will be excluded from value updates used in the splitting criteria.
return {
"render": rendered_image,
"viewspace_points": screenspace_points,
"visibility_filter": radii > 0,
"radii": radii,
}
return (
rendered_image,
screenspace_points,
radii > 0,
radii,
)
9 changes: 5 additions & 4 deletions gaussian_splatting/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch import nn

from gaussian_splatting.utils.general import get_expon_lr_func

Expand Down Expand Up @@ -78,11 +79,11 @@ def state_dict(self):
def load_state_dict(self, state_dict):
self._optimizer.load_state_dict(state_dict)

def replace_tensor(self, tensor, name):
def replace_points(self, tensor, name):
optimizable_tensors = {}
for group in self._optimizer.param_groups:
if group["name"] == name:
stored_state = self.optimizer.state.get(group["params"][0], None)
stored_state = self._optimizer.state.get(group["params"][0], None)
stored_state["exp_avg"] = torch.zeros_like(tensor)
stored_state["exp_avg_sq"] = torch.zeros_like(tensor)

Expand All @@ -94,7 +95,7 @@ def replace_tensor(self, tensor, name):

return optimizable_tensors

def prune(self, mask):
def prune_points(self, mask):
optimizable_tensors = {}
for group in self._optimizer.param_groups:
stored_state = self._optimizer.state.get(group["params"][0], None)
Expand All @@ -117,7 +118,7 @@ def prune(self, mask):

return optimizable_tensors

def cat_tensors(self, tensors_dict):
def concatenate_points(self, tensors_dict):
optimizable_tensors = {}
for group in self._optimizer.param_groups:
assert len(group["params"]) == 1
Expand Down
126 changes: 110 additions & 16 deletions gaussian_splatting/scene/gaussian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from simple_knn._C import distCUDA2
from torch import nn

from gaussian_splatting.utils.general import (build_scaling_rotation,
from gaussian_splatting.utils.general import (build_rotation,
build_scaling_rotation,
inverse_sigmoid, strip_symmetric)
from gaussian_splatting.utils.graphics import BasicPointCloud
from gaussian_splatting.utils.sh import RGB2SH
Expand All @@ -34,7 +35,7 @@ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):

class GaussianModel:

def __init__(self, sh_degree: int = 3, percent_dense=0.01):
def __init__(self, sh_degree: int = 3):
self.active_sh_degree = 0
self.max_sh_degree = sh_degree

Expand All @@ -49,8 +50,6 @@ def __init__(self, sh_degree: int = 3, percent_dense=0.01):
self.xyz_gradient_accum = torch.empty(0)
self.denom = torch.empty(0)

self.percent_dense = percent_dense

self.scaling_activation = torch.exp
self.scaling_inverse_activation = torch.log
self.covariance_activation = build_covariance_from_scaling_rotation
Expand Down Expand Up @@ -111,8 +110,19 @@ def get_features(self):
def get_opacity(self):
return self.opacity_activation(self._opacity)

def set_opacity(self, opacity):
self._opacity = opacity
def set_optimizable_tensors(self, optimizable_tensors):
if "xyz" in optimizable_tensors:
self._xyz = optimizable_tensors["xyz"]
if "f_dc" in optimizable_tensors:
self._features_dc = optimizable_tensors["f_dc"]
if "f_rest" in optimizable_tensors:
self._features_rest = optimizable_tensors["f_rest"]
if "opacity" in optimizable_tensors:
self._opacity = optimizable_tensors["opacity"]
if "scaling" in optimizable_tensors:
self._scaling = optimizable_tensors["scaling"]
if "rotation" in optimizable_tensors:
self._rotation = optimizable_tensors["rotation"]

def get_covariance(self, scaling_modifier=1):
return self.covariance_activation(
Expand Down Expand Up @@ -167,8 +177,8 @@ def initialize(self, dataset):
self._scaling = nn.Parameter(scales.requires_grad_(True))
self._rotation = nn.Parameter(rots.requires_grad_(True))
self._opacity = nn.Parameter(opacities.requires_grad_(True))
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")

self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")

Expand Down Expand Up @@ -303,16 +313,100 @@ def load_ply(self, path):

self.active_sh_degree = self.max_sh_degree

def set_optimizable_tensors(self, optimizable_tensors):
self.set_xyz = optimizable_tensors["xyz"]
self._features_dc = optimizable_tensors["f_dc"]
self._features_rest = optimizable_tensors["f_rest"]
self._opacity = optimizable_tensors["opacity"]
self._scaling = optimizable_tensors["scaling"]
self._rotation = optimizable_tensors["rotation"]

def add_densification_stats(self, viewspace_point_tensor, update_filter):
def reset_opacity(self):
new_opacity = inverse_sigmoid(
torch.min(self._opacity, torch.ones_like(self._opacity) * 0.01)
)

return new_opacity

def split_points(self, gradient_threshold, percent_dense):
gradients = self.xyz_gradient_accum / self.denom
gradients[gradients.isnan()] = 0.0

# Extract large Gaussians in over-reconstruction regions with high variance gradients
split_mask = torch.logical_and(
torch.where(
gradients.detach().squeeze() >= gradient_threshold, True, False
),
torch.max(self.get_scaling, dim=1).values
> percent_dense * self.camera_extent,
)

stds = self.get_scaling[split_mask].repeat(2, 1)
means = torch.zeros((stds.size(0), 3), device="cuda")
samples = torch.normal(mean=means, std=stds)
rots = build_rotation(self._rotation[split_mask]).repeat(2, 1, 1)

new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self._xyz[
split_mask
].repeat(2, 1)
new_scaling = self.scaling_inverse_activation(
self.get_scaling[split_mask].repeat(2, 1) / (0.8 * 2)
)
new_rotation = self._rotation[split_mask].repeat(2, 1)
new_features_dc = self._features_dc[split_mask].repeat(2, 1, 1)
new_features_rest = self._features_rest[split_mask].repeat(2, 1, 1)
new_opacity = self._opacity[split_mask].repeat(2, 1)

new_points = {
"xyz": new_xyz,
"f_dc": new_features_dc,
"f_rest": new_features_rest,
"opacity": new_opacity,
"scaling": new_scaling,
"rotation": new_rotation,
}

return new_points, split_mask

def clone_points(self, gradient_threshold, percent_dense):
gradients = self.xyz_gradient_accum / self.denom
gradients[gradients.isnan()] = 0.0

# Extract small Gaussians in under-reconstruction regions with high variance gradients
clone_mask = torch.logical_and(
torch.where(
torch.norm(gradients, dim=-1) >= gradient_threshold, True, False
),
torch.max(self.get_scaling, dim=1).values
<= percent_dense * self.camera_extent,
)

new_xyz = self._xyz[clone_mask]
new_features_dc = self._features_dc[clone_mask]
new_features_rest = self._features_rest[clone_mask]
new_opacity = self._opacity[clone_mask]
new_scaling = self._scaling[clone_mask]
new_rotation = self._rotation[clone_mask]

new_points = {
"xyz": new_xyz,
"f_dc": new_features_dc,
"f_rest": new_features_rest,
"opacity": new_opacity,
"scaling": new_scaling,
"rotation": new_rotation,
}

return new_points, clone_mask

def update_stats(self, viewspace_point_tensor, update_filter, radii):
self.max_radii2D[update_filter] = torch.max(
self.max_radii2D[update_filter],
radii[update_filter],
)
self.xyz_gradient_accum[update_filter] += torch.norm(
viewspace_point_tensor.grad[update_filter, :2], dim=-1, keepdim=True
)
self.denom[update_filter] += 1

def mask_stats(self, mask):
self.max_radii2D = self.max_radii2D[mask]
self.xyz_gradient_accum = self.xyz_gradient_accum[mask]
self.denom = self.denom[mask]

def reset_stats(self):
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
self.xyz_gradient_accum = torch.zeros((self._xyz.shape[0], 1), device="cuda")
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
Loading

0 comments on commit 4f3f351

Please sign in to comment.