Skip to content

Commit

Permalink
move merge cli into postprocess
Browse files Browse the repository at this point in the history
and recognise when random ids are the same across files being merged
  • Loading branch information
henryaddison committed Sep 12, 2023
1 parent d00566a commit cdee189
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 33 deletions.
31 changes: 0 additions & 31 deletions src/ml_downscaling_emulator/bin/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from codetiming import Timer
import glob
import logging
from knockknock import slack_sender
from ml_collections import config_dict
Expand All @@ -8,8 +7,6 @@
import shortuuid
import torch
import typer
from typing import List
import xarray as xr
import yaml

from mlde_utils import samples_path, DEFAULT_ENSEMBLE_MEMBER
Expand Down Expand Up @@ -146,31 +143,3 @@ def sample_id(

logger.info(f"Saving predictions to {output_filepath}")
xr_samples.to_netcdf(output_filepath)


@app.command()
def merge(
input_dirs: List[Path],
output_dir: Path,
):
pred_file_globs = [
glob.glob(os.path.join(samples_dir, "*.nc")) for samples_dir in input_dirs
]
# there should be the same number of samples in each input dir
assert 1 == len(set(map(len, pred_file_globs)))

for pred_file_group in zip(*pred_file_globs):
typer.echo(f"Concat {pred_file_group}")

# take a bit of the random id in each sample file's name
random_ids = [fn[-25:-20] for fn in pred_file_group]
# join those partial random ids together for the output filepath in the train directory (rather than one of the subset train dirs)
output_filepath = os.path.join(
output_dir, f"predictions-{'-'.join(random_ids)}.nc"
)

typer.echo(f"save to {output_filepath}")
os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
xr.concat([xr.open_dataset(f) for f in pred_file_group], dim="time").to_netcdf(
output_filepath
)
37 changes: 35 additions & 2 deletions src/ml_downscaling_emulator/bin/postprocess.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import glob
import logging
import os
from pathlib import Path
Expand Down Expand Up @@ -80,9 +81,7 @@ def qm(
split: str = "val",
ensemble_member: str = typer.Option(...),
):
pass
# to compute the mapping, use train split data

# open train split of dataset for the target_pr
sim_train_da = open_raw_dataset_split(train_dataset, "train").sel(
ensemble_member=ensemble_member
Expand Down Expand Up @@ -142,3 +141,37 @@ def qm(
logger.info(f"Saving to {qmapped_sample_filepath}")
qmapped_sample_filepath.parent.mkdir(parents=True, exist_ok=True)
qmapped_eval_ds.to_netcdf(qmapped_sample_filepath)


@app.command()
def merge(
input_dirs: list[Path],
output_dir: Path,
):
pred_file_globs = [
sorted(glob.glob(os.path.join(samples_dir, "*.nc")))
for samples_dir in input_dirs
]
# there should be the same number of samples in each input dir
assert 1 == len(set(map(len, pred_file_globs)))

for pred_file_group in zip(*pred_file_globs):
typer.echo(f"Concat {pred_file_group}")

# take a bit of the random id in each sample file's name
random_ids = [fn[-25:] for fn in pred_file_group]
if len(set(random_ids)) == 1:
# if all the random ids are the same (they are from the same sampling run), just use one of them for the output filepath
output_filepath = os.path.join(output_dir, f"predictions-{random_ids[0]}")
else:
# join those partial random ids together for the output filepath in the train directory (rather than one of the subset train dirs)
random_ids = [rid[:5] for rid in random_ids]
output_filepath = os.path.join(
output_dir, f"predictions-{'-'.join(random_ids)}.nc"
)

typer.echo(f"save to {output_filepath}")
os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
xr.concat([xr.open_dataset(f) for f in pred_file_group], dim="time").to_netcdf(
output_filepath
)

0 comments on commit cdee189

Please sign in to comment.