diff --git a/src/ml_downscaling_emulator/data.py b/src/ml_downscaling_emulator/data.py index b388539a3..162d3ae7f 100644 --- a/src/ml_downscaling_emulator/data.py +++ b/src/ml_downscaling_emulator/data.py @@ -180,11 +180,12 @@ def np_samples_to_xr(np_samples, target_transform, target_vars, coords, cf_data_ pred_var = (pred_dims, np_var_pred, pred_attrs) raw_pred_var = ( pred_dims, + np_var_pred, {"grid_mapping": "rotated_latitude_longitude"}, ) data_vars.update( { - var.replace("target_", "pred_"): pred_var, + var: pred_var, # don't rename pred var until after inverting target transform var.replace("target_", "raw_pred_"): raw_pred_var, } ) diff --git a/tests/smoke-test b/tests/smoke-test index 21ff64615..91dee90b2 100755 --- a/tests/smoke-test +++ b/tests/smoke-test @@ -2,25 +2,23 @@ set -euo pipefail -config_name="ukcp_local_pr_1em_cunet_continuous" -dataset="bham_gcmx-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic" - sde="subvpsde" +config_name="ukcp_local_mv_12em_cncsnpp_continuous" +dataset="bham_gcmx-4x_2em_mv" + workdir="output/test/${sde}/${config_name}/test-run" config_path="src/ml_downscaling_emulator/score_sde_pytorch/configs/${sde}/${config_name}.py" loc_spec_channels=0 train_batch_size=2 -random_crop_size=32 epoch=2 rm -rf ${workdir} -WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=${epoch} --config.model.loc_spec_channels=${loc_spec_channels} --config.training.random_crop_size=${random_crop_size} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True - +WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=${epoch} --config.model.loc_spec_channels=${loc_spec_channels} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True --config.model.name=cunet --config.model.num_scales=10 num_samples=2 -eval_batchsize=32 +eval_batch_size=32 checkpoint="epoch_${epoch}" rm -rf "${workdir}/samples/${checkpoint}/${dataset}" -python bin/predict.py ${workdir} --dataset ${dataset} --checkpoint ${checkpoint} --batch-size ${eval_batchsize} --num-samples ${num_samples} --ensemble-member 01 +python bin/predict.py ${workdir} --dataset ${dataset} --checkpoint ${checkpoint} --batch-size ${eval_batch_size} --num-samples ${num_samples} --ensemble-member 01