From e36cd31698244133e564fe24b15a73e576c37257 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Fri, 1 Sep 2023 13:22:41 +0100 Subject: [PATCH] add a CLI for filtering samples by time period --- src/ml_downscaling_emulator/bin/__init__.py | 3 +- .../bin/postprocess.py | 63 +++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 src/ml_downscaling_emulator/bin/postprocess.py diff --git a/src/ml_downscaling_emulator/bin/__init__.py b/src/ml_downscaling_emulator/bin/__init__.py index 55a07625b..dba70cbcf 100644 --- a/src/ml_downscaling_emulator/bin/__init__.py +++ b/src/ml_downscaling_emulator/bin/__init__.py @@ -1,9 +1,10 @@ import typer -from . import evaluate +from . import evaluate, postprocess app = typer.Typer() app.add_typer(evaluate.app, name="evaluate") +app.add_typer(postprocess.app, name="postprocess") if __name__ == "__main__": diff --git a/src/ml_downscaling_emulator/bin/postprocess.py b/src/ml_downscaling_emulator/bin/postprocess.py new file mode 100644 index 000000000..9fb3d5c5f --- /dev/null +++ b/src/ml_downscaling_emulator/bin/postprocess.py @@ -0,0 +1,63 @@ +import logging +import os +from pathlib import Path +import typer +import xarray as xr + +from mlde_utils import samples_path, samples_glob, TIME_PERIODS + +logging.basicConfig( + level=logging.INFO, + format="%(levelname)s - %(filename)s - %(asctime)s - %(message)s", +) +logger = logging.getLogger() +logger.setLevel("INFO") + +app = typer.Typer() + + +@app.callback() +def callback(): + pass + + +@app.command() +def filter( + workdir: Path, + dataset: str = typer.Option(...), + time_period: str = typer.Option(...), + checkpoint: str = typer.Option(...), + input_xfm: str = "stan", + split: str = "val", + ensemble_member: str = typer.Option(...), +): + """Filter a set of samples based on time period.""" + + new_dataset = f"{dataset}-{time_period}" + filtered_samples_dirpath = samples_path( + workdir, + checkpoint=checkpoint, + input_xfm=input_xfm, + dataset=new_dataset, + split=split, + ensemble_member=ensemble_member, + ) + os.makedirs(filtered_samples_dirpath, exist_ok=False) + + for sample_filepath in samples_glob( + samples_path( + workdir, + checkpoint=checkpoint, + input_xfm=input_xfm, + dataset=dataset, + split=split, + ensemble_member=ensemble_member, + ) + ): + samples_ds = xr.open_dataset(sample_filepath) + + filtered_samples_filepath = filtered_samples_dirpath / sample_filepath.name + + samples_ds.sel(time=slice(*TIME_PERIODS[time_period])).to_netcdf( + filtered_samples_filepath + )