Skip to content

Commit

Permalink
remove background color option - always black
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Mar 18, 2024
1 parent 209f93e commit 4284c06
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 22 deletions.
2 changes: 0 additions & 2 deletions gaussian_splatting/arguments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__(self, parser=None, source_path="", sentinel=False):
self._model_path = ""
self._images = "images"
self._resolution = -1
self._white_background = False
self.eval = False
super().__init__(parser, "Loading Parameters", sentinel)

Expand All @@ -81,7 +80,6 @@ def __init__(self, parser=None):
self.densify_from_iter = 500
self.densify_until_iter = 15_000
self.densify_grad_threshold = 0.0002
self.random_background = False
super().__init__(parser, "Optimization Parameters")

def get_combined_args(parser : ArgumentParser):
Expand Down
5 changes: 4 additions & 1 deletion gaussian_splatting/gaussian_renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from gaussian_splatting.scene.gaussian_model import GaussianModel
from gaussian_splatting.utils.sh_utils import eval_sh

def render(viewpoint_camera, pc : GaussianModel, bg_color : torch.Tensor, scaling_modifier = 1.0):
def render(viewpoint_camera, pc : GaussianModel, bg_color : torch.Tensor = None, scaling_modifier = 1.0):
"""
Render the scene.
Expand All @@ -29,6 +29,9 @@ def render(viewpoint_camera, pc : GaussianModel, bg_color : torch.Tensor, scalin
except:
pass

if bg_color is None:
bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")

# Set up rasterization configuration
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
Expand Down
14 changes: 4 additions & 10 deletions gaussian_splatting/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ def run(
(model_params, first_iter) = torch.load(self._checkpoint_path)
gaussians.restore(model_params, opt)

bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

iter_start = torch.cuda.Event(enable_timing = True)
iter_end = torch.cuda.Event(enable_timing = True)

Expand All @@ -79,9 +76,7 @@ def run(
viewpoint_stack = scene.getTrainCameras().copy()
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))

bg = torch.rand((3), device="cuda") if opt.random_background else background

render_pkg = render(viewpoint_cam, gaussians, bg)
render_pkg = render(viewpoint_cam, gaussians)
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]

# Loss
Expand Down Expand Up @@ -111,7 +106,6 @@ def run(
self._testing_iterations,
scene,
render,
(background)
)
if (iteration in self._saving_iterations):
print("\n[ITER {}] Saving Gaussians".format(iteration))
Expand All @@ -127,7 +121,7 @@ def run(
size_threshold = 20 if iteration > opt.opacity_reset_interval else None
gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)

if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
if iteration % opt.opacity_reset_interval == 0 or iteration == opt.densify_from_iter:
gaussians.reset_opacity()

# Optimizer step
Expand Down Expand Up @@ -155,7 +149,7 @@ def prepare_output_and_logger(args):
cfg_log_f.write(str(Namespace(**vars(args))))


def training_report(iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs):
def training_report(iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc):
# Report test and samples of training set
if iteration in testing_iterations:
torch.cuda.empty_cache()
Expand All @@ -167,7 +161,7 @@ def training_report(iteration, Ll1, loss, l1_loss, elapsed, testing_iterations,
l1_test = 0.0
psnr_test = 0.0
for idx, viewpoint in enumerate(config['cameras']):
image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0)
image = torch.clamp(renderFunc(viewpoint, scene.gaussians)["render"], 0.0, 1.0)
gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
l1_test += l1_loss(image, gt_image).mean().double()
psnr_test += psnr(image, gt_image).mean().double()
Expand Down
11 changes: 4 additions & 7 deletions scripts/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
from gaussian_splatting.gaussian_renderer import GaussianModel
from gaussian_splatting.scene import Scene

def render_set(model_path, name, iteration, views, gaussians, background):
def render_set(model_path, name, iteration, views, gaussians):
render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")

makedirs(render_path, exist_ok=True)
makedirs(gts_path, exist_ok=True)

for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
rendering = render(view, gaussians, background)["render"]
rendering = render(view, gaussians)["render"]
gt = view.original_image[0:3, :, :]
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
Expand All @@ -40,14 +40,11 @@ def render_sets(dataset : ModelParams, iteration : int, skip_train : bool, skip_
gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)

bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

if not skip_train:
render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, background)
render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians)

if not skip_test:
render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, background)
render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians)

if __name__ == "__main__":
# Set up command line argument parser
Expand Down
2 changes: 0 additions & 2 deletions scripts/train_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(self,):
self.model_path = ""
self.images = "images"
self.resolution = -1
self.white_background = False
self.eval = False

class Optimization():
Expand All @@ -78,7 +77,6 @@ def __init__(self):
self.densify_from_iter = 500
self.densify_until_iter = 15000
self.densify_grad_threshold = 0.0002
self.random_background = False


@stub.function(
Expand Down

0 comments on commit 4284c06

Please sign in to comment.