From c8fffbfa6efc858b79098e2c28b0c948a87ff55c Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Wed, 20 Mar 2024 15:11:30 +0000 Subject: [PATCH] clean up some deterministic related code --- bin/deterministic/main.py | 2 +- src/ml_downscaling_emulator/bin/evaluate.py | 6 +++--- tests/deterministic/smoke-test | 21 ++++++++++----------- tests/smoke-test | 5 +++-- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/bin/deterministic/main.py b/bin/deterministic/main.py index 5c409b000..801531607 100644 --- a/bin/deterministic/main.py +++ b/bin/deterministic/main.py @@ -13,7 +13,7 @@ "config", None, "Training configuration.", lock_config=True ) flags.DEFINE_string("workdir", None, "Work directory.") -flags.DEFINE_enum("mode", None, ["train", "eval"], "Running mode: train or eval") +flags.DEFINE_enum("mode", None, ["train"], "Running mode: train.") flags.mark_flags_as_required(["workdir", "config", "mode"]) diff --git a/src/ml_downscaling_emulator/bin/evaluate.py b/src/ml_downscaling_emulator/bin/evaluate.py index bbf4dae23..5146ddbd1 100644 --- a/src/ml_downscaling_emulator/bin/evaluate.py +++ b/src/ml_downscaling_emulator/bin/evaluate.py @@ -58,7 +58,7 @@ def sample( workdir: Path, dataset: str = typer.Option(...), split: str = "val", - epoch: int = typer.Option(...), + checkpoint: str = typer.Option(...), batch_size: int = None, num_samples: int = 1, input_transform_dataset: str = None, @@ -81,7 +81,7 @@ def sample( output_dirpath = samples_path( workdir=workdir, - checkpoint=f"epoch-{epoch}", + checkpoint=checkpoint, dataset=dataset, input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}", split=split, @@ -106,7 +106,7 @@ def sample( shuffle=False, ) - ckpt_filename = os.path.join(workdir, "checkpoints", f"epoch_{epoch}.pth") + ckpt_filename = os.path.join(workdir, "checkpoints", f"{checkpoint}.pth") num_predictors = eval_dl.dataset[0][0].shape[0] state = load_model(config, num_predictors, ckpt_filename) diff --git a/tests/deterministic/smoke-test b/tests/deterministic/smoke-test index 9e4789ae6..d0fdc140e 100755 --- a/tests/deterministic/smoke-test +++ b/tests/deterministic/smoke-test @@ -4,23 +4,22 @@ set -euo pipefail SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -cpm_dataset="bham_gcmx-4x_12em_psl-temp4th-vort4th_eqvt_random-season" -gcm_dataset="bham_60km-4x_12em_psl-temp4th-vort4th_eqvt_random-season" +cpm_dataset="bham_gcmx-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic" +gcm_dataset="bham_60km-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic" workdir="output/test/unet/test-run" config_path="src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py" -map_features=0 -train_batch_size=4 +train_batch_size=2 +epoch=2 rm -rf ${workdir} -WANDB_EXPERIMENT_NAME="test" python ${SCRIPT_DIR}/main.py --mode train --workdir ${workdir} --config ${config_path} --config.data.dataset_name=${cpm_dataset} --config.training.batch_size=${train_batch_size} --config.training.n_epochs=2 --config.data.time_inputs=True --config.model.name=debug +WANDB_DISABLE_SERVICE=True WANDB_EXPERIMENT_NAME="test" python ${SCRIPT_DIR}/../../bin/deterministic/main.py --mode train --workdir ${workdir} --config ${config_path} --config.data.dataset_name=${cpm_dataset} --config.training.batch_size=${train_batch_size} --config.training.n_epochs=${epoch} --config.data.time_inputs=True --config.model.name=debug -epoch=2 num_samples=2 -eval_batchsize=128 +eval_batchsize=32 -rm -rf "${workdir}/samples/epoch-${epoch}/${cpm_dataset}" -mlde evaluate sample ${workdir} --dataset ${cpm_dataset} --epoch ${epoch} --batch-size ${eval_batchsize} --num-samples ${num_samples} -rm -rf "${workdir}/samples/epoch-${epoch}/${gcm_dataset}" -mlde evaluate sample ${workdir} --dataset ${gcm_dataset} --epoch ${epoch} --batch-size ${eval_batchsize} --num-samples ${num_samples} +rm -rf "${workdir}/samples/epoch_${epoch}/${cpm_dataset}" +mlde evaluate sample ${workdir} --dataset ${cpm_dataset} --checkpoint epoch_${epoch} --batch-size ${eval_batchsize} --num-samples ${num_samples} +rm -rf "${workdir}/samples/epoch_${epoch}/${gcm_dataset}" +mlde evaluate sample ${workdir} --dataset ${gcm_dataset} --checkpoint epoch_${epoch} --batch-size ${eval_batchsize} --num-samples ${num_samples} diff --git a/tests/smoke-test b/tests/smoke-test index fb4e766e0..21ff64615 100755 --- a/tests/smoke-test +++ b/tests/smoke-test @@ -12,11 +12,12 @@ config_path="src/ml_downscaling_emulator/score_sde_pytorch/configs/${sde}/${conf loc_spec_channels=0 train_batch_size=2 random_crop_size=32 +epoch=2 rm -rf ${workdir} -WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=2 --config.model.loc_spec_channels=${loc_spec_channels} --config.training.random_crop_size=${random_crop_size} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True +WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=${epoch} --config.model.loc_spec_channels=${loc_spec_channels} --config.training.random_crop_size=${random_crop_size} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True + -epoch=2 num_samples=2 eval_batchsize=32 checkpoint="epoch_${epoch}"