From 649bd77f572b5f52f62938f8202f84d7b571b410 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Wed, 29 Nov 2023 21:52:02 +0000 Subject: [PATCH 01/22] add a CLI command for converting all samples of a model on a dataset split to 60km grid --- .../bin/postprocess.py | 118 +++++++++++++----- src/ml_downscaling_emulator/postprocess.py | 16 +++ 2 files changed, 100 insertions(+), 34 deletions(-) diff --git a/src/ml_downscaling_emulator/bin/postprocess.py b/src/ml_downscaling_emulator/bin/postprocess.py index 202fa4f8d..d9161277a 100644 --- a/src/ml_downscaling_emulator/bin/postprocess.py +++ b/src/ml_downscaling_emulator/bin/postprocess.py @@ -2,13 +2,14 @@ import logging import os from pathlib import Path +from typing import Callable import typer import xarray as xr from mlde_utils import samples_path, samples_glob, TIME_PERIODS from mlde_utils.training.dataset import open_raw_dataset_split -from ml_downscaling_emulator.postprocess import xrqm +from ml_downscaling_emulator.postprocess import xrqm, to_gcm_domain logging.basicConfig( level=logging.INFO, @@ -25,6 +26,50 @@ def callback(): pass +def process_each_sample( + workdir: Path, + checkpoint: str, + dataset: str, + ensemble_member: str, + input_xfm: str, + split: str, + processing_func: Callable, + new_workdir: Path, +): + samples_dirpath = samples_path( + workdir, + checkpoint=checkpoint, + input_xfm=input_xfm, + dataset=dataset, + split=split, + ensemble_member=ensemble_member, + ) + logger.info(f"Iterating on samples in {samples_dirpath}") + for sample_filepath in samples_glob(samples_dirpath): + logger.info(f"Working on {sample_filepath}") + # open the samples + samples_ds = xr.open_dataset(sample_filepath) + + processed_samples_ds = processing_func(samples_ds) + + # save output + processed_sample_filepath = ( + samples_path( + new_workdir, + checkpoint=checkpoint, + input_xfm=input_xfm, + dataset=dataset, + split=split, + ensemble_member=ensemble_member, + ) + / sample_filepath.name + ) + + logger.info(f"Saving to {processed_sample_filepath}") + processed_sample_filepath.parent.mkdir(parents=True, exist_ok=True) + processed_samples_ds.to_netcdf(processed_sample_filepath) + + @app.command() def filter( workdir: Path, @@ -103,44 +148,49 @@ def qm( )[0] )["pred_pr"] - ml_eval_samples_dirpath = samples_path( - workdir, - checkpoint=checkpoint, - input_xfm=eval_input_xfm, - dataset=eval_dataset, - split=split, - ensemble_member=ensemble_member, - ) - logger.info(f"QMapping samplesin {ml_eval_samples_dirpath}") - for sample_filepath in samples_glob(ml_eval_samples_dirpath): - logger.info(f"Working on {sample_filepath}") - # open the samples to be qmapped - ml_eval_ds = xr.open_dataset(sample_filepath) - + def process_samples(ds): # do the qmapping - qmapped_eval_da = xrqm(sim_train_da, ml_train_da, ml_eval_ds["pred_pr"]) + qmapped_eval_da = xrqm(sim_train_da, ml_train_da, ds["pred_pr"]) - qmapped_eval_ds = ml_eval_ds.copy() - qmapped_eval_ds["pred_pr"] = qmapped_eval_da + processed_samples_ds = ds.copy() + processed_samples_ds["pred_pr"] = qmapped_eval_da - # save output - new_workdir = workdir / "postprocess" / "qm" + return processed_samples_ds + + process_each_sample( + workdir, + checkpoint, + eval_dataset, + ensemble_member, + eval_input_xfm, + split, + process_samples, + new_workdir=workdir / "postprocess" / "qm", + ) - qmapped_sample_filepath = ( - samples_path( - new_workdir, - checkpoint=checkpoint, - input_xfm=eval_input_xfm, - dataset=eval_dataset, - split=split, - ensemble_member=ensemble_member, - ) - / sample_filepath.name - ) - logger.info(f"Saving to {qmapped_sample_filepath}") - qmapped_sample_filepath.parent.mkdir(parents=True, exist_ok=True) - qmapped_eval_ds.to_netcdf(qmapped_sample_filepath) +@app.command() +def gcmify( + workdir: Path, + checkpoint: str = typer.Option(...), + dataset: str = typer.Option(...), + input_xfm: str = typer.Option(...), + split: str = typer.Option(...), + ensemble_member: str = typer.Option(...), +): + def process_samples(ds): + return ds.groupby("sample_id").map(to_gcm_domain) + + process_each_sample( + workdir, + checkpoint, + dataset, + ensemble_member, + input_xfm, + split, + process_samples, + new_workdir=workdir / "postprocess" / "gcm-grid", + ) @app.command() diff --git a/src/ml_downscaling_emulator/postprocess.py b/src/ml_downscaling_emulator/postprocess.py index 6ff41ab82..50783522e 100644 --- a/src/ml_downscaling_emulator/postprocess.py +++ b/src/ml_downscaling_emulator/postprocess.py @@ -1,7 +1,12 @@ +from importlib.resources import files from typing import Callable import numpy as np import xarray as xr +from mlde_utils.data.remapcon import Remapcon +from mlde_utils.data.shift_lon_break import ShiftLonBreak +from mlde_utils.data.select_gcm_domain import SelectGCMDomain + def _get_cdf(x, xbins): pdf, _ = np.histogram(x, xbins) @@ -68,3 +73,14 @@ def xrqm( .transpose("ensemble_member", "time", "grid_latitude", "grid_longitude") .assign_coords(time=ml_eval_da["time"]) ) + + +def to_gcm_domain(ds: xr.Dataset): + target_grid_filepath = files("mlde_utils.data").joinpath( + "target_grids/60km/global/pr/moose_grid.nc" + ) + xr.open_dataset(target_grid_filepath) + ds = Remapcon(target_grid_filepath).run(ds) + ds = ShiftLonBreak().run(ds) + ds = SelectGCMDomain(subdomain="2.2km-coarsened-4x_bham-64").run(ds) + return ds From f0a43641af57d215d2ce6a8d220b912e6561fd26 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Thu, 30 Nov 2023 11:45:04 +0000 Subject: [PATCH 02/22] sample files don't have sample_id but they do have ensemble_member dim --- src/ml_downscaling_emulator/bin/postprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ml_downscaling_emulator/bin/postprocess.py b/src/ml_downscaling_emulator/bin/postprocess.py index d9161277a..ca1bf9f3c 100644 --- a/src/ml_downscaling_emulator/bin/postprocess.py +++ b/src/ml_downscaling_emulator/bin/postprocess.py @@ -179,7 +179,7 @@ def gcmify( ensemble_member: str = typer.Option(...), ): def process_samples(ds): - return ds.groupby("sample_id").map(to_gcm_domain) + return ds.groupby("ensemble_member").map(to_gcm_domain) process_each_sample( workdir, From dd9d3f7a8c3887180ebe321cdc8142cd1b261c37 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Thu, 30 Nov 2023 11:46:43 +0000 Subject: [PATCH 03/22] update SelectGCMDomain interface --- src/ml_downscaling_emulator/postprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ml_downscaling_emulator/postprocess.py b/src/ml_downscaling_emulator/postprocess.py index 50783522e..495c5d312 100644 --- a/src/ml_downscaling_emulator/postprocess.py +++ b/src/ml_downscaling_emulator/postprocess.py @@ -82,5 +82,5 @@ def to_gcm_domain(ds: xr.Dataset): xr.open_dataset(target_grid_filepath) ds = Remapcon(target_grid_filepath).run(ds) ds = ShiftLonBreak().run(ds) - ds = SelectGCMDomain(subdomain="2.2km-coarsened-4x_bham-64").run(ds) + ds = SelectGCMDomain(subdomain="birmingham", size=9).run(ds) return ds From c57cfc9324f461594feeb5c0baa0debefa9d2c50 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Fri, 1 Dec 2023 16:04:11 +0000 Subject: [PATCH 04/22] map doesn't behave with to_gcm_domain --- src/ml_downscaling_emulator/bin/postprocess.py | 4 +++- src/ml_downscaling_emulator/postprocess.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ml_downscaling_emulator/bin/postprocess.py b/src/ml_downscaling_emulator/bin/postprocess.py index ca1bf9f3c..f41d487e9 100644 --- a/src/ml_downscaling_emulator/bin/postprocess.py +++ b/src/ml_downscaling_emulator/bin/postprocess.py @@ -179,7 +179,9 @@ def gcmify( ensemble_member: str = typer.Option(...), ): def process_samples(ds): - return ds.groupby("ensemble_member").map(to_gcm_domain) + ds = to_gcm_domain(ds.sel(ensemble_member=ensemble_member).map(to_gcm_domain)) + ds["pred_pr"] = ds["pred_pr"].expand_dims({"ensemble_member": ensemble_member}) + return ds process_each_sample( workdir, diff --git a/src/ml_downscaling_emulator/postprocess.py b/src/ml_downscaling_emulator/postprocess.py index 495c5d312..b20b6a2b7 100644 --- a/src/ml_downscaling_emulator/postprocess.py +++ b/src/ml_downscaling_emulator/postprocess.py @@ -83,4 +83,5 @@ def to_gcm_domain(ds: xr.Dataset): ds = Remapcon(target_grid_filepath).run(ds) ds = ShiftLonBreak().run(ds) ds = SelectGCMDomain(subdomain="birmingham", size=9).run(ds) + ds = ds.drop_vars(["rotated_latitude_longitude"], errors="ignore") return ds From a2f033cebbb601a16451d857c85844043954bfc4 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Fri, 1 Dec 2023 16:07:07 +0000 Subject: [PATCH 05/22] oops left a map call that should have been removed --- src/ml_downscaling_emulator/bin/postprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ml_downscaling_emulator/bin/postprocess.py b/src/ml_downscaling_emulator/bin/postprocess.py index f41d487e9..476093eaa 100644 --- a/src/ml_downscaling_emulator/bin/postprocess.py +++ b/src/ml_downscaling_emulator/bin/postprocess.py @@ -179,7 +179,7 @@ def gcmify( ensemble_member: str = typer.Option(...), ): def process_samples(ds): - ds = to_gcm_domain(ds.sel(ensemble_member=ensemble_member).map(to_gcm_domain)) + ds = to_gcm_domain(ds.sel(ensemble_member=ensemble_member)) ds["pred_pr"] = ds["pred_pr"].expand_dims({"ensemble_member": ensemble_member}) return ds From 9fb2042e17e8f23bf378eb6f52dcb0f4a8bc9375 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Fri, 1 Dec 2023 16:11:20 +0000 Subject: [PATCH 06/22] OK got to expand dims for entire dataset --- src/ml_downscaling_emulator/bin/postprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ml_downscaling_emulator/bin/postprocess.py b/src/ml_downscaling_emulator/bin/postprocess.py index 476093eaa..75a0da97c 100644 --- a/src/ml_downscaling_emulator/bin/postprocess.py +++ b/src/ml_downscaling_emulator/bin/postprocess.py @@ -180,7 +180,7 @@ def gcmify( ): def process_samples(ds): ds = to_gcm_domain(ds.sel(ensemble_member=ensemble_member)) - ds["pred_pr"] = ds["pred_pr"].expand_dims({"ensemble_member": ensemble_member}) + ds = ds.expand_dims({"ensemble_member": ensemble_member}) return ds process_each_sample( From d0fc90eb1de90b54bb6ca8d087cdb771acfa24ed Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Fri, 1 Dec 2023 16:12:58 +0000 Subject: [PATCH 07/22] ahh no it was missing square brakets to define the value --- src/ml_downscaling_emulator/bin/postprocess.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ml_downscaling_emulator/bin/postprocess.py b/src/ml_downscaling_emulator/bin/postprocess.py index 75a0da97c..2eb4e5c2d 100644 --- a/src/ml_downscaling_emulator/bin/postprocess.py +++ b/src/ml_downscaling_emulator/bin/postprocess.py @@ -180,7 +180,9 @@ def gcmify( ): def process_samples(ds): ds = to_gcm_domain(ds.sel(ensemble_member=ensemble_member)) - ds = ds.expand_dims({"ensemble_member": ensemble_member}) + ds["pred_pr"] = ds["pred_pr"].expand_dims( + {"ensemble_member": [ensemble_member]} + ) return ds process_each_sample( From 70f45999d7be0437503466bb2b9d64f426f216cf Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Wed, 6 Dec 2023 11:44:31 +0000 Subject: [PATCH 08/22] a helper script to check predictions --- bin/check_preds.py | 68 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 bin/check_preds.py diff --git a/bin/check_preds.py b/bin/check_preds.py new file mode 100644 index 000000000..f68a9ea3b --- /dev/null +++ b/bin/check_preds.py @@ -0,0 +1,68 @@ +import sys + +from tqdm import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm +import xarray as xr + +fpaths = sys.argv[1:] + + +def check(pred_ds): + + pred_var = "pred_pr" + meta_vars = ( + "rotated_latitude_longitude", + "time_bnds", + "grid_latitude_bnds", + "grid_longitude_bnds", + ) + + errors = [] + + pred_pr_attrs = { + "grid_mapping": "rotated_latitude_longitude", + "standard_name": "pred_pr", + "units": "kg m-2 s-1", + } + + try: + assert ( + pred_ds[pred_var].attrs == pred_pr_attrs + ), f"Bad attrs on {pred_var}: {pred_ds[pred_var].attrs}" + except AssertionError as e: + errors.append(e) + + # pred_ds[var] = pred_ds[var].assign_attrs(pred_pr_attrs) + + for var in meta_vars: + # print(var, pred_ds[var].dims) + try: + assert ("ensemble_member" not in pred_ds[var].dims) and ( + "time" not in pred_ds[var].dims + ), f"Bad dims on {var}: {pred_ds[var].dims}" + except AssertionError as e: + errors.append(e) + # pred_ds[var] = pred_ds[var].isel(ensemble_member=0, time=0) + # print(var, pred_ds[var].dims) + + return errors + + +with logging_redirect_tqdm(): + with tqdm( + total=len(fpaths), + desc=f"CHecking prediction files", + unit=" files", + ) as pbar: + for fpath in fpaths: + pred_ds = xr.open_dataset(fpath) + # import pdb; pdb.set_trace() + errors = check(pred_ds) + + if len(errors) != 5: + print(f"Errors in {fpath}:") + for e in errors: + print(e) + pbar.update(1) + + # print(pred_ds) From 275c9885a56e986c5b08429f13264695e4b2480a Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Thu, 14 Dec 2023 16:29:47 -0600 Subject: [PATCH 09/22] refactor checking preds --- bin/check_preds.py | 66 +++++++++++++++++++++++++++++----------------- 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/bin/check_preds.py b/bin/check_preds.py index f68a9ea3b..0d8c35203 100644 --- a/bin/check_preds.py +++ b/bin/check_preds.py @@ -6,44 +6,46 @@ fpaths = sys.argv[1:] +PRED_VAR = "pred_pr" +META_VARS = ( + "rotated_latitude_longitude", + "time_bnds", + "grid_latitude_bnds", + "grid_longitude_bnds", +) -def check(pred_ds): +PRED_PR_ATTRS = { + "grid_mapping": "rotated_latitude_longitude", + "standard_name": "pred_pr", + "units": "kg m-2 s-1", +} - pred_var = "pred_pr" - meta_vars = ( - "rotated_latitude_longitude", - "time_bnds", - "grid_latitude_bnds", - "grid_longitude_bnds", - ) - errors = [] +def fix(pred_ds): + pred_ds[PRED_VAR] = pred_ds[PRED_VAR].assign_attrs(PRED_PR_ATTRS) + for var in META_VARS: + pred_ds[var] = pred_ds[var].isel(ensemble_member=0, time=0) + + return pred_ds + - pred_pr_attrs = { - "grid_mapping": "rotated_latitude_longitude", - "standard_name": "pred_pr", - "units": "kg m-2 s-1", - } +def check(pred_ds): + errors = [] try: assert ( - pred_ds[pred_var].attrs == pred_pr_attrs - ), f"Bad attrs on {pred_var}: {pred_ds[pred_var].attrs}" + pred_ds[PRED_VAR].attrs == PRED_PR_ATTRS + ), f"Bad attrs on {PRED_VAR}: {pred_ds[PRED_VAR].attrs}" except AssertionError as e: errors.append(e) - # pred_ds[var] = pred_ds[var].assign_attrs(pred_pr_attrs) - - for var in meta_vars: - # print(var, pred_ds[var].dims) + for var in META_VARS: try: assert ("ensemble_member" not in pred_ds[var].dims) and ( "time" not in pred_ds[var].dims ), f"Bad dims on {var}: {pred_ds[var].dims}" except AssertionError as e: errors.append(e) - # pred_ds[var] = pred_ds[var].isel(ensemble_member=0, time=0) - # print(var, pred_ds[var].dims) return errors @@ -51,7 +53,7 @@ def check(pred_ds): with logging_redirect_tqdm(): with tqdm( total=len(fpaths), - desc=f"CHecking prediction files", + desc=f"Checking prediction files", unit=" files", ) as pbar: for fpath in fpaths: @@ -65,4 +67,20 @@ def check(pred_ds): print(e) pbar.update(1) - # print(pred_ds) +# with logging_redirect_tqdm(): +# with tqdm( +# total=len(fpaths), +# desc=f"Fixing prediction files", +# unit=" files", +# ) as pbar: +# for fpath in fpaths: +# pred_ds = xr.open_dataset(fpath) +# # import pdb; pdb.set_trace() +# pred_ds = fix(pred_ds) +# errors = check(pred_ds) + +# if len(errors) > 0: +# print(f"Errors in {fpath}:") +# for e in errors: +# print(e) +# pbar.update(1) From 708ce080275ff1eba82d39b69b4514bb99749689 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Thu, 14 Dec 2023 16:32:17 -0600 Subject: [PATCH 10/22] fix the prediction files --- bin/check_preds.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/bin/check_preds.py b/bin/check_preds.py index 0d8c35203..a7e710b33 100644 --- a/bin/check_preds.py +++ b/bin/check_preds.py @@ -67,20 +67,20 @@ def check(pred_ds): print(e) pbar.update(1) -# with logging_redirect_tqdm(): -# with tqdm( -# total=len(fpaths), -# desc=f"Fixing prediction files", -# unit=" files", -# ) as pbar: -# for fpath in fpaths: -# pred_ds = xr.open_dataset(fpath) -# # import pdb; pdb.set_trace() -# pred_ds = fix(pred_ds) -# errors = check(pred_ds) - -# if len(errors) > 0: -# print(f"Errors in {fpath}:") -# for e in errors: -# print(e) -# pbar.update(1) +with logging_redirect_tqdm(): + with tqdm( + total=len(fpaths), + desc=f"Fixing prediction files", + unit=" files", + ) as pbar: + for fpath in fpaths: + pred_ds = xr.open_dataset(fpath) + # import pdb; pdb.set_trace() + pred_ds = fix(pred_ds) + errors = check(pred_ds) + + if len(errors) > 0: + print(f"Errors in {fpath}:") + for e in errors: + print(e) + pbar.update(1) From a08c5815cb2807662d1bf9dd71b3e07051ea4a8b Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Thu, 14 Dec 2023 16:38:55 -0600 Subject: [PATCH 11/22] save the fixed prediction files too --- bin/check_preds.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bin/check_preds.py b/bin/check_preds.py index a7e710b33..00c689ab9 100644 --- a/bin/check_preds.py +++ b/bin/check_preds.py @@ -74,7 +74,7 @@ def check(pred_ds): unit=" files", ) as pbar: for fpath in fpaths: - pred_ds = xr.open_dataset(fpath) + pred_ds = xr.load_dataset(fpath) # import pdb; pdb.set_trace() pred_ds = fix(pred_ds) errors = check(pred_ds) @@ -83,4 +83,6 @@ def check(pred_ds): print(f"Errors in {fpath}:") for e in errors: print(e) + else: + pred_ds.to_netcdf(fpath) pbar.update(1) From b0cb3758a254573c073310325745275edc6edd45 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Thu, 14 Dec 2023 16:43:21 -0600 Subject: [PATCH 12/22] make fixing idempotent --- bin/check_preds.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bin/check_preds.py b/bin/check_preds.py index 00c689ab9..c885dbee4 100644 --- a/bin/check_preds.py +++ b/bin/check_preds.py @@ -24,7 +24,10 @@ def fix(pred_ds): pred_ds[PRED_VAR] = pred_ds[PRED_VAR].assign_attrs(PRED_PR_ATTRS) for var in META_VARS: - pred_ds[var] = pred_ds[var].isel(ensemble_member=0, time=0) + if "ensemble_member" in pred_ds[var].dims: + pred_ds[var] = pred_ds[var].isel(ensemble_member=0) + if "time" in pred_ds[var].dims: + pred_ds[var] = pred_ds[var].isel(time=0) return pred_ds From cc39e01e9823f9d1a40438a9f36b38a46949fc77 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Thu, 14 Dec 2023 16:48:15 -0600 Subject: [PATCH 13/22] disable saving fixed pred --- bin/check_preds.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bin/check_preds.py b/bin/check_preds.py index c885dbee4..4c1feec2e 100644 --- a/bin/check_preds.py +++ b/bin/check_preds.py @@ -86,6 +86,6 @@ def check(pred_ds): print(f"Errors in {fpath}:") for e in errors: print(e) - else: - pred_ds.to_netcdf(fpath) + # else: + # pred_ds.to_netcdf(fpath) pbar.update(1) From f48d2b5450abc15bed87e053a26d0f7c33063d81 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Thu, 14 Dec 2023 16:53:46 -0600 Subject: [PATCH 14/22] re-enable saving and make sure close file after initial checks --- bin/check_preds.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/bin/check_preds.py b/bin/check_preds.py index 4c1feec2e..c2e00b2dc 100644 --- a/bin/check_preds.py +++ b/bin/check_preds.py @@ -60,14 +60,13 @@ def check(pred_ds): unit=" files", ) as pbar: for fpath in fpaths: - pred_ds = xr.open_dataset(fpath) - # import pdb; pdb.set_trace() - errors = check(pred_ds) + with xr.open_dataset(fpath) as pred_ds: + errors = check(pred_ds) - if len(errors) != 5: - print(f"Errors in {fpath}:") - for e in errors: - print(e) + if len(errors) != 5: + print(f"Errors in {fpath}:") + for e in errors: + print(e) pbar.update(1) with logging_redirect_tqdm(): @@ -78,7 +77,7 @@ def check(pred_ds): ) as pbar: for fpath in fpaths: pred_ds = xr.load_dataset(fpath) - # import pdb; pdb.set_trace() + pred_ds = fix(pred_ds) errors = check(pred_ds) @@ -86,6 +85,6 @@ def check(pred_ds): print(f"Errors in {fpath}:") for e in errors: print(e) - # else: - # pred_ds.to_netcdf(fpath) + else: + pred_ds.to_netcdf(fpath) pbar.update(1) From a32414d83a7ced6bb44ab4557c4727c39fde1284 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Thu, 14 Dec 2023 17:22:39 -0600 Subject: [PATCH 15/22] update predict to ensure it has the correct attributes on pred_pr --- bin/predict.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bin/predict.py b/bin/predict.py index 5859d8f92..def13b88f 100644 --- a/bin/predict.py +++ b/bin/predict.py @@ -165,6 +165,7 @@ def np_samples_to_xr(np_samples, target_transform, coords, cf_data_vars): xr.Dataset(data_vars=data_vars, coords=coords, attrs={}) ) samples_ds = samples_ds.rename({"target_pr": "pred_pr"}) + samples_ds["pred_pr"] = samples_ds["pred_pr"].assign_attrs(pred_pr_attrs) return samples_ds From 18b2fc2f60e64ea2861a19172b8cb8f83a72f132 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Fri, 15 Dec 2023 19:35:11 -0600 Subject: [PATCH 16/22] just check for issues with prediction files in helper script --- bin/check_preds.py | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/bin/check_preds.py b/bin/check_preds.py index c2e00b2dc..b408d0f4b 100644 --- a/bin/check_preds.py +++ b/bin/check_preds.py @@ -63,28 +63,18 @@ def check(pred_ds): with xr.open_dataset(fpath) as pred_ds: errors = check(pred_ds) - if len(errors) != 5: + if len(errors) > 0: print(f"Errors in {fpath}:") for e in errors: print(e) - pbar.update(1) - -with logging_redirect_tqdm(): - with tqdm( - total=len(fpaths), - desc=f"Fixing prediction files", - unit=" files", - ) as pbar: - for fpath in fpaths: - pred_ds = xr.load_dataset(fpath) - pred_ds = fix(pred_ds) - errors = check(pred_ds) + # pred_ds = fix(pred_ds) + # errors = check(pred_ds) - if len(errors) > 0: - print(f"Errors in {fpath}:") - for e in errors: - print(e) - else: - pred_ds.to_netcdf(fpath) + # if len(errors) > 0: + # print(f"Errors in {fpath}:") + # for e in errors: + # print(e) + # else: + # pred_ds.to_netcdf(fpath) pbar.update(1) From 2db984443184b88047b19e1544fd95f8a21081da Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Fri, 15 Dec 2023 19:36:52 -0600 Subject: [PATCH 17/22] make sure no nulls when gcmifying --- src/ml_downscaling_emulator/postprocess.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ml_downscaling_emulator/postprocess.py b/src/ml_downscaling_emulator/postprocess.py index b20b6a2b7..402f7468d 100644 --- a/src/ml_downscaling_emulator/postprocess.py +++ b/src/ml_downscaling_emulator/postprocess.py @@ -79,9 +79,10 @@ def to_gcm_domain(ds: xr.Dataset): target_grid_filepath = files("mlde_utils.data").joinpath( "target_grids/60km/global/pr/moose_grid.nc" ) - xr.open_dataset(target_grid_filepath) ds = Remapcon(target_grid_filepath).run(ds) ds = ShiftLonBreak().run(ds) ds = SelectGCMDomain(subdomain="birmingham", size=9).run(ds) + nan_count = ds["pred_pr"].isnull().sum().values.item() + assert 0 == nan_count, f"nan count: {nan_count}" ds = ds.drop_vars(["rotated_latitude_longitude"], errors="ignore") return ds From e2b7d5710480d2bb1c008a6a5b67411446175681 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Thu, 4 Jan 2024 09:00:07 +0000 Subject: [PATCH 18/22] rename x to target in batch var names since x is confusing in conditional setting where often assumed to be the conditioning for a target y rather than target itself --- .../score_sde_pytorch_hja22/run_lib.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py index d21305b87..b51672bd5 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py @@ -53,14 +53,14 @@ def val_loss(config, eval_ds, eval_step_fn, state): val_set_loss = 0.0 - for eval_cond_batch, eval_x_batch, eval_time_batch in eval_ds: - # eval_cond_batch, eval_x_batch = next(iter(eval_ds)) - eval_x_batch = eval_x_batch.to(config.device) + for eval_cond_batch, eval_target_batch, eval_time_batch in eval_ds: + # eval_cond_batch, eval_target_batch = next(iter(eval_ds)) + eval_target_batch = eval_target_batch.to(config.device) eval_cond_batch = eval_cond_batch.to(config.device) # append any location-specific parameters eval_cond_batch = state['location_params'](eval_cond_batch) # eval_batch = eval_batch.permute(0, 3, 1, 2) - eval_loss = eval_step_fn(state, eval_x_batch, eval_cond_batch) + eval_loss = eval_step_fn(state, eval_target_batch, eval_cond_batch) # Progress val_set_loss += eval_loss.item() @@ -171,21 +171,21 @@ def train(config, workdir): train_set_loss = 0.0 with logging_redirect_tqdm(): with tqdm(total=len(train_dl.dataset), desc=f"Epoch {state['epoch']}", unit=' timesteps') as pbar: - for cond_batch, x_batch, time_batch in train_dl: + for cond_batch, target_batch, time_batch in train_dl: - x_batch = x_batch.to(config.device) + target_batch = target_batch.to(config.device) cond_batch = cond_batch.to(config.device) # append any location-specific parameters cond_batch = state['location_params'](cond_batch) if config.training.random_crop_size > 0: - x_ch = x_batch.shape[1] - cropped = random_crop(torch.cat([x_batch, cond_batch], dim=1)) - x_batch = cropped[:,:x_ch] + x_ch = target_batch.shape[1] + cropped = random_crop(torch.cat([target_batch, cond_batch], dim=1)) + target_batch = cropped[:,:x_ch] cond_batch = cropped[:,x_ch:] # Execute one training step - loss = train_step_fn(state, x_batch, cond_batch) + loss = train_step_fn(state, target_batch, cond_batch) train_set_loss += loss.item() if state['step'] % config.training.log_freq == 0: logging.info("epoch: %d, step: %d, train_loss: %.5e" % (state['epoch'], state['step'], loss.item())) From 394454eb1f230d425b89d5d4f45250173211c66c Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Tue, 9 Jan 2024 11:45:42 +0000 Subject: [PATCH 19/22] add script for checking number of parameters in a model --- bin/model-size | 71 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100755 bin/model-size diff --git a/bin/model-size b/bin/model-size new file mode 100755 index 000000000..38b64f7f4 --- /dev/null +++ b/bin/model-size @@ -0,0 +1,71 @@ +#!/usr/bin/env python + +import os +from pathlib import Path + +from ml_collections import config_dict + +import typer +import logging +import yaml + +from ml_downscaling_emulator.score_sde_pytorch_hja22.models.location_params import ( + LocationParams, +) + +from ml_downscaling_emulator.score_sde_pytorch_hja22.models import utils as mutils + +# from score_sde_pytorch_hja22.models import ncsnv2 +# from score_sde_pytorch_hja22.models import ncsnpp +from ml_downscaling_emulator.score_sde_pytorch_hja22.models import cncsnpp # noqa: F401 +from ml_downscaling_emulator.score_sde_pytorch_hja22.models import cunet # noqa: F401 + +# from score_sde_pytorch_hja22.models import ddpm as ddpm_model +from ml_downscaling_emulator.score_sde_pytorch_hja22.models import ( # noqa: F401 + layerspp, # noqa: F401 +) # noqa: F401 +from ml_downscaling_emulator.score_sde_pytorch_hja22.models import layers # noqa: F401 +from ml_downscaling_emulator.score_sde_pytorch_hja22.models import ( # noqa: F401 + normalization, # noqa: F401 +) # noqa: F401 + +logger = logging.getLogger() +logger.setLevel("INFO") + +app = typer.Typer() + + +def load_model(config): + logger.info(f"Loading model from config") + score_model = mutils.get_model(config.model.name)(config) + location_params = LocationParams( + config.model.loc_spec_channels, config.data.image_size + ) + + return score_model, location_params + + +def load_config(config_path): + logger.info(f"Loading config from {config_path}") + with open(config_path) as f: + config = config_dict.ConfigDict(yaml.unsafe_load(f)) + + return config + + +@app.command() +def main( + workdir: Path, +): + config_path = os.path.join(workdir, "config.yml") + config = load_config(config_path) + score_model, location_params = load_model(config) + num_score_model_parameters = sum(p.numel() for p in score_model.parameters()) + num_location_parameters = sum(p.numel() for p in location_params.parameters()) + + typer.echo(f"Score model has {num_score_model_parameters} parameters") + typer.echo(f"Location parameters have {num_location_parameters} parameters") + + +if __name__ == "__main__": + app() From 8894f7f99f4f24e2512b66be1035263088611032 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Wed, 14 Feb 2024 14:37:21 +0000 Subject: [PATCH 20/22] allow for setting of which dataset a transform should be fitted to so doesn't have to be the same as the dataset it is applied during sampling during training it is always the same dataset for both (though potentially even this could be uncoupled) --- bin/predict.py | 8 +++++++- src/ml_downscaling_emulator/bin/evaluate.py | 8 +++++++- src/ml_downscaling_emulator/deterministic/run_lib.py | 2 ++ .../score_sde_pytorch_hja22/run_lib.py | 4 ++-- src/ml_downscaling_emulator/torch.py | 3 +++ 5 files changed, 21 insertions(+), 4 deletions(-) diff --git a/bin/predict.py b/bin/predict.py index def13b88f..79ec5ab35 100644 --- a/bin/predict.py +++ b/bin/predict.py @@ -240,6 +240,7 @@ def main( epoch: int = typer.Option(...), batch_size: int = None, num_samples: int = 3, + input_transform_dataset: str = None, input_transform_key: str = None, ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER, ): @@ -247,6 +248,10 @@ def main( config = load_config(config_path) if batch_size is not None: config.eval.batch_size = batch_size + if input_transform_dataset is not None: + config.data.input_transform_dataset = input_transform_dataset + else: + config.data.input_transform_dataset = dataset if input_transform_key is not None: config.data.input_transform_key = input_transform_key @@ -254,7 +259,7 @@ def main( workdir=workdir, checkpoint=f"epoch-{epoch}", dataset=dataset, - input_xfm=config.data.input_transform_key, + input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}", split=split, ensemble_member=ensemble_member, ) @@ -270,6 +275,7 @@ def main( eval_dl, _, target_transform = get_dataloader( dataset, config.data.dataset_name, + config.data.input_transform_dataset, config.data.input_transform_key, config.data.target_transform_key, transform_dir, diff --git a/src/ml_downscaling_emulator/bin/evaluate.py b/src/ml_downscaling_emulator/bin/evaluate.py index e6f78f36e..5915002cc 100644 --- a/src/ml_downscaling_emulator/bin/evaluate.py +++ b/src/ml_downscaling_emulator/bin/evaluate.py @@ -61,6 +61,7 @@ def sample( epoch: int = typer.Option(...), batch_size: int = None, num_samples: int = 1, + input_transform_dataset: str = None, input_transform_key: str = None, ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER, ): @@ -70,6 +71,10 @@ def sample( if batch_size is not None: config.eval.batch_size = batch_size + if input_transform_dataset is not None: + config.data.input_transform_dataset = input_transform_dataset + else: + config.data.input_transform_dataset = dataset if input_transform_key is not None: config.data.input_transform_key = input_transform_key @@ -77,7 +82,7 @@ def sample( workdir=workdir, checkpoint=f"epoch-{epoch}", dataset=dataset, - input_xfm=config.data.input_transform_key, + input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}", split=split, ensemble_member=ensemble_member, ) @@ -88,6 +93,7 @@ def sample( eval_dl, _, target_transform = get_dataloader( dataset, config.data.dataset_name, + config.data.input_transform_dataset, config.data.input_transform_key, config.data.target_transform_key, transform_dir, diff --git a/src/ml_downscaling_emulator/deterministic/run_lib.py b/src/ml_downscaling_emulator/deterministic/run_lib.py index 407444cd8..79af70e34 100644 --- a/src/ml_downscaling_emulator/deterministic/run_lib.py +++ b/src/ml_downscaling_emulator/deterministic/run_lib.py @@ -77,6 +77,7 @@ def train(config, workdir): # Build dataloaders train_dl, _, _ = get_dataloader( + config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, @@ -89,6 +90,7 @@ def train(config, workdir): evaluation=False, ) val_dl, _, _ = get_dataloader( + config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py index b51672bd5..f82cd76ee 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py @@ -110,8 +110,8 @@ def train(config, workdir): ) as (wandb_run, writer): # Build dataloaders dataset_meta = DatasetMetadata(config.data.dataset_name) - train_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="train", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) - eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) + train_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="train", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) + eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) # Initialize model. score_model = mutils.create_model(config) diff --git a/src/ml_downscaling_emulator/torch.py b/src/ml_downscaling_emulator/torch.py index aa0b6a90b..857af654c 100644 --- a/src/ml_downscaling_emulator/torch.py +++ b/src/ml_downscaling_emulator/torch.py @@ -112,6 +112,7 @@ def custom_collate(batch): def get_dataloader( active_dataset_name, model_src_dataset_name, + input_transform_dataset_name, input_transform_key, target_transform_key, transform_dir, @@ -127,6 +128,7 @@ def get_dataloader( Args: active_dataset_name: Name of dataset from which to load data splits model_src_dataset_name: Name of dataset used to train the diffusion model (may be the same) + input_transform_dataset_name: Name of dataset to use for fitting input transform (may be the same as active_dataset_name or model_src_dataset_name) transform_dir: Path to where transforms should be stored input_transform_key: Name of input transform pipeline to use target_transform_key: Name of target transform pipeline to use @@ -140,6 +142,7 @@ def get_dataloader( xr_data, transform, target_transform = get_dataset( active_dataset_name, model_src_dataset_name, + input_transform_dataset_name, input_transform_key, target_transform_key, transform_dir, From b2e5d885d210ba9e61785403847e62f13a13821c Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Wed, 14 Feb 2024 14:47:46 +0000 Subject: [PATCH 21/22] set a default value for input transform dataset config --- .../score_sde_pytorch_hja22/configs/default_xarray_configs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_xarray_configs.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_xarray_configs.py index 89e59be57..3d4e8ac5a 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_xarray_configs.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_xarray_configs.py @@ -46,6 +46,7 @@ def get_default_configs(): data.random_flip = False data.centered = False data.uniform_dequantization = False + data.input_transform_dataset = None data.input_transform_key = "pixelmmsstanur" data.target_transform_key = "v1" data.time_inputs = False From 3e2646a1c4db61072ede18b8c938a06389be31cd Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Wed, 14 Feb 2024 14:55:02 +0000 Subject: [PATCH 22/22] need to unlock the config to set input_transform_dataset --- bin/predict.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/bin/predict.py b/bin/predict.py index 79ec5ab35..5aa95e2ec 100644 --- a/bin/predict.py +++ b/bin/predict.py @@ -248,10 +248,11 @@ def main( config = load_config(config_path) if batch_size is not None: config.eval.batch_size = batch_size - if input_transform_dataset is not None: - config.data.input_transform_dataset = input_transform_dataset - else: - config.data.input_transform_dataset = dataset + with config.unlocked(): + if input_transform_dataset is not None: + config.data.input_transform_dataset = input_transform_dataset + else: + config.data.input_transform_dataset = dataset if input_transform_key is not None: config.data.input_transform_key = input_transform_key