diff --git a/bin/check_preds.py b/bin/check_preds.py new file mode 100644 index 000000000..b408d0f4b --- /dev/null +++ b/bin/check_preds.py @@ -0,0 +1,80 @@ +import sys + +from tqdm import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm +import xarray as xr + +fpaths = sys.argv[1:] + +PRED_VAR = "pred_pr" +META_VARS = ( + "rotated_latitude_longitude", + "time_bnds", + "grid_latitude_bnds", + "grid_longitude_bnds", +) + +PRED_PR_ATTRS = { + "grid_mapping": "rotated_latitude_longitude", + "standard_name": "pred_pr", + "units": "kg m-2 s-1", +} + + +def fix(pred_ds): + pred_ds[PRED_VAR] = pred_ds[PRED_VAR].assign_attrs(PRED_PR_ATTRS) + for var in META_VARS: + if "ensemble_member" in pred_ds[var].dims: + pred_ds[var] = pred_ds[var].isel(ensemble_member=0) + if "time" in pred_ds[var].dims: + pred_ds[var] = pred_ds[var].isel(time=0) + + return pred_ds + + +def check(pred_ds): + errors = [] + + try: + assert ( + pred_ds[PRED_VAR].attrs == PRED_PR_ATTRS + ), f"Bad attrs on {PRED_VAR}: {pred_ds[PRED_VAR].attrs}" + except AssertionError as e: + errors.append(e) + + for var in META_VARS: + try: + assert ("ensemble_member" not in pred_ds[var].dims) and ( + "time" not in pred_ds[var].dims + ), f"Bad dims on {var}: {pred_ds[var].dims}" + except AssertionError as e: + errors.append(e) + + return errors + + +with logging_redirect_tqdm(): + with tqdm( + total=len(fpaths), + desc=f"Checking prediction files", + unit=" files", + ) as pbar: + for fpath in fpaths: + with xr.open_dataset(fpath) as pred_ds: + errors = check(pred_ds) + + if len(errors) > 0: + print(f"Errors in {fpath}:") + for e in errors: + print(e) + + # pred_ds = fix(pred_ds) + # errors = check(pred_ds) + + # if len(errors) > 0: + # print(f"Errors in {fpath}:") + # for e in errors: + # print(e) + # else: + # pred_ds.to_netcdf(fpath) + pbar.update(1) diff --git a/bin/model-size b/bin/model-size new file mode 100755 index 000000000..38b64f7f4 --- /dev/null +++ b/bin/model-size @@ -0,0 +1,71 @@ +#!/usr/bin/env python + +import os +from pathlib import Path + +from ml_collections import config_dict + +import typer +import logging +import yaml + +from ml_downscaling_emulator.score_sde_pytorch_hja22.models.location_params import ( + LocationParams, +) + +from ml_downscaling_emulator.score_sde_pytorch_hja22.models import utils as mutils + +# from score_sde_pytorch_hja22.models import ncsnv2 +# from score_sde_pytorch_hja22.models import ncsnpp +from ml_downscaling_emulator.score_sde_pytorch_hja22.models import cncsnpp # noqa: F401 +from ml_downscaling_emulator.score_sde_pytorch_hja22.models import cunet # noqa: F401 + +# from score_sde_pytorch_hja22.models import ddpm as ddpm_model +from ml_downscaling_emulator.score_sde_pytorch_hja22.models import ( # noqa: F401 + layerspp, # noqa: F401 +) # noqa: F401 +from ml_downscaling_emulator.score_sde_pytorch_hja22.models import layers # noqa: F401 +from ml_downscaling_emulator.score_sde_pytorch_hja22.models import ( # noqa: F401 + normalization, # noqa: F401 +) # noqa: F401 + +logger = logging.getLogger() +logger.setLevel("INFO") + +app = typer.Typer() + + +def load_model(config): + logger.info(f"Loading model from config") + score_model = mutils.get_model(config.model.name)(config) + location_params = LocationParams( + config.model.loc_spec_channels, config.data.image_size + ) + + return score_model, location_params + + +def load_config(config_path): + logger.info(f"Loading config from {config_path}") + with open(config_path) as f: + config = config_dict.ConfigDict(yaml.unsafe_load(f)) + + return config + + +@app.command() +def main( + workdir: Path, +): + config_path = os.path.join(workdir, "config.yml") + config = load_config(config_path) + score_model, location_params = load_model(config) + num_score_model_parameters = sum(p.numel() for p in score_model.parameters()) + num_location_parameters = sum(p.numel() for p in location_params.parameters()) + + typer.echo(f"Score model has {num_score_model_parameters} parameters") + typer.echo(f"Location parameters have {num_location_parameters} parameters") + + +if __name__ == "__main__": + app() diff --git a/bin/predict.py b/bin/predict.py index 5859d8f92..5aa95e2ec 100644 --- a/bin/predict.py +++ b/bin/predict.py @@ -165,6 +165,7 @@ def np_samples_to_xr(np_samples, target_transform, coords, cf_data_vars): xr.Dataset(data_vars=data_vars, coords=coords, attrs={}) ) samples_ds = samples_ds.rename({"target_pr": "pred_pr"}) + samples_ds["pred_pr"] = samples_ds["pred_pr"].assign_attrs(pred_pr_attrs) return samples_ds @@ -239,6 +240,7 @@ def main( epoch: int = typer.Option(...), batch_size: int = None, num_samples: int = 3, + input_transform_dataset: str = None, input_transform_key: str = None, ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER, ): @@ -246,6 +248,11 @@ def main( config = load_config(config_path) if batch_size is not None: config.eval.batch_size = batch_size + with config.unlocked(): + if input_transform_dataset is not None: + config.data.input_transform_dataset = input_transform_dataset + else: + config.data.input_transform_dataset = dataset if input_transform_key is not None: config.data.input_transform_key = input_transform_key @@ -253,7 +260,7 @@ def main( workdir=workdir, checkpoint=f"epoch-{epoch}", dataset=dataset, - input_xfm=config.data.input_transform_key, + input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}", split=split, ensemble_member=ensemble_member, ) @@ -269,6 +276,7 @@ def main( eval_dl, _, target_transform = get_dataloader( dataset, config.data.dataset_name, + config.data.input_transform_dataset, config.data.input_transform_key, config.data.target_transform_key, transform_dir, diff --git a/src/ml_downscaling_emulator/bin/evaluate.py b/src/ml_downscaling_emulator/bin/evaluate.py index e6f78f36e..5915002cc 100644 --- a/src/ml_downscaling_emulator/bin/evaluate.py +++ b/src/ml_downscaling_emulator/bin/evaluate.py @@ -61,6 +61,7 @@ def sample( epoch: int = typer.Option(...), batch_size: int = None, num_samples: int = 1, + input_transform_dataset: str = None, input_transform_key: str = None, ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER, ): @@ -70,6 +71,10 @@ def sample( if batch_size is not None: config.eval.batch_size = batch_size + if input_transform_dataset is not None: + config.data.input_transform_dataset = input_transform_dataset + else: + config.data.input_transform_dataset = dataset if input_transform_key is not None: config.data.input_transform_key = input_transform_key @@ -77,7 +82,7 @@ def sample( workdir=workdir, checkpoint=f"epoch-{epoch}", dataset=dataset, - input_xfm=config.data.input_transform_key, + input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}", split=split, ensemble_member=ensemble_member, ) @@ -88,6 +93,7 @@ def sample( eval_dl, _, target_transform = get_dataloader( dataset, config.data.dataset_name, + config.data.input_transform_dataset, config.data.input_transform_key, config.data.target_transform_key, transform_dir, diff --git a/src/ml_downscaling_emulator/bin/postprocess.py b/src/ml_downscaling_emulator/bin/postprocess.py index 202fa4f8d..2eb4e5c2d 100644 --- a/src/ml_downscaling_emulator/bin/postprocess.py +++ b/src/ml_downscaling_emulator/bin/postprocess.py @@ -2,13 +2,14 @@ import logging import os from pathlib import Path +from typing import Callable import typer import xarray as xr from mlde_utils import samples_path, samples_glob, TIME_PERIODS from mlde_utils.training.dataset import open_raw_dataset_split -from ml_downscaling_emulator.postprocess import xrqm +from ml_downscaling_emulator.postprocess import xrqm, to_gcm_domain logging.basicConfig( level=logging.INFO, @@ -25,6 +26,50 @@ def callback(): pass +def process_each_sample( + workdir: Path, + checkpoint: str, + dataset: str, + ensemble_member: str, + input_xfm: str, + split: str, + processing_func: Callable, + new_workdir: Path, +): + samples_dirpath = samples_path( + workdir, + checkpoint=checkpoint, + input_xfm=input_xfm, + dataset=dataset, + split=split, + ensemble_member=ensemble_member, + ) + logger.info(f"Iterating on samples in {samples_dirpath}") + for sample_filepath in samples_glob(samples_dirpath): + logger.info(f"Working on {sample_filepath}") + # open the samples + samples_ds = xr.open_dataset(sample_filepath) + + processed_samples_ds = processing_func(samples_ds) + + # save output + processed_sample_filepath = ( + samples_path( + new_workdir, + checkpoint=checkpoint, + input_xfm=input_xfm, + dataset=dataset, + split=split, + ensemble_member=ensemble_member, + ) + / sample_filepath.name + ) + + logger.info(f"Saving to {processed_sample_filepath}") + processed_sample_filepath.parent.mkdir(parents=True, exist_ok=True) + processed_samples_ds.to_netcdf(processed_sample_filepath) + + @app.command() def filter( workdir: Path, @@ -103,44 +148,53 @@ def qm( )[0] )["pred_pr"] - ml_eval_samples_dirpath = samples_path( - workdir, - checkpoint=checkpoint, - input_xfm=eval_input_xfm, - dataset=eval_dataset, - split=split, - ensemble_member=ensemble_member, - ) - logger.info(f"QMapping samplesin {ml_eval_samples_dirpath}") - for sample_filepath in samples_glob(ml_eval_samples_dirpath): - logger.info(f"Working on {sample_filepath}") - # open the samples to be qmapped - ml_eval_ds = xr.open_dataset(sample_filepath) - + def process_samples(ds): # do the qmapping - qmapped_eval_da = xrqm(sim_train_da, ml_train_da, ml_eval_ds["pred_pr"]) + qmapped_eval_da = xrqm(sim_train_da, ml_train_da, ds["pred_pr"]) - qmapped_eval_ds = ml_eval_ds.copy() - qmapped_eval_ds["pred_pr"] = qmapped_eval_da + processed_samples_ds = ds.copy() + processed_samples_ds["pred_pr"] = qmapped_eval_da - # save output - new_workdir = workdir / "postprocess" / "qm" + return processed_samples_ds - qmapped_sample_filepath = ( - samples_path( - new_workdir, - checkpoint=checkpoint, - input_xfm=eval_input_xfm, - dataset=eval_dataset, - split=split, - ensemble_member=ensemble_member, - ) - / sample_filepath.name + process_each_sample( + workdir, + checkpoint, + eval_dataset, + ensemble_member, + eval_input_xfm, + split, + process_samples, + new_workdir=workdir / "postprocess" / "qm", + ) + + +@app.command() +def gcmify( + workdir: Path, + checkpoint: str = typer.Option(...), + dataset: str = typer.Option(...), + input_xfm: str = typer.Option(...), + split: str = typer.Option(...), + ensemble_member: str = typer.Option(...), +): + def process_samples(ds): + ds = to_gcm_domain(ds.sel(ensemble_member=ensemble_member)) + ds["pred_pr"] = ds["pred_pr"].expand_dims( + {"ensemble_member": [ensemble_member]} ) + return ds - logger.info(f"Saving to {qmapped_sample_filepath}") - qmapped_sample_filepath.parent.mkdir(parents=True, exist_ok=True) - qmapped_eval_ds.to_netcdf(qmapped_sample_filepath) + process_each_sample( + workdir, + checkpoint, + dataset, + ensemble_member, + input_xfm, + split, + process_samples, + new_workdir=workdir / "postprocess" / "gcm-grid", + ) @app.command() diff --git a/src/ml_downscaling_emulator/deterministic/run_lib.py b/src/ml_downscaling_emulator/deterministic/run_lib.py index 407444cd8..79af70e34 100644 --- a/src/ml_downscaling_emulator/deterministic/run_lib.py +++ b/src/ml_downscaling_emulator/deterministic/run_lib.py @@ -77,6 +77,7 @@ def train(config, workdir): # Build dataloaders train_dl, _, _ = get_dataloader( + config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, @@ -89,6 +90,7 @@ def train(config, workdir): evaluation=False, ) val_dl, _, _ = get_dataloader( + config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, diff --git a/src/ml_downscaling_emulator/postprocess.py b/src/ml_downscaling_emulator/postprocess.py index 6ff41ab82..402f7468d 100644 --- a/src/ml_downscaling_emulator/postprocess.py +++ b/src/ml_downscaling_emulator/postprocess.py @@ -1,7 +1,12 @@ +from importlib.resources import files from typing import Callable import numpy as np import xarray as xr +from mlde_utils.data.remapcon import Remapcon +from mlde_utils.data.shift_lon_break import ShiftLonBreak +from mlde_utils.data.select_gcm_domain import SelectGCMDomain + def _get_cdf(x, xbins): pdf, _ = np.histogram(x, xbins) @@ -68,3 +73,16 @@ def xrqm( .transpose("ensemble_member", "time", "grid_latitude", "grid_longitude") .assign_coords(time=ml_eval_da["time"]) ) + + +def to_gcm_domain(ds: xr.Dataset): + target_grid_filepath = files("mlde_utils.data").joinpath( + "target_grids/60km/global/pr/moose_grid.nc" + ) + ds = Remapcon(target_grid_filepath).run(ds) + ds = ShiftLonBreak().run(ds) + ds = SelectGCMDomain(subdomain="birmingham", size=9).run(ds) + nan_count = ds["pred_pr"].isnull().sum().values.item() + assert 0 == nan_count, f"nan count: {nan_count}" + ds = ds.drop_vars(["rotated_latitude_longitude"], errors="ignore") + return ds diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_xarray_configs.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_xarray_configs.py index 89e59be57..3d4e8ac5a 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_xarray_configs.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_xarray_configs.py @@ -46,6 +46,7 @@ def get_default_configs(): data.random_flip = False data.centered = False data.uniform_dequantization = False + data.input_transform_dataset = None data.input_transform_key = "pixelmmsstanur" data.target_transform_key = "v1" data.time_inputs = False diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py index d21305b87..f82cd76ee 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py @@ -53,14 +53,14 @@ def val_loss(config, eval_ds, eval_step_fn, state): val_set_loss = 0.0 - for eval_cond_batch, eval_x_batch, eval_time_batch in eval_ds: - # eval_cond_batch, eval_x_batch = next(iter(eval_ds)) - eval_x_batch = eval_x_batch.to(config.device) + for eval_cond_batch, eval_target_batch, eval_time_batch in eval_ds: + # eval_cond_batch, eval_target_batch = next(iter(eval_ds)) + eval_target_batch = eval_target_batch.to(config.device) eval_cond_batch = eval_cond_batch.to(config.device) # append any location-specific parameters eval_cond_batch = state['location_params'](eval_cond_batch) # eval_batch = eval_batch.permute(0, 3, 1, 2) - eval_loss = eval_step_fn(state, eval_x_batch, eval_cond_batch) + eval_loss = eval_step_fn(state, eval_target_batch, eval_cond_batch) # Progress val_set_loss += eval_loss.item() @@ -110,8 +110,8 @@ def train(config, workdir): ) as (wandb_run, writer): # Build dataloaders dataset_meta = DatasetMetadata(config.data.dataset_name) - train_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="train", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) - eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) + train_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="train", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) + eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) # Initialize model. score_model = mutils.create_model(config) @@ -171,21 +171,21 @@ def train(config, workdir): train_set_loss = 0.0 with logging_redirect_tqdm(): with tqdm(total=len(train_dl.dataset), desc=f"Epoch {state['epoch']}", unit=' timesteps') as pbar: - for cond_batch, x_batch, time_batch in train_dl: + for cond_batch, target_batch, time_batch in train_dl: - x_batch = x_batch.to(config.device) + target_batch = target_batch.to(config.device) cond_batch = cond_batch.to(config.device) # append any location-specific parameters cond_batch = state['location_params'](cond_batch) if config.training.random_crop_size > 0: - x_ch = x_batch.shape[1] - cropped = random_crop(torch.cat([x_batch, cond_batch], dim=1)) - x_batch = cropped[:,:x_ch] + x_ch = target_batch.shape[1] + cropped = random_crop(torch.cat([target_batch, cond_batch], dim=1)) + target_batch = cropped[:,:x_ch] cond_batch = cropped[:,x_ch:] # Execute one training step - loss = train_step_fn(state, x_batch, cond_batch) + loss = train_step_fn(state, target_batch, cond_batch) train_set_loss += loss.item() if state['step'] % config.training.log_freq == 0: logging.info("epoch: %d, step: %d, train_loss: %.5e" % (state['epoch'], state['step'], loss.item())) diff --git a/src/ml_downscaling_emulator/torch.py b/src/ml_downscaling_emulator/torch.py index aa0b6a90b..857af654c 100644 --- a/src/ml_downscaling_emulator/torch.py +++ b/src/ml_downscaling_emulator/torch.py @@ -112,6 +112,7 @@ def custom_collate(batch): def get_dataloader( active_dataset_name, model_src_dataset_name, + input_transform_dataset_name, input_transform_key, target_transform_key, transform_dir, @@ -127,6 +128,7 @@ def get_dataloader( Args: active_dataset_name: Name of dataset from which to load data splits model_src_dataset_name: Name of dataset used to train the diffusion model (may be the same) + input_transform_dataset_name: Name of dataset to use for fitting input transform (may be the same as active_dataset_name or model_src_dataset_name) transform_dir: Path to where transforms should be stored input_transform_key: Name of input transform pipeline to use target_transform_key: Name of target transform pipeline to use @@ -140,6 +142,7 @@ def get_dataloader( xr_data, transform, target_transform = get_dataset( active_dataset_name, model_src_dataset_name, + input_transform_dataset_name, input_transform_key, target_transform_key, transform_dir,