Skip to content

Commit

Permalink
clean up some deterministic related code
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Mar 20, 2024
1 parent 67e2e4b commit c8fffbf
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion bin/deterministic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"config", None, "Training configuration.", lock_config=True
)
flags.DEFINE_string("workdir", None, "Work directory.")
flags.DEFINE_enum("mode", None, ["train", "eval"], "Running mode: train or eval")
flags.DEFINE_enum("mode", None, ["train"], "Running mode: train.")
flags.mark_flags_as_required(["workdir", "config", "mode"])


Expand Down
6 changes: 3 additions & 3 deletions src/ml_downscaling_emulator/bin/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def sample(
workdir: Path,
dataset: str = typer.Option(...),
split: str = "val",
epoch: int = typer.Option(...),
checkpoint: str = typer.Option(...),
batch_size: int = None,
num_samples: int = 1,
input_transform_dataset: str = None,
Expand All @@ -81,7 +81,7 @@ def sample(

output_dirpath = samples_path(
workdir=workdir,
checkpoint=f"epoch-{epoch}",
checkpoint=checkpoint,
dataset=dataset,
input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}",
split=split,
Expand All @@ -106,7 +106,7 @@ def sample(
shuffle=False,
)

ckpt_filename = os.path.join(workdir, "checkpoints", f"epoch_{epoch}.pth")
ckpt_filename = os.path.join(workdir, "checkpoints", f"{checkpoint}.pth")
num_predictors = eval_dl.dataset[0][0].shape[0]
state = load_model(config, num_predictors, ckpt_filename)

Expand Down
21 changes: 10 additions & 11 deletions tests/deterministic/smoke-test
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,22 @@ set -euo pipefail

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )

cpm_dataset="bham_gcmx-4x_12em_psl-temp4th-vort4th_eqvt_random-season"
gcm_dataset="bham_60km-4x_12em_psl-temp4th-vort4th_eqvt_random-season"
cpm_dataset="bham_gcmx-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic"
gcm_dataset="bham_60km-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic"
workdir="output/test/unet/test-run"

config_path="src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py"

map_features=0
train_batch_size=4
train_batch_size=2
epoch=2

rm -rf ${workdir}
WANDB_EXPERIMENT_NAME="test" python ${SCRIPT_DIR}/main.py --mode train --workdir ${workdir} --config ${config_path} --config.data.dataset_name=${cpm_dataset} --config.training.batch_size=${train_batch_size} --config.training.n_epochs=2 --config.data.time_inputs=True --config.model.name=debug
WANDB_DISABLE_SERVICE=True WANDB_EXPERIMENT_NAME="test" python ${SCRIPT_DIR}/../../bin/deterministic/main.py --mode train --workdir ${workdir} --config ${config_path} --config.data.dataset_name=${cpm_dataset} --config.training.batch_size=${train_batch_size} --config.training.n_epochs=${epoch} --config.data.time_inputs=True --config.model.name=debug

epoch=2
num_samples=2
eval_batchsize=128
eval_batchsize=32

rm -rf "${workdir}/samples/epoch-${epoch}/${cpm_dataset}"
mlde evaluate sample ${workdir} --dataset ${cpm_dataset} --epoch ${epoch} --batch-size ${eval_batchsize} --num-samples ${num_samples}
rm -rf "${workdir}/samples/epoch-${epoch}/${gcm_dataset}"
mlde evaluate sample ${workdir} --dataset ${gcm_dataset} --epoch ${epoch} --batch-size ${eval_batchsize} --num-samples ${num_samples}
rm -rf "${workdir}/samples/epoch_${epoch}/${cpm_dataset}"
mlde evaluate sample ${workdir} --dataset ${cpm_dataset} --checkpoint epoch_${epoch} --batch-size ${eval_batchsize} --num-samples ${num_samples}
rm -rf "${workdir}/samples/epoch_${epoch}/${gcm_dataset}"
mlde evaluate sample ${workdir} --dataset ${gcm_dataset} --checkpoint epoch_${epoch} --batch-size ${eval_batchsize} --num-samples ${num_samples}
5 changes: 3 additions & 2 deletions tests/smoke-test
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ config_path="src/ml_downscaling_emulator/score_sde_pytorch/configs/${sde}/${conf
loc_spec_channels=0
train_batch_size=2
random_crop_size=32
epoch=2

rm -rf ${workdir}
WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=2 --config.model.loc_spec_channels=${loc_spec_channels} --config.training.random_crop_size=${random_crop_size} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True
WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=${epoch} --config.model.loc_spec_channels=${loc_spec_channels} --config.training.random_crop_size=${random_crop_size} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True


epoch=2
num_samples=2
eval_batchsize=32
checkpoint="epoch_${epoch}"
Expand Down

0 comments on commit c8fffbf

Please sign in to comment.