Skip to content

Commit

Permalink
Merge pull request #30 from henryaddison/raw-gcm-pr
Browse files Browse the repository at this point in the history
Raw gcm pr
  • Loading branch information
henryaddison authored Feb 23, 2024
2 parents 3824919 + 3e2646a commit 5ea8e6a
Show file tree
Hide file tree
Showing 10 changed files with 290 additions and 47 deletions.
80 changes: 80 additions & 0 deletions bin/check_preds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import sys

from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
import xarray as xr

fpaths = sys.argv[1:]

PRED_VAR = "pred_pr"
META_VARS = (
"rotated_latitude_longitude",
"time_bnds",
"grid_latitude_bnds",
"grid_longitude_bnds",
)

PRED_PR_ATTRS = {
"grid_mapping": "rotated_latitude_longitude",
"standard_name": "pred_pr",
"units": "kg m-2 s-1",
}


def fix(pred_ds):
pred_ds[PRED_VAR] = pred_ds[PRED_VAR].assign_attrs(PRED_PR_ATTRS)
for var in META_VARS:
if "ensemble_member" in pred_ds[var].dims:
pred_ds[var] = pred_ds[var].isel(ensemble_member=0)
if "time" in pred_ds[var].dims:
pred_ds[var] = pred_ds[var].isel(time=0)

return pred_ds


def check(pred_ds):
errors = []

try:
assert (
pred_ds[PRED_VAR].attrs == PRED_PR_ATTRS
), f"Bad attrs on {PRED_VAR}: {pred_ds[PRED_VAR].attrs}"
except AssertionError as e:
errors.append(e)

for var in META_VARS:
try:
assert ("ensemble_member" not in pred_ds[var].dims) and (
"time" not in pred_ds[var].dims
), f"Bad dims on {var}: {pred_ds[var].dims}"
except AssertionError as e:
errors.append(e)

return errors


with logging_redirect_tqdm():
with tqdm(
total=len(fpaths),
desc=f"Checking prediction files",
unit=" files",
) as pbar:
for fpath in fpaths:
with xr.open_dataset(fpath) as pred_ds:
errors = check(pred_ds)

if len(errors) > 0:
print(f"Errors in {fpath}:")
for e in errors:
print(e)

# pred_ds = fix(pred_ds)
# errors = check(pred_ds)

# if len(errors) > 0:
# print(f"Errors in {fpath}:")
# for e in errors:
# print(e)
# else:
# pred_ds.to_netcdf(fpath)
pbar.update(1)
71 changes: 71 additions & 0 deletions bin/model-size
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python

import os
from pathlib import Path

from ml_collections import config_dict

import typer
import logging
import yaml

from ml_downscaling_emulator.score_sde_pytorch_hja22.models.location_params import (
LocationParams,
)

from ml_downscaling_emulator.score_sde_pytorch_hja22.models import utils as mutils

# from score_sde_pytorch_hja22.models import ncsnv2
# from score_sde_pytorch_hja22.models import ncsnpp
from ml_downscaling_emulator.score_sde_pytorch_hja22.models import cncsnpp # noqa: F401
from ml_downscaling_emulator.score_sde_pytorch_hja22.models import cunet # noqa: F401

# from score_sde_pytorch_hja22.models import ddpm as ddpm_model
from ml_downscaling_emulator.score_sde_pytorch_hja22.models import ( # noqa: F401
layerspp, # noqa: F401
) # noqa: F401
from ml_downscaling_emulator.score_sde_pytorch_hja22.models import layers # noqa: F401
from ml_downscaling_emulator.score_sde_pytorch_hja22.models import ( # noqa: F401
normalization, # noqa: F401
) # noqa: F401

logger = logging.getLogger()
logger.setLevel("INFO")

app = typer.Typer()


def load_model(config):
logger.info(f"Loading model from config")
score_model = mutils.get_model(config.model.name)(config)
location_params = LocationParams(
config.model.loc_spec_channels, config.data.image_size
)

