From 4483f780a1db0f24f50bb1d3f7e9a738d2b6a7d9 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Mon, 11 Sep 2023 20:53:09 +0100 Subject: [PATCH] fix the output format of postprocess qm to be a samples like dataset --- .../bin/postprocess.py | 62 +++++++++++-------- src/ml_downscaling_emulator/postprocess.py | 2 +- 2 files changed, 36 insertions(+), 28 deletions(-) diff --git a/src/ml_downscaling_emulator/bin/postprocess.py b/src/ml_downscaling_emulator/bin/postprocess.py index 0918f80c8..08a3e4978 100644 --- a/src/ml_downscaling_emulator/bin/postprocess.py +++ b/src/ml_downscaling_emulator/bin/postprocess.py @@ -72,9 +72,11 @@ def filter( @app.command() def qm( workdir: Path, - dataset: str = typer.Option(...), checkpoint: str = typer.Option(...), - input_xfm: str = "stan", + train_dataset: str = typer.Option(...), + train_input_xfm: str = "stan", + eval_dataset: str = typer.Option(...), + eval_input_xfm: str = "stan", split: str = "val", ensemble_member: str = typer.Option(...), ): @@ -82,39 +84,45 @@ def qm( # to compute the mapping, use train split data # open train split of dataset for the target_pr - sim_train_da = open_raw_dataset_split(dataset, "train").sel( + sim_train_da = open_raw_dataset_split(train_dataset, "train").sel( ensemble_member=ensemble_member )["target_pr"] # open sample of model from train split ml_train_da = xr.open_dataset( - samples_glob( - samples_path( - workdir, - checkpoint=checkpoint, - input_xfm=input_xfm, - dataset=dataset, - split="train", - ensemble_member=ensemble_member, + list( + samples_glob( + samples_path( + workdir, + checkpoint=checkpoint, + input_xfm=train_input_xfm, + dataset=train_dataset, + split="train", + ensemble_member=ensemble_member, + ) ) )[0] )["pred_pr"] - for sample_path in samples_glob( - samples_path( - workdir, - checkpoint=checkpoint, - input_xfm=input_xfm, - dataset=dataset, - split=split, - ensemble_member=ensemble_member, - ) - ): + ml_eval_samples_dirpath = samples_path( + workdir, + checkpoint=checkpoint, + input_xfm=eval_input_xfm, + dataset=eval_dataset, + split=split, + ensemble_member=ensemble_member, + ) + logger.info(f"QMapping samplesin {ml_eval_samples_dirpath}") + for sample_filepath in samples_glob(ml_eval_samples_dirpath): + logger.info(f"Working on {sample_filepath}") # open the samples to be qmapped - ml_eval_da = xr.open_dataset(sample_path)["pred_pr"] + ml_eval_ds = xr.open_dataset(sample_filepath) # do the qmapping - qmapped_eval_da = xrqm(sim_train_da, ml_train_da, ml_eval_da) + qmapped_eval_da = xrqm(sim_train_da, ml_train_da, ml_eval_ds["pred_pr"]) + + qmapped_eval_ds = ml_eval_ds.copy() + qmapped_eval_ds["pred_pr"] = qmapped_eval_da # save output new_workdir = workdir / "postprocess" / "qm" @@ -123,14 +131,14 @@ def qm( samples_path( new_workdir, checkpoint=checkpoint, - input_xfm=input_xfm, - dataset=dataset, + input_xfm=eval_input_xfm, + dataset=eval_dataset, split=split, ensemble_member=ensemble_member, ) - / sample_path.name + / sample_filepath.name ) logger.info(f"Saving to {qmapped_sample_filepath}") qmapped_sample_filepath.parent.mkdir(parents=True, exist_ok=True) - qmapped_eval_da.to_netcdf(qmapped_sample_filepath) + qmapped_eval_ds.to_netcdf(qmapped_sample_filepath) diff --git a/src/ml_downscaling_emulator/postprocess.py b/src/ml_downscaling_emulator/postprocess.py index 283c68f51..6ff41ab82 100644 --- a/src/ml_downscaling_emulator/postprocess.py +++ b/src/ml_downscaling_emulator/postprocess.py @@ -65,6 +65,6 @@ def xrqm( ), # dimensions allowed to change size. Must be set! vectorize=True, ) - .transpose("time", "grid_latitude", "grid_longitude") + .transpose("ensemble_member", "time", "grid_latitude", "grid_longitude") .assign_coords(time=ml_eval_da["time"]) )