Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raw gcm pr #30

Merged
merged 22 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
649bd77
add a CLI command for converting all samples of a model on a dataset…
henryaddison Nov 29, 2023
f0a4364
sample files don't have sample_id but they do have ensemble_member dim
henryaddison Nov 30, 2023
dd9d3f7
update SelectGCMDomain interface
henryaddison Nov 30, 2023
c57cfc9
map doesn't behave with to_gcm_domain
henryaddison Dec 1, 2023
a2f033c
oops left a map call that should have been removed
henryaddison Dec 1, 2023
9fb2042
OK got to expand dims for entire dataset
henryaddison Dec 1, 2023
d0fc90e
ahh no it was missing square brakets to define the value
henryaddison Dec 1, 2023
70f4599
a helper script to check predictions
henryaddison Dec 6, 2023
275c988
refactor checking preds
henryaddison Dec 14, 2023
708ce08
fix the prediction files
henryaddison Dec 14, 2023
a08c581
save the fixed prediction files too
henryaddison Dec 14, 2023
b0cb375
make fixing idempotent
henryaddison Dec 14, 2023
cc39e01
disable saving fixed pred
henryaddison Dec 14, 2023
f48d2b5
re-enable saving and make sure close file after initial checks
henryaddison Dec 14, 2023
a32414d
update predict to ensure it has the correct attributes on pred_pr
henryaddison Dec 14, 2023
18b2fc2
just check for issues with prediction files in helper script
henryaddison Dec 16, 2023
2db9844
make sure no nulls when gcmifying
henryaddison Dec 16, 2023
e2b7d57
rename x to target in batch var names
henryaddison Jan 4, 2024
394454e
add script for checking number of parameters in a model
henryaddison Jan 9, 2024
8894f7f
allow for setting of which dataset a transform should be fitted to
henryaddison Feb 14, 2024
b2e5d88
set a default value for input transform dataset config
henryaddison Feb 14, 2024
3e2646a
need to unlock the config to set input_transform_dataset
henryaddison Feb 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions bin/check_preds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import sys

from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
import xarray as xr

fpaths = sys.argv[1:]

PRED_VAR = "pred_pr"
META_VARS = (
"rotated_latitude_longitude",
"time_bnds",
"grid_latitude_bnds",
"grid_longitude_bnds",
)

PRED_PR_ATTRS = {
"grid_mapping": "rotated_latitude_longitude",
"standard_name": "pred_pr",
"units": "kg m-2 s-1",
}


def fix(pred_ds):
pred_ds[PRED_VAR] = pred_ds[PRED_VAR].assign_attrs(PRED_PR_ATTRS)
for var in META_VARS:
if "ensemble_member" in pred_ds[var].dims:
pred_ds[var] = pred_ds[var].isel(ensemble_member=0)
if "time" in pred_ds[var].dims:
pred_ds[var] = pred_ds[var].isel(time=0)

return pred_ds


def check(pred_ds):
errors = []

try:
assert (
pred_ds[PRED_VAR].attrs == PRED_PR_ATTRS
), f"Bad attrs on {PRED_VAR}: {pred_ds[PRED_VAR].attrs}"
except AssertionError as e:
errors.append(e)

for var in META_VARS:
try:
assert ("ensemble_member" not in pred_ds[var].dims) and (
"time" not in pred_ds[var].dims
), f"Bad dims on {var}: {pred_ds[var].dims}"
except AssertionError as e:
errors.append(e)

return errors


with logging_redirect_tqdm():
with tqdm(
total=len(fpaths),
desc=f"Checking prediction files",
unit=" files",
) as pbar:
for fpath in fpaths:
with xr.open_dataset(fpath) as pred_ds:
errors = check(pred_ds)

if len(errors) > 0:
print(f"Errors in {fpath}:")
for e in errors:
print(e)

# pred_ds = fix(pred_ds)
# errors = check(pred_ds)

# if len(errors) > 0:
# print(f"Errors in {fpath}:")
# for e in errors:
# print(e)
# else:
# pred_ds.to_netcdf(fpath)
pbar.update(1)
71 changes: 71 additions & 0 deletions bin/model-size
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python

import os
from pathlib import Path

from ml_collections import config_dict

import typer
import logging
import yaml

from ml_downscaling_emulator.score_sde_pytorch_hja22.models.location_params import (
LocationParams,
)

from ml_downscaling_emulator.score_sde_pytorch_hja22.models import utils as mutils