return score_model, location_params


def load_config(config_path):
logger.info(f"Loading config from {config_path}")
with open(config_path) as f:
config = config_dict.ConfigDict(yaml.unsafe_load(f))

return config


@app.command()
def main(
workdir: Path,
):
config_path = os.path.join(workdir, "config.yml")
config = load_config(config_path)
score_model, location_params = load_model(config)
num_score_model_parameters = sum(p.numel() for p in score_model.parameters())
num_location_parameters = sum(p.numel() for p in location_params.parameters())

typer.echo(f"Score model has {num_score_model_parameters} parameters")
typer.echo(f"Location parameters have {num_location_parameters} parameters")


if __name__ == "__main__":
app()
10 changes: 9 additions & 1 deletion bin/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def np_samples_to_xr(np_samples, target_transform, coords, cf_data_vars):
xr.Dataset(data_vars=data_vars, coords=coords, attrs={})
)
samples_ds = samples_ds.rename({"target_pr": "pred_pr"})
samples_ds["pred_pr"] = samples_ds["pred_pr"].assign_attrs(pred_pr_attrs)
return samples_ds


Expand Down Expand Up @@ -239,21 +240,27 @@ def main(
epoch: int = typer.Option(...),
batch_size: int = None,
num_samples: int = 3,
input_transform_dataset: str = None,
input_transform_key: str = None,
ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER,
):
config_path = os.path.join(workdir, "config.yml")
config = load_config(config_path)
if batch_size is not None:
config.eval.batch_size = batch_size
with config.unlocked():
if input_transform_dataset is not None:
config.data.input_transform_dataset = input_transform_dataset
else:
config.data.input_transform_dataset = dataset
if input_transform_key is not None:
config.data.input_transform_key = input_transform_key

