Skip to content

Commit

Permalink
training() -> Trainer.run()
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Mar 15, 2024
1 parent 708d0d5 commit 5a1a3c3
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 148 deletions.
295 changes: 155 additions & 140 deletions gaussian_splatting/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,146 +19,160 @@
TENSORBOARD_FOUND = False


def training(
dataset,
opt,
pipe,
testing_iterations=None,
saving_iterations=None,
checkpoint_iterations=None,
checkpoint=None,
debug_from=-1,
quiet=False,
ip="127.0.0.1",
port=6009,
detect_anomaly=False
):
if dataset is None:
dataset = ModelParams(source_path=source_path)

if pipe is None:
pipe = PipelineParams

if opt is None:
opt = OptimizationParams()

if testing_iterations is None:
testing_iterations = [7_000, 30_000]

if saving_iterations is None:
saving_iterations = [7_000, 30_000]

if checkpoint_iterations is None:
checkpoint_iterations = []

# Initialize system state (RNG)
safe_state(quiet)

# Start GUI server, configure and run training
network_gui.init(ip, port)
torch.autograd.set_detect_anomaly(detect_anomaly)

first_iter = 0
tb_writer = prepare_output_and_logger(dataset)
gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians)
gaussians.training_setup(opt)
if checkpoint:
(model_params, first_iter) = torch.load(checkpoint)
gaussians.restore(model_params, opt)

bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

iter_start = torch.cuda.Event(enable_timing = True)
iter_end = torch.cuda.Event(enable_timing = True)

viewpoint_stack = None
ema_loss_for_log = 0.0
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
first_iter += 1
for iteration in range(first_iter, opt.iterations + 1):
if network_gui.conn == None:
network_gui.try_connect()
while network_gui.conn != None:
try:
net_image_bytes = None
custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
if custom_cam != None:
net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]
net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
network_gui.send(net_image_bytes, dataset.source_path)
if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
break
except Exception as e:
network_gui.conn = None

iter_start.record()

gaussians.update_learning_rate(iteration)

# Every 1000 its we increase the levels of SH up to a maximum degree
if iteration % 1000 == 0:
gaussians.oneupSHdegree()

# Pick a random Camera
if not viewpoint_stack:
viewpoint_stack = scene.getTrainCameras().copy()
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))

# Render
if (iteration - 1) == debug_from:
pipe.debug = True

bg = torch.rand((3), device="cuda") if opt.random_background else background

render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]

# Loss
gt_image = viewpoint_cam.original_image.cuda()
Ll1 = l1_loss(image, gt_image)
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
loss.backward()

iter_end.record()

with torch.no_grad():
# Progress bar
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
if iteration % 10 == 0:
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
progress_bar.update(10)
if iteration == opt.iterations:
progress_bar.close()

# Log and save
training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
if (iteration in saving_iterations):
print("\n[ITER {}] Saving Gaussians".format(iteration))
scene.save(iteration)

# Densification
if iteration < opt.densify_until_iter:
# Keep track of max radii in image-space for pruning
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)

if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
size_threshold = 20 if iteration > opt.opacity_reset_interval else None
gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)

if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
gaussians.reset_opacity()

# Optimizer step
if iteration < opt.iterations:
gaussians.optimizer.step()
gaussians.optimizer.zero_grad(set_to_none = True)

if (iteration in checkpoint_iterations):
print("\n[ITER {}] Saving Checkpoint".format(iteration))
torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
class Trainer:
def __init__(
self,
testing_iterations=None,
saving_iterations=None,
checkpoint_iterations=None,
checkpoint=None,
debug_from=-1,
quiet=False,
ip="127.0.0.1",
port=6009,
detect_anomaly=False
):
if testing_iterations is None:
testing_iterations = [7_000, 30_000]
self._testing_iterations = testing_iterations

if saving_iterations is None:
saving_iterations = [7_000, 30_000]
self._saving_iterations = saving_iterations

if checkpoint_iterations is None:
checkpoint_iterations = []
self._checkpoint_iterations = checkpoint_iterations

self._checkpoint = checkpoint
self._debug_from = debug_from

# Initialize system state (RNG)
safe_state(quiet)
# Start GUI server, configure and run training
network_gui.init(ip, port)
torch.autograd.set_detect_anomaly(detect_anomaly)


