Skip to content

Commit

Permalink
Reformat code
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Mar 18, 2024
1 parent 4284c06 commit 3716147
Show file tree
Hide file tree
Showing 27 changed files with 1,099 additions and 552 deletions.
22 changes: 15 additions & 7 deletions gaussian_splatting/arguments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@
# For inquiries contact [email protected]
#

from argparse import ArgumentParser, Namespace
import sys
import os
import sys
from argparse import ArgumentParser, Namespace


class GroupParams:
pass


class ParamGroup:
def __init__(self, parser: ArgumentParser = None, name : str = "", fill_none = False):
def __init__(self, parser: ArgumentParser = None, name: str = "", fill_none=False):
if parser is None:
parser = ArgumentParser()
group = parser.add_argument_group(name)
Expand All @@ -31,9 +32,13 @@ def __init__(self, parser: ArgumentParser = None, name : str = "", fill_none = F
value = value if not fill_none else None
if shorthand:
if t == bool:
group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
group.add_argument(
"--" + key, ("-" + key[0:1]), default=value, action="store_true"
)
else:
group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
group.add_argument(
"--" + key, ("-" + key[0:1]), default=value, type=t
)
else:
if t == bool:
group.add_argument("--" + key, default=value, action="store_true")
Expand All @@ -47,6 +52,7 @@ def extract(self, args):
setattr(group, arg[0], arg[1])
return group


class ModelParams(ParamGroup):
def __init__(self, parser=None, source_path="", sentinel=False):
self.sh_degree = 3
Expand All @@ -62,6 +68,7 @@ def extract(self, args):
g.source_path = os.path.abspath(g.source_path)
return g


class OptimizationParams(ParamGroup):
def __init__(self, parser=None):
self.iterations = 30_000
Expand All @@ -82,7 +89,8 @@ def __init__(self, parser=None):
self.densify_grad_threshold = 0.0002
super().__init__(parser, "Optimization Parameters")

def get_combined_args(parser : ArgumentParser):

def get_combined_args(parser: ArgumentParser):
cmdlne_string = sys.argv[1:]
cfgfile_string = "Namespace()"
args_cmdline = parser.parse_args(cmdlne_string)
Expand All @@ -99,7 +107,7 @@ def get_combined_args(parser : ArgumentParser):
args_cfgfile = eval(cfgfile_string)

merged_dict = vars(args_cfgfile).copy()
for k,v in vars(args_cmdline).items():
for k, v in vars(args_cmdline).items():
if v != None:
merged_dict[k] = v
return Namespace(**merged_dict)
54 changes: 36 additions & 18 deletions gaussian_splatting/gaussian_renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,36 @@
# For inquiries contact [email protected]
#

import torch
import math
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer

import torch
from diff_gaussian_rasterization import (
GaussianRasterizationSettings,
GaussianRasterizer,
)

from gaussian_splatting.scene.gaussian_model import GaussianModel
from gaussian_splatting.utils.sh_utils import eval_sh

def render(viewpoint_camera, pc : GaussianModel, bg_color : torch.Tensor = None, scaling_modifier = 1.0):

def render(
viewpoint_camera,
pc: GaussianModel,
bg_color: torch.Tensor = None,
scaling_modifier=1.0,
):
"""
Render the scene.
Background tensor (bg_color) must be on GPU!
"""

# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
screenspace_points = (
torch.zeros_like(
pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
)
+ 0
)
try:
screenspace_points.retain_grad()
except:
Expand All @@ -48,7 +63,7 @@ def render(viewpoint_camera, pc : GaussianModel, bg_color : torch.Tensor = None,
sh_degree=pc.active_sh_degree,
campos=viewpoint_camera.camera_center,
prefiltered=False,
debug=False
debug=False,
)

rasterizer = GaussianRasterizer(raster_settings=raster_settings)
Expand All @@ -68,18 +83,21 @@ def render(viewpoint_camera, pc : GaussianModel, bg_color : torch.Tensor = None,

# Rasterize visible Gaussians to image, obtain their radii (on screen).
rendered_image, radii = rasterizer(
means3D = means3D,
means2D = means2D,
shs = shs,
colors_precomp = colors_precomp,
opacities = opacity,
scales = scales,
rotations = rotations,
cov3D_precomp = cov3D_precomp)
means3D=means3D,
means2D=means2D,
shs=shs,
colors_precomp=colors_precomp,
opacities=opacity,
scales=scales,
rotations=rotations,
cov3D_precomp=cov3D_precomp,
)

# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
# They will be excluded from value updates used in the splitting criteria.
return {"render": rendered_image,
"viewspace_points": screenspace_points,
"visibility_filter" : radii > 0,
"radii": radii}
return {
"render": rendered_image,
"viewspace_points": screenspace_points,
"visibility_filter": radii > 0,
"radii": radii,
}
7 changes: 3 additions & 4 deletions gaussian_splatting/lpipsPyTorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from gaussian_splatting.lpipsPyTorch.modules.lpips import LPIPS


def lpips(x: torch.Tensor,
y: torch.Tensor,
net_type: str = 'alex',
version: str = '0.1'):
def lpips(
x: torch.Tensor, y: torch.Tensor, net_type: str = "alex", version: str = "0.1"
):
r"""Function that measures
Learned Perceptual Image Patch Similarity (LPIPS).
Expand Down
7 changes: 4 additions & 3 deletions gaussian_splatting/lpipsPyTorch/modules/lpips.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn

from gaussian_splatting.lpipsPyTorch.modules.networks import get_network, LinLayers
from gaussian_splatting.lpipsPyTorch.modules.networks import LinLayers, get_network
from gaussian_splatting.lpipsPyTorch.modules.utils import get_state_dict


Expand All @@ -14,9 +14,10 @@ class LPIPS(nn.Module):
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
version (str): the version of LPIPS. Default: 0.1.
"""
def __init__(self, net_type: str = 'alex', version: str = '0.1'):

assert version in ['0.1'], 'v0.1 is only supported now'
def __init__(self, net_type: str = "alex", version: str = "0.1"):

assert version in ["0.1"], "v0.1 is only supported now"

super(LPIPS, self).__init__()

Expand Down
29 changes: 15 additions & 14 deletions gaussian_splatting/lpipsPyTorch/modules/networks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Sequence

from itertools import chain
from typing import Sequence

import torch
import torch.nn as nn
Expand All @@ -10,24 +9,24 @@


def get_network(net_type: str):
if net_type == 'alex':
if net_type == "alex":
return AlexNet()
elif net_type == 'squeeze':
elif net_type == "squeeze":
return SqueezeNet()
elif net_type == 'vgg':
elif net_type == "vgg":
return VGG16()
else:
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
raise NotImplementedError("choose net_type from [alex, squeeze, vgg].")


class LinLayers(nn.ModuleList):
def __init__(self, n_channels_list: Sequence[int]):
super(LinLayers, self).__init__([
nn.Sequential(
nn.Identity(),
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
) for nc in n_channels_list
])
super(LinLayers, self).__init__(
[
nn.Sequential(nn.Identity(), nn.Conv2d(nc, 1, 1, 1, 0, bias=False))
for nc in n_channels_list
]
)

for param in self.parameters():
param.requires_grad = False
Expand All @@ -39,9 +38,11 @@ def __init__(self):

# register buffer
self.register_buffer(
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
"mean", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
)
self.register_buffer(
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
"std", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
)

def set_requires_grad(self, state: bool):
for param in chain(self.parameters(), self.buffers()):
Expand Down
19 changes: 11 additions & 8 deletions gaussian_splatting/lpipsPyTorch/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,30 @@


def normalize_activation(x, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
return x / (norm_factor + eps)


def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
def get_state_dict(net_type: str = "alex", version: str = "0.1"):
# build url
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
+ f'master/lpips/weights/v{version}/{net_type}.pth'
url = (
"https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/"
+ f"master/lpips/weights/v{version}/{net_type}.pth"
)

# download
old_state_dict = torch.hub.load_state_dict_from_url(
url, progress=True,
map_location=None if torch.cuda.is_available() else torch.device('cpu')
url,
progress=True,
map_location=None if torch.cuda.is_available() else torch.device("cpu"),
)

# rename keys
new_state_dict = OrderedDict()
for key, val in old_state_dict.items():
new_key = key
new_key = new_key.replace('lin', '')
new_key = new_key.replace('model.', '')
new_key = new_key.replace("lin", "")
new_key = new_key.replace("model.", "")
new_state_dict[new_key] = val

return new_state_dict
Loading

0 comments on commit 3716147

Please sign in to comment.