Skip to content

Commit

Permalink
Remove detect_anomaly and quiet
Browse files Browse the repository at this point in the history
  • Loading branch information
bouchardi committed Mar 18, 2024
1 parent a655198 commit 4e00452
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 22 deletions.
12 changes: 3 additions & 9 deletions gaussian_splatting/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,15 @@ def __init__(
saving_iterations=None,
checkpoint_iterations=None,
checkpoint_path=None,
quiet=False,
detect_anomaly=False,
):
self._resolution = resolution

if testing_iterations is None:
testing_iterations = [7_000, 30_000]
testing_iterations = [7000, 30000]
self._testing_iterations = testing_iterations

if saving_iterations is None:
saving_iterations = [7_000, 30_000]
saving_iterations = [7000, 30000]
self._saving_iterations = saving_iterations

if checkpoint_iterations is None:
Expand All @@ -40,10 +38,6 @@ def __init__(

self._checkpoint_path = checkpoint_path

safe_state(quiet)

torch.autograd.set_detect_anomaly(detect_anomaly)

def run(
self,
dataset,
Expand Down Expand Up @@ -220,4 +214,4 @@ def training_report(

torch.cuda.empty_cache()

print("\nTraining complete.")
print("\nTraining complete.")
8 changes: 4 additions & 4 deletions gaussian_splatting/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def build_scaling_rotation(s, r):
return L


def safe_state(silent):
def safe_state(silent=True, seed=0):
old_f = sys.stdout

class F:
Expand All @@ -146,7 +146,7 @@ def flush(self):

sys.stdout = F(silent)

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.set_device(torch.device("cuda:0"))
4 changes: 2 additions & 2 deletions scripts/full_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
args = parser.parse_args()

if not args.skip_training:
common_args = " --quiet --eval --test_iterations -1 "
common_args = " --eval --test_iterations -1 "
for scene in mipnerf360_outdoor_scenes:
source = args.mipnerf360 + "/" + scene
os.system(
Expand Down Expand Up @@ -94,7 +94,7 @@
for scene in deep_blending_scenes:
all_sources.append(args.deepblending + "/" + scene)

common_args = " --quiet --eval --skip_train"
common_args = " --eval --skip_train"
for scene, source in zip(all_scenes, all_sources):
os.system(
"python render.py --iteration 7000 -s "
Expand Down
4 changes: 1 addition & 3 deletions scripts/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,9 @@ def render_sets(
parser.add_argument("--iteration", default=-1, type=int)
parser.add_argument("--skip_train", action="store_true")
parser.add_argument("--skip_test", action="store_true")
parser.add_argument("--quiet", action="store_true")
args = get_combined_args(parser)
print("Rendering " + args.model_path)

# Initialize system state (RNG)
safe_state(args.quiet)
safe_state()

render_sets(model.extract(args), args.iteration, args.skip_train, args.skip_test)
4 changes: 0 additions & 4 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
parser = ArgumentParser(description="Training script parameters")
lp = ModelParams(parser)
op = OptimizationParams(parser)
parser.add_argument("--detect_anomaly", action="store_true", default=False)
parser.add_argument(
"--test_iterations", nargs="+", type=int, default=[7_000, 30_000]
)
parser.add_argument(
"--save_iterations", nargs="+", type=int, default=[7_000, 30_000]
)
parser.add_argument("--quiet", action="store_true")
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
parser.add_argument("--checkpoint_path", type=str, default=None)
parser.add_argument("--resolution", default=-1, type=int)
Expand All @@ -39,8 +37,6 @@
saving_iterations=args.save_iterations,
checkpoint_iterations=args.checkpoint_iterations,
checkpoint_path=args.checkpoint_path,
quiet=args.quiet,
detect_anomaly=args.detect_anomaly,
)
trainer.run(
dataset=lp.extract(args),
Expand Down

0 comments on commit 4e00452

Please sign in to comment.