Skip to content

Commit

Permalink
correct a couple of bugs from multivariate samples
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Mar 29, 2024
1 parent e4ee079 commit d99c9d9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
3 changes: 2 additions & 1 deletion src/ml_downscaling_emulator/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,12 @@ def np_samples_to_xr(np_samples, target_transform, target_vars, coords, cf_data_
pred_var = (pred_dims, np_var_pred, pred_attrs)
raw_pred_var = (
pred_dims,
np_var_pred,
{"grid_mapping": "rotated_latitude_longitude"},
)
data_vars.update(
{
var.replace("target_", "pred_"): pred_var,
var: pred_var, # don't rename pred var until after inverting target transform
var.replace("target_", "raw_pred_"): raw_pred_var,
}
)
Expand Down
14 changes: 6 additions & 8 deletions tests/smoke-test
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,23 @@

set -euo pipefail

config_name="ukcp_local_pr_1em_cunet_continuous"
dataset="bham_gcmx-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic"

sde="subvpsde"
config_name="ukcp_local_mv_12em_cncsnpp_continuous"
dataset="bham_gcmx-4x_2em_mv"

workdir="output/test/${sde}/${config_name}/test-run"
config_path="src/ml_downscaling_emulator/score_sde_pytorch/configs/${sde}/${config_name}.py"

loc_spec_channels=0
train_batch_size=2
random_crop_size=32
epoch=2

rm -rf ${workdir}
WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=${epoch} --config.model.loc_spec_channels=${loc_spec_channels} --config.training.random_crop_size=${random_crop_size} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True

WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=${epoch} --config.model.loc_spec_channels=${loc_spec_channels} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True --config.model.name=cunet --config.model.num_scales=10

num_samples=2
eval_batchsize=32
eval_batch_size=32
checkpoint="epoch_${epoch}"

rm -rf "${workdir}/samples/${checkpoint}/${dataset}"
python bin/predict.py ${workdir} --dataset ${dataset} --checkpoint ${checkpoint} --batch-size ${eval_batchsize} --num-samples ${num_samples} --ensemble-member 01
python bin/predict.py ${workdir} --dataset ${dataset} --checkpoint ${checkpoint} --batch-size ${eval_batch_size} --num-samples ${num_samples} --ensemble-member 01

0 comments on commit d99c9d9

Please sign in to comment.