From cdee1892a80f27aa990339b2a82c7f502a43811f Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Tue, 12 Sep 2023 14:03:07 +0100 Subject: [PATCH] move merge cli into postprocess and recognise when random ids are the same across files being merged --- src/ml_downscaling_emulator/bin/evaluate.py | 31 ---------------- .../bin/postprocess.py | 37 ++++++++++++++++++- 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/src/ml_downscaling_emulator/bin/evaluate.py b/src/ml_downscaling_emulator/bin/evaluate.py index 674cf0bf8..e6f78f36e 100644 --- a/src/ml_downscaling_emulator/bin/evaluate.py +++ b/src/ml_downscaling_emulator/bin/evaluate.py @@ -1,5 +1,4 @@ from codetiming import Timer -import glob import logging from knockknock import slack_sender from ml_collections import config_dict @@ -8,8 +7,6 @@ import shortuuid import torch import typer -from typing import List -import xarray as xr import yaml from mlde_utils import samples_path, DEFAULT_ENSEMBLE_MEMBER @@ -146,31 +143,3 @@ def sample_id( logger.info(f"Saving predictions to {output_filepath}") xr_samples.to_netcdf(output_filepath) - - -@app.command() -def merge( - input_dirs: List[Path], - output_dir: Path, -): - pred_file_globs = [ - glob.glob(os.path.join(samples_dir, "*.nc")) for samples_dir in input_dirs - ] - # there should be the same number of samples in each input dir - assert 1 == len(set(map(len, pred_file_globs))) - - for pred_file_group in zip(*pred_file_globs): - typer.echo(f"Concat {pred_file_group}") - - # take a bit of the random id in each sample file's name - random_ids = [fn[-25:-20] for fn in pred_file_group] - # join those partial random ids together for the output filepath in the train directory (rather than one of the subset train dirs) - output_filepath = os.path.join( - output_dir, f"predictions-{'-'.join(random_ids)}.nc" - ) - - typer.echo(f"save to {output_filepath}") - os.makedirs(os.path.dirname(output_filepath), exist_ok=True) - xr.concat([xr.open_dataset(f) for f in pred_file_group], dim="time").to_netcdf( - output_filepath - ) diff --git a/src/ml_downscaling_emulator/bin/postprocess.py b/src/ml_downscaling_emulator/bin/postprocess.py index 08a3e4978..202fa4f8d 100644 --- a/src/ml_downscaling_emulator/bin/postprocess.py +++ b/src/ml_downscaling_emulator/bin/postprocess.py @@ -1,3 +1,4 @@ +import glob import logging import os from pathlib import Path @@ -80,9 +81,7 @@ def qm( split: str = "val", ensemble_member: str = typer.Option(...), ): - pass # to compute the mapping, use train split data - # open train split of dataset for the target_pr sim_train_da = open_raw_dataset_split(train_dataset, "train").sel( ensemble_member=ensemble_member @@ -142,3 +141,37 @@ def qm( logger.info(f"Saving to {qmapped_sample_filepath}") qmapped_sample_filepath.parent.mkdir(parents=True, exist_ok=True) qmapped_eval_ds.to_netcdf(qmapped_sample_filepath) + + +@app.command() +def merge( + input_dirs: list[Path], + output_dir: Path, +): + pred_file_globs = [ + sorted(glob.glob(os.path.join(samples_dir, "*.nc"))) + for samples_dir in input_dirs + ] + # there should be the same number of samples in each input dir + assert 1 == len(set(map(len, pred_file_globs))) + + for pred_file_group in zip(*pred_file_globs): + typer.echo(f"Concat {pred_file_group}") + + # take a bit of the random id in each sample file's name + random_ids = [fn[-25:] for fn in pred_file_group] + if len(set(random_ids)) == 1: + # if all the random ids are the same (they are from the same sampling run), just use one of them for the output filepath + output_filepath = os.path.join(output_dir, f"predictions-{random_ids[0]}") + else: + # join those partial random ids together for the output filepath in the train directory (rather than one of the subset train dirs) + random_ids = [rid[:5] for rid in random_ids] + output_filepath = os.path.join( + output_dir, f"predictions-{'-'.join(random_ids)}.nc" + ) + + typer.echo(f"save to {output_filepath}") + os.makedirs(os.path.dirname(output_filepath), exist_ok=True) + xr.concat([xr.open_dataset(f) for f in pred_file_group], dim="time").to_netcdf( + output_filepath + )