# from score_sde_pytorch_hja22.models import ncsnv2
# from score_sde_pytorch_hja22.models import ncsnpp
from ml_downscaling_emulator.score_sde_pytorch_hja22.models import cncsnpp # noqa: F401
from ml_downscaling_emulator.score_sde_pytorch_hja22.models import cunet # noqa: F401

# from score_sde_pytorch_hja22.models import ddpm as ddpm_model
from ml_downscaling_emulator.score_sde_pytorch_hja22.models import ( # noqa: F401
layerspp, # noqa: F401
) # noqa: F401
from ml_downscaling_emulator.score_sde_pytorch_hja22.models import layers # noqa: F401
from ml_downscaling_emulator.score_sde_pytorch_hja22.models import ( # noqa: F401
normalization, # noqa: F401
) # noqa: F401

logger = logging.getLogger()
logger.setLevel("INFO")

app = typer.Typer()


def load_model(config):
logger.info(f"Loading model from config")
score_model = mutils.get_model(config.model.name)(config)
location_params = LocationParams(
config.model.loc_spec_channels, config.data.image_size
)

return score_model, location_params


def load_config(config_path):
logger.info(f"Loading config from {config_path}")
with open(config_path) as f:
config = config_dict.ConfigDict(yaml.unsafe_load(f))

return config


@app.command()
def main(
workdir: Path,
):
config_path = os.path.join(workdir, "config.yml")
config = load_config(config_path)
score_model, location_params = load_model(config)
num_score_model_parameters = sum(p.numel() for p in score_model.parameters())
num_location_parameters = sum(p.numel() for p in location_params.parameters())

typer.echo(f"Score model has {num_score_model_parameters} parameters")
typer.echo(f"Location parameters have {num_location_parameters} parameters")


if __name__ == "__main__":
app()
10 changes: 9 additions & 1 deletion bin/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def np_samples_to_xr(np_samples, target_transform, coords, cf_data_vars):
xr.Dataset(data_vars=data_vars, coords=coords, attrs={})
)
samples_ds = samples_ds.rename({"target_pr": "pred_pr"})
samples_ds["pred_pr"] = samples_ds["pred_pr"].assign_attrs(pred_pr_attrs)
return samples_ds


Expand Down Expand Up @@ -239,21 +240,27 @@ def main(
epoch: int = typer.Option(...),
batch_size: int = None,
num_samples: int = 3,
input_transform_dataset: str = None,
input_transform_key: str = None,
ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER,
):
config_path = os.path.join(workdir, "config.yml")
config = load_config(config_path)
if batch_size is not None:
config.eval.batch_size = batch_size
with config.unlocked():
if input_transform_dataset is not None:
config.data.input_transform_dataset = input_transform_dataset
else:
config.data.input_transform_dataset = dataset
if input_transform_key is not None:
config.data.input_transform_key = input_transform_key