def run(
self,
dataset,
opt,
pipe,
):
first_iter = 0
tb_writer = prepare_output_and_logger(dataset)
gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians)
gaussians.training_setup(opt)
if self._checkpoint:
(model_params, first_iter) = torch.load(self._checkpoint)
gaussians.restore(model_params, opt)

bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

iter_start = torch.cuda.Event(enable_timing = True)
iter_end = torch.cuda.Event(enable_timing = True)

viewpoint_stack = None
ema_loss_for_log = 0.0
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
first_iter += 1
for iteration in range(first_iter, opt.iterations + 1):
if network_gui.conn == None:
network_gui.try_connect()
while network_gui.conn != None:
try:
net_image_bytes = None
custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
if custom_cam != None:
net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]
net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
network_gui.send(net_image_bytes, dataset.source_path)
if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
break
except Exception as e:
network_gui.conn = None

iter_start.record()

gaussians.update_learning_rate(iteration)

# Every 1000 its we increase the levels of SH up to a maximum degree
if iteration % 1000 == 0:
gaussians.oneupSHdegree()

# Pick a random Camera
if not viewpoint_stack:
viewpoint_stack = scene.getTrainCameras().copy()
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))

# Render
if (iteration - 1) == self._debug_from:
pipe.debug = True

bg = torch.rand((3), device="cuda") if opt.random_background else background

render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]

# Loss
gt_image = viewpoint_cam.original_image.cuda()
Ll1 = l1_loss(image, gt_image)
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
loss.backward()

iter_end.record()

with torch.no_grad():
# Progress bar
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
if iteration % 10 == 0:
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
progress_bar.update(10)
if iteration == opt.iterations:
progress_bar.close()

# Log and save
training_report(
tb_writer,
iteration,
Ll1,
loss,
l1_loss,
iter_start.elapsed_time(iter_end),
self._testing_iterations,
scene,
render,
(pipe, background)
)
if (iteration in self._saving_iterations):
print("\n[ITER {}] Saving Gaussians".format(iteration))
scene.save(iteration)

# Densification
if iteration < opt.densify_until_iter:
# Keep track of max radii in image-space for pruning
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)

if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
size_threshold = 20 if iteration > opt.opacity_reset_interval else None
gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)

if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
gaussians.reset_opacity()

# Optimizer step
if iteration < opt.iterations:
gaussians.optimizer.step()
gaussians.optimizer.zero_grad(set_to_none = True)

if (iteration in self._checkpoint_iterations):
print("\n[ITER {}] Saving Checkpoint".format(iteration))
torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")


def prepare_output_and_logger(args):
if not args.model_path:
Expand All @@ -182,6 +196,7 @@ def prepare_output_and_logger(args):
print("Tensorboard not available: not logging progress")
return tb_writer


def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs):
if tb_writer:
tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
Expand Down
12 changes: 7 additions & 5 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sys
from argparse import ArgumentParser
from gaussian_splatting.arguments import ModelParams, PipelineParams, OptimizationParams
from gaussian_splatting.training import training
from gaussian_splatting.training import Trainer


if __name__ == "__main__":
Expand All @@ -32,10 +32,7 @@
args = parser.parse_args(sys.argv[1:])
args.save_iterations.append(args.iterations)

training(
dataset=lp.extract(args),
opt=op.extract(args),
pipe=pp.extract(args),
trainer = Trainer(
testing_iterations=args.test_iterations,
saving_iterations=args.save_iterations,
checkpoint_iterations=args.checkpoint_iterations,
Expand All @@ -46,3 +43,8 @@
port=args.port,
detect_anomaly=args.detect_anomaly
)
trainer.run(
dataset=lp.extract(args),
opt=op.extract(args),
pipe=pp.extract(args),
)
7 changes: 4 additions & 3 deletions scripts/train_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
class Dataset():
def __init__(self,):
self.sh_degree = 3
self.source_path = "/workspace/data/train/"
self.source_path = "/workspace/data/phil_open/5"
self.model_path = ""
self.images = "images"
self.resolution = -1
Expand Down Expand Up @@ -97,9 +97,10 @@ def __init__(self):
timeout=10800
)
def f():
from gaussian_splatting.training import training
from gaussian_splatting.training import Trainer

training(
trainer = Trainer()
trainer.run(
dataset=Dataset(),
opt=Optimization(),
pipe=Pipeline()
Expand Down

0 comments on commit 5a1a3c3

Please sign in to comment.