diff --git a/bin/deterministic/local-test-train b/bin/deterministic/local-test-train deleted file mode 100755 index ecaf40d9a..000000000 --- a/bin/deterministic/local-test-train +++ /dev/null @@ -1,26 +0,0 @@ -#! /usr/bin/env bash - -set -euo pipefail - -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) - -cpm_dataset="bham_gcmx-4x_12em_psl-temp4th-vort4th_eqvt_random-season" -gcm_dataset="bham_60km-4x_12em_psl-temp4th-vort4th_eqvt_random-season" -workdir="output/test/unet/test-run" - -config_path="src/ml_downscaling_emulator/deterministic/configs/default.py" - -map_features=1 -train_batch_size=32 - -rm -rf ${workdir} -WANDB_EXPERIMENT_NAME="test" python ${SCRIPT_DIR}/main.py --mode train --workdir ${workdir} --config ${config_path} --config.data.dataset_name=${cpm_dataset} --config.training.batch_size=${train_batch_size} --config.training.n_epochs=2 --config.data.time_inputs=True --config.model.name=debug - -epoch=2 -num_samples=2 -eval_batchsize=128 - -rm -rf "${workdir}/samples/epoch-${epoch}/${cpm_dataset}" -mlde evaluate sample ${workdir} --dataset ${cpm_dataset} --epoch ${epoch} --batch-size ${eval_batchsize} --num-samples ${num_samples} -rm -rf "${workdir}/samples/epoch-${epoch}/${gcm_dataset}" -mlde evaluate sample ${workdir} --dataset ${gcm_dataset} --epoch ${epoch} --batch-size ${eval_batchsize} --num-samples ${num_samples} diff --git a/bin/deterministic/main.py b/bin/deterministic/main.py index 5c409b000..801531607 100644 --- a/bin/deterministic/main.py +++ b/bin/deterministic/main.py @@ -13,7 +13,7 @@ "config", None, "Training configuration.", lock_config=True ) flags.DEFINE_string("workdir", None, "Work directory.") -flags.DEFINE_enum("mode", None, ["train", "eval"], "Running mode: train or eval") +flags.DEFINE_enum("mode", None, ["train"], "Running mode: train.") flags.mark_flags_as_required(["workdir", "config", "mode"]) diff --git a/src/ml_downscaling_emulator/bin/evaluate.py b/src/ml_downscaling_emulator/bin/evaluate.py index a5d4b3eea..5146ddbd1 100644 --- a/src/ml_downscaling_emulator/bin/evaluate.py +++ b/src/ml_downscaling_emulator/bin/evaluate.py @@ -58,7 +58,7 @@ def sample( workdir: Path, dataset: str = typer.Option(...), split: str = "val", - epoch: int = typer.Option(...), + checkpoint: str = typer.Option(...), batch_size: int = None, num_samples: int = 1, input_transform_dataset: str = None, @@ -71,16 +71,17 @@ def sample( if batch_size is not None: config.eval.batch_size = batch_size - if input_transform_dataset is not None: - config.data.input_transform_dataset = input_transform_dataset - else: - config.data.input_transform_dataset = dataset + with config.unlocked(): + if input_transform_dataset is not None: + config.data.input_transform_dataset = input_transform_dataset + else: + config.data.input_transform_dataset = dataset if input_transform_key is not None: config.data.input_transform_key = input_transform_key output_dirpath = samples_path( workdir=workdir, - checkpoint=f"epoch-{epoch}", + checkpoint=checkpoint, dataset=dataset, input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}", split=split, @@ -105,7 +106,7 @@ def sample( shuffle=False, ) - ckpt_filename = os.path.join(workdir, "checkpoints", f"epoch_{epoch}.pth") + ckpt_filename = os.path.join(workdir, "checkpoints", f"{checkpoint}.pth") num_predictors = eval_dl.dataset[0][0].shape[0] state = load_model(config, num_predictors, ckpt_filename) diff --git a/src/ml_downscaling_emulator/deterministic/configs/default.py b/src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py similarity index 86% rename from src/ml_downscaling_emulator/deterministic/configs/default.py rename to src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py index 709be5b81..1a710d81f 100644 --- a/src/ml_downscaling_emulator/deterministic/configs/default.py +++ b/src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py @@ -16,9 +16,12 @@ def get_config(): evaluate.batch_size = 64 config.data = data = ml_collections.ConfigDict() - data.dataset_name = "" + data.dataset_name = ( + "bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season" + ) data.input_transform_key = "stan" data.target_transform_key = "sqrturrecen" + data.input_transform_dataset = None data.time_inputs = False config.model = model = ml_collections.ConfigDict() diff --git a/src/ml_downscaling_emulator/deterministic/run_lib.py b/src/ml_downscaling_emulator/deterministic/run_lib.py index 79af70e34..4f0b3c04d 100644 --- a/src/ml_downscaling_emulator/deterministic/run_lib.py +++ b/src/ml_downscaling_emulator/deterministic/run_lib.py @@ -13,7 +13,7 @@ from ..training import log_epoch, track_run from .utils import restore_checkpoint, save_checkpoint, create_model -from ..torch import get_dataloader +from ..data import get_dataloader FLAGS = flags.FLAGS EXPERIMENT_NAME = os.getenv("WANDB_EXPERIMENT_NAME") @@ -208,7 +208,6 @@ def train_step_fn(state, batch, cond): EXPERIMENT_NAME, run_name, run_config, [config.model.name, "baseline"], tb_dir ) as (wandb_run, tb_writer): # Fit model - wandb_run.watch(model, criterion=criterion, log_freq=100) logging.info("Starting training loop at epoch %d." % (initial_epoch,)) diff --git a/tests/deterministic/smoke-test b/tests/deterministic/smoke-test new file mode 100755 index 000000000..d0fdc140e --- /dev/null +++ b/tests/deterministic/smoke-test @@ -0,0 +1,25 @@ +#! /usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +cpm_dataset="bham_gcmx-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic" +gcm_dataset="bham_60km-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic" +workdir="output/test/unet/test-run" + +config_path="src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py" + +train_batch_size=2 +epoch=2 + +rm -rf ${workdir} +WANDB_DISABLE_SERVICE=True WANDB_EXPERIMENT_NAME="test" python ${SCRIPT_DIR}/../../bin/deterministic/main.py --mode train --workdir ${workdir} --config ${config_path} --config.data.dataset_name=${cpm_dataset} --config.training.batch_size=${train_batch_size} --config.training.n_epochs=${epoch} --config.data.time_inputs=True --config.model.name=debug + +num_samples=2 +eval_batchsize=32 + +rm -rf "${workdir}/samples/epoch_${epoch}/${cpm_dataset}" +mlde evaluate sample ${workdir} --dataset ${cpm_dataset} --checkpoint epoch_${epoch} --batch-size ${eval_batchsize} --num-samples ${num_samples} +rm -rf "${workdir}/samples/epoch_${epoch}/${gcm_dataset}" +mlde evaluate sample ${workdir} --dataset ${gcm_dataset} --checkpoint epoch_${epoch} --batch-size ${eval_batchsize} --num-samples ${num_samples} diff --git a/tests/smoke-test b/tests/smoke-test index fb4e766e0..21ff64615 100755 --- a/tests/smoke-test +++ b/tests/smoke-test @@ -12,11 +12,12 @@ config_path="src/ml_downscaling_emulator/score_sde_pytorch/configs/${sde}/${conf loc_spec_channels=0 train_batch_size=2 random_crop_size=32 +epoch=2 rm -rf ${workdir} -WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=2 --config.model.loc_spec_channels=${loc_spec_channels} --config.training.random_crop_size=${random_crop_size} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True +WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=${epoch} --config.model.loc_spec_channels=${loc_spec_channels} --config.training.random_crop_size=${random_crop_size} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True + -epoch=2 num_samples=2 eval_batchsize=32 checkpoint="epoch_${epoch}"