diff --git a/src/ml_downscaling_emulator/bin/__init__.py b/src/ml_downscaling_emulator/bin/__init__.py index dba70cbcf..7fd0be43c 100644 --- a/src/ml_downscaling_emulator/bin/__init__.py +++ b/src/ml_downscaling_emulator/bin/__init__.py @@ -1,9 +1,9 @@ import typer -from . import evaluate, postprocess +from . import postprocess, sample app = typer.Typer() -app.add_typer(evaluate.app, name="evaluate") +app.add_typer(sample.app, name="sample") app.add_typer(postprocess.app, name="postprocess") diff --git a/src/ml_downscaling_emulator/bin/evaluate.py b/src/ml_downscaling_emulator/bin/sample.py similarity index 90% rename from src/ml_downscaling_emulator/bin/evaluate.py rename to src/ml_downscaling_emulator/bin/sample.py index f38ec3d7c..1eea615c2 100644 --- a/src/ml_downscaling_emulator/bin/evaluate.py +++ b/src/ml_downscaling_emulator/bin/sample.py @@ -72,14 +72,18 @@ def _sample_id(variable: str, eval_ds: xr.Dataset) -> xr.Dataset: @app.command() @Timer(name="sample", text="{name}: {minutes:.1f} minutes", logger=logging.info) @slack_sender(webhook_url=os.getenv("KK_SLACK_WH_URL"), channel="general") -def sample_id( +def as_input( workdir: Path, dataset: str = typer.Option(...), variable: str = "pr", split: str = "val", ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER, ): + """ + Use a given variable from the dataset to create a file of prediction samples. + Commonly used to create samples based on an already processed variable like using a bilinearly interpolated coarse resolution variable as the predicted "high-resolution" value directly. + """ output_dirpath = samples_path( workdir=workdir, checkpoint=f"epoch-0", diff --git a/tests/ml_downscaling_emulator/deterministic/test_sampling.py b/tests/ml_downscaling_emulator/bin/test_sample.py similarity index 60% rename from tests/ml_downscaling_emulator/deterministic/test_sampling.py rename to tests/ml_downscaling_emulator/bin/test_sample.py index 4c633a96c..574b47101 100644 --- a/tests/ml_downscaling_emulator/deterministic/test_sampling.py +++ b/tests/ml_downscaling_emulator/bin/test_sample.py @@ -1,14 +1,14 @@ import xarray as xr -from ml_downscaling_emulator.deterministic.sampling import sample_id +from ml_downscaling_emulator.bin.sample import _sample_id def test_sample_id(dataset: xr.Dataset): - """Ensure the sample_id function creates a set of predictions using the values of the given variable.""" + """Ensure the _sample_id bin function creates a set of predictions using the values of the given variable.""" variable = "linpr" em_dataset = dataset.sel(ensemble_member=["01"]) - xr_samples = sample_id(variable, em_dataset) + xr_samples = _sample_id(variable, em_dataset) assert (xr_samples["pred_pr"].values == em_dataset["linpr"].values).all() for dim in ["time", "grid_latitude", "grid_longitude"]: