Skip to content

Commit

Permalink
remove twice tensorboard log
Browse files Browse the repository at this point in the history
  • Loading branch information
blaisewf committed Dec 21, 2024
1 parent 403ddae commit 7a8e12c
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions rvc/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,9 @@ def run(
smoothed_value_disc = 0

if rank == 0:
writer = SummaryWriter(log_dir=experiment_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(experiment_dir, "eval"))
else:
writer, writer_eval = None, None
writer_eval = None

dist.init_process_group(
backend="gloo",
Expand Down Expand Up @@ -489,7 +488,7 @@ def run(
[optim_g, optim_d],
scaler,
[train_loader, None],
[writer, writer_eval],
[writer_eval],
cache,
custom_save_every_weights,
custom_total_epoch,
Expand Down Expand Up @@ -529,15 +528,14 @@ def train_and_evaluate(
optims (list): List of optimizers [optim_g, optim_d].
scaler (GradScaler): Gradient scaler for mixed precision training.
loaders (list): List of dataloaders [train_loader, eval_loader].
writers (list): List of TensorBoard writers [writer, writer_eval].
writers (list): List of TensorBoard writers [writer_eval].
cache (list): List to cache data in GPU memory.
use_cpu (bool): Whether to use CPU for training.
"""
global global_step, lowest_value, loss_disc, consecutive_increases_gen, consecutive_increases_disc, smoothed_value_gen, smoothed_value_disc

if epoch == 1:
lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
last_loss_gen_all = 0.0
consecutive_increases_gen = 0
consecutive_increases_disc = 0

Expand Down

0 comments on commit 7a8e12c

Please sign in to comment.