Skip to content

Commit

Permalink
Merge pull request #32 from henryaddison/unet-paper
Browse files Browse the repository at this point in the history
Tidy deterministic (U-Net) code ready for paper publishing
  • Loading branch information
henryaddison authored Mar 21, 2024
2 parents 31876dd + 2bfe6a5 commit 25e1fec
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 39 deletions.
26 changes: 0 additions & 26 deletions bin/deterministic/local-test-train

This file was deleted.

2 changes: 1 addition & 1 deletion bin/deterministic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"config", None, "Training configuration.", lock_config=True
)
flags.DEFINE_string("workdir", None, "Work directory.")
flags.DEFINE_enum("mode", None, ["train", "eval"], "Running mode: train or eval")
flags.DEFINE_enum("mode", None, ["train"], "Running mode: train.")
flags.mark_flags_as_required(["workdir", "config", "mode"])


Expand Down
15 changes: 8 additions & 7 deletions src/ml_downscaling_emulator/bin/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def sample(
workdir: Path,
dataset: str = typer.Option(...),
split: str = "val",
epoch: int = typer.Option(...),
checkpoint: str = typer.Option(...),
batch_size: int = None,
num_samples: int = 1,
input_transform_dataset: str = None,
Expand All @@ -71,16 +71,17 @@ def sample(

if batch_size is not None:
config.eval.batch_size = batch_size
if input_transform_dataset is not None:
config.data.input_transform_dataset = input_transform_dataset
else:
config.data.input_transform_dataset = dataset
with config.unlocked():
if input_transform_dataset is not None:
config.data.input_transform_dataset = input_transform_dataset
else:
config.data.input_transform_dataset = dataset
if input_transform_key is not None:
config.data.input_transform_key = input_transform_key

output_dirpath = samples_path(
workdir=workdir,
checkpoint=f"epoch-{epoch}",
checkpoint=checkpoint,
dataset=dataset,
input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}",
split=split,
Expand All @@ -105,7 +106,7 @@ def sample(
shuffle=False,
)

ckpt_filename = os.path.join(workdir, "checkpoints", f"epoch_{epoch}.pth")
ckpt_filename = os.path.join(workdir, "checkpoints", f"{checkpoint}.pth")
num_predictors = eval_dl.dataset[0][0].shape[0]
state = load_model(config, num_predictors, ckpt_filename)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ def get_config():
evaluate.batch_size = 64

config.data = data = ml_collections.ConfigDict()
data.dataset_name = ""
data.dataset_name = (
"bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season"
)
data.input_transform_key = "stan"
data.target_transform_key = "sqrturrecen"
data.input_transform_dataset = None
data.time_inputs = False

config.model = model = ml_collections.ConfigDict()
Expand Down
3 changes: 1 addition & 2 deletions src/ml_downscaling_emulator/deterministic/run_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from ..training import log_epoch, track_run
from .utils import restore_checkpoint, save_checkpoint, create_model
from ..torch import get_dataloader
from ..data import get_dataloader

FLAGS = flags.FLAGS
EXPERIMENT_NAME = os.getenv("WANDB_EXPERIMENT_NAME")
Expand Down Expand Up @@ -208,7 +208,6 @@ def train_step_fn(state, batch, cond):
EXPERIMENT_NAME, run_name, run_config, [config.model.name, "baseline"], tb_dir
) as (wandb_run, tb_writer):
# Fit model
wandb_run.watch(model, criterion=criterion, log_freq=100)

logging.info("Starting training loop at epoch %d." % (initial_epoch,))

Expand Down
25 changes: 25 additions & 0 deletions tests/deterministic/smoke-test
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#! /usr/bin/env bash

set -euo pipefail

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )

cpm_dataset="bham_gcmx-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic"
gcm_dataset="bham_60km-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic"
workdir="output/test/unet/test-run"

config_path="src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py"

train_batch_size=2
epoch=2

rm -rf ${workdir}
WANDB_DISABLE_SERVICE=True WANDB_EXPERIMENT_NAME="test" python ${SCRIPT_DIR}/../../bin/deterministic/main.py --mode train --workdir ${workdir} --config ${config_path} --config.data.dataset_name=${cpm_dataset} --config.training.batch_size=${train_batch_size} --config.training.n_epochs=${epoch} --config.data.time_inputs=True --config.model.name=debug

num_samples=2
eval_batchsize=32

rm -rf "${workdir}/samples/epoch_${epoch}/${cpm_dataset}"
mlde evaluate sample ${workdir} --dataset ${cpm_dataset} --checkpoint epoch_${epoch} --batch-size ${eval_batchsize} --num-samples ${num_samples}
rm -rf "${workdir}/samples/epoch_${epoch}/${gcm_dataset}"
mlde evaluate sample ${workdir} --dataset ${gcm_dataset} --checkpoint epoch_${epoch} --batch-size ${eval_batchsize} --num-samples ${num_samples}
5 changes: 3 additions & 2 deletions tests/smoke-test
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ config_path="src/ml_downscaling_emulator/score_sde_pytorch/configs/${sde}/${conf
loc_spec_channels=0
train_batch_size=2
random_crop_size=32
epoch=2

rm -rf ${workdir}
WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=2 --config.model.loc_spec_channels=${loc_spec_channels} --config.training.random_crop_size=${random_crop_size} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True
WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=${epoch} --config.model.loc_spec_channels=${loc_spec_channels} --config.training.random_crop_size=${random_crop_size} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True


epoch=2
num_samples=2
eval_batchsize=32
checkpoint="epoch_${epoch}"
Expand Down

0 comments on commit 25e1fec

Please sign in to comment.