Skip to content

Commit

Permalink
update sample helper used for bilinear interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Oct 22, 2024
1 parent f76547f commit 4df883a
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/ml_downscaling_emulator/bin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import typer

from . import evaluate, postprocess
from . import postprocess, sample

app = typer.Typer()
app.add_typer(evaluate.app, name="evaluate")
app.add_typer(sample.app, name="sample")
app.add_typer(postprocess.app, name="postprocess")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,18 @@ def _sample_id(variable: str, eval_ds: xr.Dataset) -> xr.Dataset:
@app.command()
@Timer(name="sample", text="{name}: {minutes:.1f} minutes", logger=logging.info)
@slack_sender(webhook_url=os.getenv("KK_SLACK_WH_URL"), channel="general")
def sample_id(
def as_input(
workdir: Path,
dataset: str = typer.Option(...),
variable: str = "pr",
split: str = "val",
ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER,
):
"""
Use a given variable from the dataset to create a file of prediction samples.
Commonly used to create samples based on an already processed variable like using a bilinearly interpolated coarse resolution variable as the predicted "high-resolution" value directly.
"""
output_dirpath = samples_path(
workdir=workdir,
checkpoint=f"epoch-0",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import xarray as xr

from ml_downscaling_emulator.deterministic.sampling import sample_id
from ml_downscaling_emulator.bin.sample import _sample_id


def test_sample_id(dataset: xr.Dataset):
"""Ensure the sample_id function creates a set of predictions using the values of the given variable."""
"""Ensure the _sample_id bin function creates a set of predictions using the values of the given variable."""

variable = "linpr"
em_dataset = dataset.sel(ensemble_member=["01"])
xr_samples = sample_id(variable, em_dataset)
xr_samples = _sample_id(variable, em_dataset)

assert (xr_samples["pred_pr"].values == em_dataset["linpr"].values).all()
for dim in ["time", "grid_latitude", "grid_longitude"]:
Expand Down

0 comments on commit 4df883a

Please sign in to comment.