From 188c2867723c8551e1542321ae846bd59186c3cf Mon Sep 17 00:00:00 2001 From: Isabelle Bouchard Date: Fri, 15 Mar 2024 11:18:53 -0400 Subject: [PATCH] remove network gui --- .../gaussian_renderer/network_gui.py | 86 ------------------- gaussian_splatting/training.py | 56 +----------- scripts/train.py | 4 - 3 files changed, 3 insertions(+), 143 deletions(-) delete mode 100644 gaussian_splatting/gaussian_renderer/network_gui.py diff --git a/gaussian_splatting/gaussian_renderer/network_gui.py b/gaussian_splatting/gaussian_renderer/network_gui.py deleted file mode 100644 index 0d40d4f8b..000000000 --- a/gaussian_splatting/gaussian_renderer/network_gui.py +++ /dev/null @@ -1,86 +0,0 @@ -# -# Copyright (C) 2023, Inria -# GRAPHDECO research group, https://team.inria.fr/graphdeco -# All rights reserved. -# -# This software is free for non-commercial, research and evaluation use -# under the terms of the LICENSE.md file. -# -# For inquiries contact george.drettakis@inria.fr -# - -import torch -import traceback -import socket -import json -from gaussian_splatting.scene.cameras import MiniCam - -host = "127.0.0.1" -port = 6009 - -conn = None -addr = None - -listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - -def init(wish_host, wish_port): - global host, port, listener - host = wish_host - port = wish_port - listener.bind((host, port)) - listener.listen() - listener.settimeout(0) - -def try_connect(): - global conn, addr, listener - try: - conn, addr = listener.accept() - print(f"\nConnected by {addr}") - conn.settimeout(None) - except Exception as inst: - pass - -def read(): - global conn - messageLength = conn.recv(4) - messageLength = int.from_bytes(messageLength, 'little') - message = conn.recv(messageLength) - return json.loads(message.decode("utf-8")) - -def send(message_bytes, verify): - global conn - if message_bytes != None: - conn.sendall(message_bytes) - conn.sendall(len(verify).to_bytes(4, 'little')) - conn.sendall(bytes(verify, 'ascii')) - -def receive(): - message = read() - - width = message["resolution_x"] - height = message["resolution_y"] - - if width != 0 and height != 0: - try: - do_training = bool(message["train"]) - fovy = message["fov_y"] - fovx = message["fov_x"] - znear = message["z_near"] - zfar = message["z_far"] - do_shs_python = bool(message["shs_python"]) - do_rot_scale_python = bool(message["rot_scale_python"]) - keep_alive = bool(message["keep_alive"]) - scaling_modifier = message["scaling_modifier"] - world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() - world_view_transform[:,1] = -world_view_transform[:,1] - world_view_transform[:,2] = -world_view_transform[:,2] - full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() - full_proj_transform[:,1] = -full_proj_transform[:,1] - custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) - except Exception as e: - print("") - traceback.print_exc() - raise e - return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier - else: - return None, None, None, None, None, None diff --git a/gaussian_splatting/training.py b/gaussian_splatting/training.py index 529a26d58..7844de0ec 100644 --- a/gaussian_splatting/training.py +++ b/gaussian_splatting/training.py @@ -2,7 +2,7 @@ import torch from random import randint from gaussian_splatting.utils.loss_utils import l1_loss, ssim -from gaussian_splatting.gaussian_renderer import render, network_gui +from gaussian_splatting.gaussian_renderer import render import sys from gaussian_splatting.scene import Scene, GaussianModel from gaussian_splatting.utils.general_utils import safe_state @@ -12,12 +12,6 @@ from argparse import Namespace from gaussian_splatting.arguments import ModelParams, PipelineParams, OptimizationParams -try: - from torch.utils.tensorboard import SummaryWriter - TENSORBOARD_FOUND = True -except ImportError: - TENSORBOARD_FOUND = False - class Trainer: def __init__( @@ -28,8 +22,6 @@ def __init__( checkpoint=None, debug_from=-1, quiet=False, - ip="127.0.0.1", - port=6009, detect_anomaly=False ): if testing_iterations is None: @@ -47,10 +39,8 @@ def __init__( self._checkpoint = checkpoint self._debug_from = debug_from - # Initialize system state (RNG) safe_state(quiet) - # Start GUI server, configure and run training - network_gui.init(ip, port) + torch.autograd.set_detect_anomaly(detect_anomaly) @@ -61,7 +51,6 @@ def run( pipe, ): first_iter = 0 - tb_writer = prepare_output_and_logger(dataset) gaussians = GaussianModel(dataset.sh_degree) scene = Scene(dataset, gaussians) gaussians.training_setup(opt) @@ -80,21 +69,6 @@ def run( progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") first_iter += 1 for iteration in range(first_iter, opt.iterations + 1): - if network_gui.conn == None: - network_gui.try_connect() - while network_gui.conn != None: - try: - net_image_bytes = None - custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() - if custom_cam != None: - net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] - net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) - network_gui.send(net_image_bytes, dataset.source_path) - if do_training and ((iteration < int(opt.iterations)) or not keep_alive): - break - except Exception as e: - network_gui.conn = None - iter_start.record() gaussians.update_learning_rate(iteration) @@ -136,7 +110,6 @@ def run( # Log and save training_report( - tb_writer, iteration, Ll1, loss, @@ -188,21 +161,8 @@ def prepare_output_and_logger(args): with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: cfg_log_f.write(str(Namespace(**vars(args)))) - # Create Tensorboard writer - tb_writer = None - if TENSORBOARD_FOUND: - tb_writer = SummaryWriter(args.model_path) - else: - print("Tensorboard not available: not logging progress") - return tb_writer - - -def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): - if tb_writer: - tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) - tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) - tb_writer.add_scalar('iter_time', elapsed, iteration) +def training_report(iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): # Report test and samples of training set if iteration in testing_iterations: torch.cuda.empty_cache() @@ -216,22 +176,12 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i for idx, viewpoint in enumerate(config['cameras']): image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0) gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) - if tb_writer and (idx < 5): - tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) - if iteration == testing_iterations[0]: - tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) l1_test += l1_loss(image, gt_image).mean().double() psnr_test += psnr(image, gt_image).mean().double() psnr_test /= len(config['cameras']) l1_test /= len(config['cameras']) print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) - if tb_writer: - tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) - tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) - if tb_writer: - tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) - tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) torch.cuda.empty_cache() print("\nTraining complete.") diff --git a/scripts/train.py b/scripts/train.py index 1f3d3b108..d9244c2a2 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -20,8 +20,6 @@ lp = ModelParams(parser) op = OptimizationParams(parser) pp = PipelineParams(parser) - parser.add_argument('--ip', type=str, default="127.0.0.1") - parser.add_argument('--port', type=int, default=6009) parser.add_argument('--debug_from', type=int, default=-1) parser.add_argument('--detect_anomaly', action='store_true', default=False) parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000]) @@ -39,8 +37,6 @@ checkpoint=args.start_checkpoint, debug_from=args.debug_from, quiet=args.quiet, - ip=args.ip, - port=args.port, detect_anomaly=args.detect_anomaly ) trainer.run(