Skip to content

Commit

Permalink
fix the output format of postprocess qm to be a samples like dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Sep 11, 2023
1 parent 49323fd commit 4483f78
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 28 deletions.
62 changes: 35 additions & 27 deletions src/ml_downscaling_emulator/bin/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,49 +72,57 @@ def filter(
@app.command()
def qm(
workdir: Path,
dataset: str = typer.Option(...),
checkpoint: str = typer.Option(...),
input_xfm: str = "stan",
train_dataset: str = typer.Option(...),
train_input_xfm: str = "stan",
eval_dataset: str = typer.Option(...),
eval_input_xfm: str = "stan",
split: str = "val",
ensemble_member: str = typer.Option(...),
):
pass
# to compute the mapping, use train split data

# open train split of dataset for the target_pr
sim_train_da = open_raw_dataset_split(dataset, "train").sel(
sim_train_da = open_raw_dataset_split(train_dataset, "train").sel(
ensemble_member=ensemble_member
)["target_pr"]

# open sample of model from train split
ml_train_da = xr.open_dataset(
samples_glob(
samples_path(
workdir,
checkpoint=checkpoint,
input_xfm=input_xfm,
dataset=dataset,
split="train",
ensemble_member=ensemble_member,
list(
samples_glob(
samples_path(
workdir,
checkpoint=checkpoint,
input_xfm=train_input_xfm,
dataset=train_dataset,
split="train",
ensemble_member=ensemble_member,
)
)
)[0]
)["pred_pr"]

for sample_path in samples_glob(
samples_path(
workdir,
checkpoint=checkpoint,
input_xfm=input_xfm,
dataset=dataset,
split=split,
ensemble_member=ensemble_member,
)
):
ml_eval_samples_dirpath = samples_path(
workdir,
checkpoint=checkpoint,
input_xfm=eval_input_xfm,
dataset=eval_dataset,
split=split,
ensemble_member=ensemble_member,
)
logger.info(f"QMapping samplesin {ml_eval_samples_dirpath}")
for sample_filepath in samples_glob(ml_eval_samples_dirpath):
logger.info(f"Working on {sample_filepath}")
# open the samples to be qmapped
ml_eval_da = xr.open_dataset(sample_path)["pred_pr"]
ml_eval_ds = xr.open_dataset(sample_filepath)

# do the qmapping
qmapped_eval_da = xrqm(sim_train_da, ml_train_da, ml_eval_da)
qmapped_eval_da = xrqm(sim_train_da, ml_train_da, ml_eval_ds["pred_pr"])

qmapped_eval_ds = ml_eval_ds.copy()
qmapped_eval_ds["pred_pr"] = qmapped_eval_da

# save output
new_workdir = workdir / "postprocess" / "qm"
Expand All @@ -123,14 +131,14 @@ def qm(
samples_path(
new_workdir,
checkpoint=checkpoint,
input_xfm=input_xfm,
dataset=dataset,
input_xfm=eval_input_xfm,
dataset=eval_dataset,
split=split,
ensemble_member=ensemble_member,
)
/ sample_path.name
/ sample_filepath.name
)

logger.info(f"Saving to {qmapped_sample_filepath}")
qmapped_sample_filepath.parent.mkdir(parents=True, exist_ok=True)
qmapped_eval_da.to_netcdf(qmapped_sample_filepath)
qmapped_eval_ds.to_netcdf(qmapped_sample_filepath)
2 changes: 1 addition & 1 deletion src/ml_downscaling_emulator/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@ def xrqm(
), # dimensions allowed to change size. Must be set!
vectorize=True,
)
.transpose("time", "grid_latitude", "grid_longitude")
.transpose("ensemble_member", "time", "grid_latitude", "grid_longitude")
.assign_coords(time=ml_eval_da["time"])
)

0 comments on commit 4483f78

Please sign in to comment.