output_dirpath = samples_path(
workdir=workdir,
checkpoint=f"epoch-{epoch}",
dataset=dataset,
input_xfm=config.data.input_transform_key,
input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}",
split=split,
ensemble_member=ensemble_member,
)
Expand All @@ -269,6 +276,7 @@ def main(
eval_dl, _, target_transform = get_dataloader(
dataset,
config.data.dataset_name,
config.data.input_transform_dataset,
config.data.input_transform_key,
config.data.target_transform_key,
transform_dir,
Expand Down
8 changes: 7 additions & 1 deletion src/ml_downscaling_emulator/bin/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def sample(
epoch: int = typer.Option(...),
batch_size: int = None,
num_samples: int = 1,
input_transform_dataset: str = None,
input_transform_key: str = None,
ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER,
):
Expand All @@ -70,14 +71,18 @@ def sample(

if batch_size is not None:
config.eval.batch_size = batch_size
if input_transform_dataset is not None:
config.data.input_transform_dataset = input_transform_dataset
else:
config.data.input_transform_dataset = dataset
if input_transform_key is not None:
config.data.input_transform_key = input_transform_key

output_dirpath = samples_path(
workdir=workdir,
checkpoint=f"epoch-{epoch}",
dataset=dataset,
input_xfm=config.data.input_transform_key,
input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}",
split=split,
ensemble_member=ensemble_member,
)
Expand All @@ -88,6 +93,7 @@ def sample(
eval_dl, _, target_transform = get_dataloader(
dataset,
config.data.dataset_name,
config.data.input_transform_dataset,
config.data.input_transform_key,
config.data.target_transform_key,
transform_dir,
Expand Down
120 changes: 87 additions & 33 deletions src/ml_downscaling_emulator/bin/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import logging
import os
from pathlib import Path
from typing import Callable
import typer
import xarray as xr

from mlde_utils import samples_path, samples_glob, TIME_PERIODS
from mlde_utils.training.dataset import open_raw_dataset_split

from ml_downscaling_emulator.postprocess import xrqm
from ml_downscaling_emulator.postprocess import xrqm, to_gcm_domain

logging.basicConfig(
level=logging.INFO,
Expand All @@ -25,6 +26,50 @@ def callback():
pass


def process_each_sample(
workdir: Path,
checkpoint: str,
dataset: str,
ensemble_member: str,
input_xfm: str,
split: str,
processing_func: Callable,
new_workdir: Path,
):
samples_dirpath = samples_path(
workdir,
checkpoint=checkpoint,
input_xfm=input_xfm,
dataset=dataset,
split=split,
ensemble_member=ensemble_member,
)
logger.info(f"Iterating on samples in {samples_dirpath}")
for sample_filepath in samples_glob(samples_dirpath):
logger.info(f"Working on {sample_filepath}")
# open the samples
samples_ds = xr.open_dataset(sample_filepath)

processed_samples_ds = processing_func(samples_ds)

# save output
processed_sample_filepath = (
samples_path(
new_workdir,
checkpoint=checkpoint,
input_xfm=input_xfm,
dataset=dataset,
split=split,
ensemble_member=ensemble_member,
)
/ sample_filepath.name
)

logger.info(f"Saving to {processed_sample_filepath}")
processed_sample_filepath.parent.mkdir(parents=True, exist_ok=True)
processed_samples_ds.to_netcdf(processed_sample_filepath)


@app.command()
def filter(
workdir: Path,
Expand Down Expand Up @@ -103,44 +148,53 @@ def qm(
)[0]
)["pred_pr"]

ml_eval_samples_dirpath = samples_path(
workdir,
checkpoint=checkpoint,
input_xfm=eval_input_xfm,
dataset=eval_dataset,
split=split,
ensemble_member=ensemble_member,
)
logger.info(f"QMapping samplesin {ml_eval_samples_dirpath}")
for sample_filepath in samples_glob(ml_eval_samples_dirpath):
logger.info(f"Working on {sample_filepath}")
# open the samples to be qmapped
ml_eval_ds = xr.open_dataset(sample_filepath)

def process_samples(ds):
# do the qmapping
qmapped_eval_da = xrqm(sim_train_da, ml_train_da, ml_eval_ds["pred_pr"])
qmapped_eval_da = xrqm(sim_train_da, ml_train_da, ds["pred_pr"])

qmapped_eval_ds = ml_eval_ds.copy()
qmapped_eval_ds["pred_pr"] = qmapped_eval_da
processed_samples_ds = ds.copy()
processed_samples_ds["pred_pr"] = qmapped_eval_da

# save output
new_workdir = workdir / "postprocess" / "qm"
return processed_samples_ds

qmapped_sample_filepath = (
samples_path(
new_workdir,
checkpoint=checkpoint,
input_xfm=eval_input_xfm,
dataset=eval_dataset,
split=split,
ensemble_member=ensemble_member,
)
/ sample_filepath.name
process_each_sample(
workdir,
checkpoint,
eval_dataset,
ensemble_member,
eval_input_xfm,
split,
process_samples,
new_workdir=workdir / "postprocess" / "qm",
)


@app.command()
def gcmify(
workdir: Path,
checkpoint: str = typer.Option(...),
dataset: str = typer.Option(...),
input_xfm: str = typer.Option(...),
split: str = typer.Option(...),
ensemble_member: str = typer.Option(...),
):
def process_samples(ds):
ds = to_gcm_domain(ds.sel(ensemble_member=ensemble_member))
ds["pred_pr"] = ds["pred_pr"].expand_dims(
{"ensemble_member": [ensemble_member]}
)
return ds

logger.info(f"Saving to {qmapped_sample_filepath}")
qmapped_sample_filepath.parent.mkdir(parents=True, exist_ok=True)
qmapped_eval_ds.to_netcdf(qmapped_sample_filepath)
process_each_sample(
workdir,
checkpoint,
dataset,
ensemble_member,
input_xfm,
split,
process_samples,
new_workdir=workdir / "postprocess" / "gcm-grid",
)


@app.command()
Expand Down
Loading

0 comments on commit 5ea8e6a

Please sign in to comment.