diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index 21c51d8e..9283fc3a 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -100,15 +100,15 @@ def train(config, workdir): tb_dir = os.path.join(workdir, "tensorboard") os.makedirs(tb_dir, exist_ok=True) + run_name = os.path.basename(workdir) run_config = dict( dataset=config.data.dataset_name, input_transform_key=config.data.input_transform_key, target_transform_key=config.data.target_transform_key, architecture=config.model.name, sde=config.training.sde, - name=os.path.basename(workdir), + name=run_name, ) - run_name = os.path.basename(workdir) with track_run( EXPERIMENT_NAME, run_name, run_config, ["score_sde"], tb_dir