From 4e00452122184a5952c68adeabe2c3469ea40e02 Mon Sep 17 00:00:00 2001 From: Isabelle Bouchard Date: Mon, 18 Mar 2024 13:29:14 -0400 Subject: [PATCH] Remove detect_anomaly and quiet --- gaussian_splatting/training.py | 12 +++--------- gaussian_splatting/utils/general.py | 8 ++++---- scripts/full_eval.py | 4 ++-- scripts/render.py | 4 +--- scripts/train.py | 4 ---- 5 files changed, 10 insertions(+), 22 deletions(-) diff --git a/gaussian_splatting/training.py b/gaussian_splatting/training.py index e9b8cf3b9..3233d8945 100644 --- a/gaussian_splatting/training.py +++ b/gaussian_splatting/training.py @@ -21,17 +21,15 @@ def __init__( saving_iterations=None, checkpoint_iterations=None, checkpoint_path=None, - quiet=False, - detect_anomaly=False, ): self._resolution = resolution if testing_iterations is None: - testing_iterations = [7_000, 30_000] + testing_iterations = [7000, 30000] self._testing_iterations = testing_iterations if saving_iterations is None: - saving_iterations = [7_000, 30_000] + saving_iterations = [7000, 30000] self._saving_iterations = saving_iterations if checkpoint_iterations is None: @@ -40,10 +38,6 @@ def __init__( self._checkpoint_path = checkpoint_path - safe_state(quiet) - - torch.autograd.set_detect_anomaly(detect_anomaly) - def run( self, dataset, @@ -220,4 +214,4 @@ def training_report( torch.cuda.empty_cache() - print("\nTraining complete.") + print("\nTraining complete.") diff --git a/gaussian_splatting/utils/general.py b/gaussian_splatting/utils/general.py index 3040fb29c..6ae7f15c4 100644 --- a/gaussian_splatting/utils/general.py +++ b/gaussian_splatting/utils/general.py @@ -120,7 +120,7 @@ def build_scaling_rotation(s, r): return L -def safe_state(silent): +def safe_state(silent=True, seed=0): old_f = sys.stdout class F: @@ -146,7 +146,7 @@ def flush(self): sys.stdout = F(silent) - random.seed(0) - np.random.seed(0) - torch.manual_seed(0) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) torch.cuda.set_device(torch.device("cuda:0")) diff --git a/scripts/full_eval.py b/scripts/full_eval.py index 4c83a9418..b568ab7eb 100644 --- a/scripts/full_eval.py +++ b/scripts/full_eval.py @@ -37,7 +37,7 @@ args = parser.parse_args() if not args.skip_training: - common_args = " --quiet --eval --test_iterations -1 " + common_args = " --eval --test_iterations -1 " for scene in mipnerf360_outdoor_scenes: source = args.mipnerf360 + "/" + scene os.system( @@ -94,7 +94,7 @@ for scene in deep_blending_scenes: all_sources.append(args.deepblending + "/" + scene) - common_args = " --quiet --eval --skip_train" + common_args = " --eval --skip_train" for scene, source in zip(all_scenes, all_sources): os.system( "python render.py --iteration 7000 -s " diff --git a/scripts/render.py b/scripts/render.py index 43ee759c0..1c189ca1e 100644 --- a/scripts/render.py +++ b/scripts/render.py @@ -84,11 +84,9 @@ def render_sets( parser.add_argument("--iteration", default=-1, type=int) parser.add_argument("--skip_train", action="store_true") parser.add_argument("--skip_test", action="store_true") - parser.add_argument("--quiet", action="store_true") args = get_combined_args(parser) print("Rendering " + args.model_path) - # Initialize system state (RNG) - safe_state(args.quiet) + safe_state() render_sets(model.extract(args), args.iteration, args.skip_train, args.skip_test) diff --git a/scripts/train.py b/scripts/train.py index 5e2c07db3..bd35b70a6 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -19,14 +19,12 @@ parser = ArgumentParser(description="Training script parameters") lp = ModelParams(parser) op = OptimizationParams(parser) - parser.add_argument("--detect_anomaly", action="store_true", default=False) parser.add_argument( "--test_iterations", nargs="+", type=int, default=[7_000, 30_000] ) parser.add_argument( "--save_iterations", nargs="+", type=int, default=[7_000, 30_000] ) - parser.add_argument("--quiet", action="store_true") parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) parser.add_argument("--checkpoint_path", type=str, default=None) parser.add_argument("--resolution", default=-1, type=int) @@ -39,8 +37,6 @@ saving_iterations=args.save_iterations, checkpoint_iterations=args.checkpoint_iterations, checkpoint_path=args.checkpoint_path, - quiet=args.quiet, - detect_anomaly=args.detect_anomaly, ) trainer.run( dataset=lp.extract(args),