output_dirpath = samples_path(
workdir=workdir,
checkpoint=f"epoch-{epoch}",
dataset=dataset,
input_xfm=config.data.input_transform_key,
input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}",
split=split,
ensemble_member=ensemble_member,
)
Expand All @@ -269,6 +276,7 @@ def main(
eval_dl, _, target_transform = get_dataloader(
dataset,
config.data.dataset_name,
config.data.input_transform_dataset,
config.data.input_transform_key,
config.data.target_transform_key,
transform_dir,
Expand Down
8 changes: 7 additions & 1 deletion src/ml_downscaling_emulator/bin/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def sample(
epoch: int = typer.Option(...),
batch_size: int = None,
num_samples: int = 1,
input_transform_dataset: str = None,
input_transform_key: str = None,
ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER,
):
Expand All @@ -70,14 +71,18 @@ def sample(

if batch_size is not None:
config.eval.batch_size = batch_size
if input_transform_dataset is not None:
config.data.input_transform_dataset = input_transform_dataset
else:
config.data.input_transform_dataset = dataset
if input_transform_key is not None:
config.data.input_transform_key = input_transform_key

output_dirpath = samples_path(
workdir=workdir,
checkpoint=f"epoch-{epoch}",
dataset=dataset,
input_xfm=config.data.input_transform_key,
input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}",
split=split,
ensemble_member=ensemble_member,
)
Expand All @@ -88,6 +93,7 @@ def sample(
eval_dl, _, target_transform = get_dataloader(
dataset,
config.data.dataset_name,
config.data.input_transform_dataset,
config.data.input_transform_key,
config.data.target_transform_key,
transform_dir,
Expand Down
120 changes: 87 additions & 33 deletions src/ml_downscaling_emulator/bin/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import logging
import os
from pathlib import Path
from typing import Callable
import typer
import xarray as xr

from mlde_utils import samples_path, samples_glob, TIME_PERIODS
from mlde_utils.training.dataset import open_raw_dataset_split

from ml_downscaling_emulator.postprocess import xrqm
from ml_downscaling_emulator.postprocess import xrqm, to_gcm_domain

logging.basicConfig(
level=logging.INFO,
Expand All @@ -25,6 +26,50 @@ def callback():
pass


def process_each_sample(
workdir: Path,
checkpoint: str,
dataset: str,
ensemble_member: str,
input_xfm: str,
split: str,
processing_func: Callable,
new_workdir: Path,
):
samples_dirpath = samples_path(
workdir,
checkpoint=checkpoint,
input_xfm=input_xfm,
dataset=dataset,
split=split,
ensemble_member=ensemble_member,
)
logger.info(f"Iterating on samples in {samples_dirpath}")
for sample_filepath in samples_glob(samples_dirpath):
logger.info(f"Working on {sample_filepath}")
# open the samples
samples_ds = xr.open_dataset(sample_filepath)

processed_samples_ds = processing_func(samples_ds)

# save output
processed_sample_filepath = (
samples_path(
new_workdir,
checkpoint=checkpoint,
input_xfm=input_xfm,
dataset=dataset,
split=split,
ensemble_member=ensemble_member,
)
/ sample_filepath.name
)

logger.info(f"Saving to {processed_sample_filepath}")
processed_sample_filepath.parent.mkdir(parents=True, exist_ok=True)
processed_samples_ds.to_netcdf(processed_sample_filepath)


@app.command()
def filter(
workdir: Path,
Expand Down Expand Up @@ -103,44 +148,53 @@ def qm(
)[0]
)["pred_pr"]

ml_eval_samples_dirpath = samples_path(
workdir,
checkpoint=checkpoint,
input_xfm=eval_input_xfm,
dataset=eval_dataset,
split=split,
ensemble_member=ensemble_member,
)
logger.info(f"QMapping samplesin {ml_eval_samples_dirpath}")
for sample_filepath in samples_glob(ml_eval_samples_dirpath):
logger.info(f"Working on {sample_filepath}")
# open the samples to be qmapped
ml_eval_ds = xr.open_dataset(sample_filepath)

def process_samples(ds):
# do the qmapping
qmapped_eval_da = xrqm(sim_train_da, ml_train_da, ml_eval_ds["pred_pr"])
qmapped_eval_da = xrqm(sim_train_da, ml_train_da, ds["pred_pr"])

qmapped_eval_ds = ml_eval_ds.copy()
qmapped_eval_ds["pred_pr"] = qmapped_eval_da
processed_samples_ds = ds.copy()
processed_samples_ds["pred_pr"] = qmapped_eval_da

# save output
new_workdir = workdir / "postprocess" / "qm"
return processed_samples_ds

qmapped_sample_filepath = (
samples_path(
new_workdir,
checkpoint=checkpoint,
input_xfm=eval_input_xfm,
dataset=eval_dataset,
split=split,
ensemble_member=ensemble_member,
)
/ sample_filepath.name
process_each_sample(
workdir,
checkpoint,
eval_dataset,
ensemble_member,
eval_input_xfm,
split,
process_samples,
new_workdir=workdir / "postprocess" / "qm",
)


@app.command()
def gcmify(
workdir: Path,
checkpoint: str = typer.Option(...),
dataset: str = typer.Option(...),
input_xfm: str = typer.Option(...),
split: str = typer.Option(...),
ensemble_member: str = typer.Option(...),
):
def process_samples(ds):
ds = to_gcm_domain(ds.sel(ensemble_member=ensemble_member))
ds["pred_pr"] = ds["pred_pr"].expand_dims(
{"ensemble_member": [ensemble_member]}
)
return ds

logger.info(f"Saving to {qmapped_sample_filepath}")
qmapped_sample_filepath.parent.mkdir(parents=True, exist_ok=True)
qmapped_eval_ds.to_netcdf(qmapped_sample_filepath)
process_each_sample(
workdir,
checkpoint,
dataset,
ensemble_member,
input_xfm,
split,
process_samples,
new_workdir=workdir / "postprocess" / "gcm-grid",
)


@app.command()
Expand Down
Loading
Loading