Skip to content

Commit

Permalink
add a CLI command for converting all samples of a model on a dataset …
Browse files Browse the repository at this point in the history
…split to 60km grid
  • Loading branch information
henryaddison committed Nov 29, 2023
1 parent 51ec14d commit 649bd77
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 34 deletions.
118 changes: 84 additions & 34 deletions src/ml_downscaling_emulator/bin/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import logging
import os
from pathlib import Path
from typing import Callable
import typer
import xarray as xr

from mlde_utils import samples_path, samples_glob, TIME_PERIODS
from mlde_utils.training.dataset import open_raw_dataset_split

from ml_downscaling_emulator.postprocess import xrqm
from ml_downscaling_emulator.postprocess import xrqm, to_gcm_domain

logging.basicConfig(
level=logging.INFO,
Expand All @@ -25,6 +26,50 @@ def callback():
pass


def process_each_sample(
workdir: Path,
checkpoint: str,
dataset: str,
ensemble_member: str,
input_xfm: str,
split: str,
processing_func: Callable,
new_workdir: Path,
):
samples_dirpath = samples_path(
workdir,
checkpoint=checkpoint,
input_xfm=input_xfm,
dataset=dataset,
split=split,
ensemble_member=ensemble_member,
)
logger.info(f"Iterating on samples in {samples_dirpath}")
for sample_filepath in samples_glob(samples_dirpath):
logger.info(f"Working on {sample_filepath}")
# open the samples
samples_ds = xr.open_dataset(sample_filepath)

processed_samples_ds = processing_func(samples_ds)

# save output
processed_sample_filepath = (
samples_path(
new_workdir,
checkpoint=checkpoint,
input_xfm=input_xfm,
dataset=dataset,
split=split,
ensemble_member=ensemble_member,
)
/ sample_filepath.name
)

logger.info(f"Saving to {processed_sample_filepath}")
processed_sample_filepath.parent.mkdir(parents=True, exist_ok=True)
processed_samples_ds.to_netcdf(processed_sample_filepath)


@app.command()
def filter(
workdir: Path,
Expand Down Expand Up @@ -103,44 +148,49 @@ def qm(
)[0]
)["pred_pr"]

ml_eval_samples_dirpath = samples_path(
workdir,
checkpoint=checkpoint,
input_xfm=eval_input_xfm,
dataset=eval_dataset,
split=split,
ensemble_member=ensemble_member,
)
logger.info(f"QMapping samplesin {ml_eval_samples_dirpath}")
for sample_filepath in samples_glob(ml_eval_samples_dirpath):
logger.info(f"Working on {sample_filepath}")
# open the samples to be qmapped
ml_eval_ds = xr.open_dataset(sample_filepath)

def process_samples(ds):
# do the qmapping
qmapped_eval_da = xrqm(sim_train_da, ml_train_da, ml_eval_ds["pred_pr"])
qmapped_eval_da = xrqm(sim_train_da, ml_train_da, ds["pred_pr"])

qmapped_eval_ds = ml_eval_ds.copy()
qmapped_eval_ds["pred_pr"] = qmapped_eval_da
processed_samples_ds = ds.copy()
processed_samples_ds["pred_pr"] = qmapped_eval_da

# save output
new_workdir = workdir / "postprocess" / "qm"
return processed_samples_ds

process_each_sample(
workdir,
checkpoint,
eval_dataset,
ensemble_member,
eval_input_xfm,
split,
process_samples,
new_workdir=workdir / "postprocess" / "qm",
)

qmapped_sample_filepath = (
samples_path(
new_workdir,
checkpoint=checkpoint,
input_xfm=eval_input_xfm,
dataset=eval_dataset,
split=split,
ensemble_member=ensemble_member,
)
/ sample_filepath.name
)

logger.info(f"Saving to {qmapped_sample_filepath}")
qmapped_sample_filepath.parent.mkdir(parents=True, exist_ok=True)
qmapped_eval_ds.to_netcdf(qmapped_sample_filepath)
@app.command()
def gcmify(
workdir: Path,
checkpoint: str = typer.Option(...),
dataset: str = typer.Option(...),
input_xfm: str = typer.Option(...),
split: str = typer.Option(...),
ensemble_member: str = typer.Option(...),
):
def process_samples(ds):
return ds.groupby("sample_id").map(to_gcm_domain)

process_each_sample(
workdir,
checkpoint,
dataset,
ensemble_member,
input_xfm,
split,
process_samples,
new_workdir=workdir / "postprocess" / "gcm-grid",
)


@app.command()
Expand Down
16 changes: 16 additions & 0 deletions src/ml_downscaling_emulator/postprocess.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from importlib.resources import files
from typing import Callable
import numpy as np
import xarray as xr

from mlde_utils.data.remapcon import Remapcon
from mlde_utils.data.shift_lon_break import ShiftLonBreak
from mlde_utils.data.select_gcm_domain import SelectGCMDomain


def _get_cdf(x, xbins):
pdf, _ = np.histogram(x, xbins)
Expand Down Expand Up @@ -68,3 +73,14 @@ def xrqm(
.transpose("ensemble_member", "time", "grid_latitude", "grid_longitude")
.assign_coords(time=ml_eval_da["time"])
)


def to_gcm_domain(ds: xr.Dataset):
target_grid_filepath = files("mlde_utils.data").joinpath(
"target_grids/60km/global/pr/moose_grid.nc"
)
xr.open_dataset(target_grid_filepath)
ds = Remapcon(target_grid_filepath).run(ds)
ds = ShiftLonBreak().run(ds)
ds = SelectGCMDomain(subdomain="2.2km-coarsened-4x_bham-64").run(ds)
return ds

0 comments on commit 649bd77

Please sign in to comment.