Skip to content

Commit

Permalink
add a implementation of quantile mapping cli
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Sep 11, 2023
1 parent 1df23fc commit 49323fd
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
70 changes: 70 additions & 0 deletions src/ml_downscaling_emulator/bin/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import xarray as xr

from mlde_utils import samples_path, samples_glob, TIME_PERIODS
from mlde_utils.training.dataset import open_raw_dataset_split

from ml_downscaling_emulator.postprocess import xrqm

logging.basicConfig(
level=logging.INFO,
Expand Down Expand Up @@ -64,3 +67,70 @@ def filter(
samples_ds.sel(time=slice(*TIME_PERIODS[time_period])).to_netcdf(
filtered_samples_filepath
)


@app.command()
def qm(
workdir: Path,
dataset: str = typer.Option(...),
checkpoint: str = typer.Option(...),
input_xfm: str = "stan",
split: str = "val",
ensemble_member: str = typer.Option(...),
):
pass
# to compute the mapping, use train split data

# open train split of dataset for the target_pr
sim_train_da = open_raw_dataset_split(dataset, "train").sel(
ensemble_member=ensemble_member
)["target_pr"]

# open sample of model from train split
ml_train_da = xr.open_dataset(
samples_glob(
samples_path(
workdir,
checkpoint=checkpoint,
input_xfm=input_xfm,
dataset=dataset,
split="train",
ensemble_member=ensemble_member,
)
)[0]
)["pred_pr"]

for sample_path in samples_glob(
samples_path(
workdir,
checkpoint=checkpoint,
input_xfm=input_xfm,
dataset=dataset,
split=split,
ensemble_member=ensemble_member,
)
):
# open the samples to be qmapped
ml_eval_da = xr.open_dataset(sample_path)["pred_pr"]

# do the qmapping
qmapped_eval_da = xrqm(sim_train_da, ml_train_da, ml_eval_da)

# save output
new_workdir = workdir / "postprocess" / "qm"

qmapped_sample_filepath = (
samples_path(
new_workdir,
checkpoint=checkpoint,
input_xfm=input_xfm,
dataset=dataset,
split=split,
ensemble_member=ensemble_member,
)
/ sample_path.name
)

logger.info(f"Saving to {qmapped_sample_filepath}")
qmapped_sample_filepath.parent.mkdir(parents=True, exist_ok=True)
qmapped_eval_da.to_netcdf(qmapped_sample_filepath)

0 comments on commit 49323fd

Please sign in to comment.