forked from yang-song/score_sde_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #30 from henryaddison/raw-gcm-pr
Raw gcm pr
- Loading branch information
Showing
10 changed files
with
290 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import sys | ||
|
||
from tqdm import tqdm | ||
from tqdm.contrib.logging import logging_redirect_tqdm | ||
import xarray as xr | ||
|
||
fpaths = sys.argv[1:] | ||
|
||
PRED_VAR = "pred_pr" | ||
META_VARS = ( | ||
"rotated_latitude_longitude", | ||
"time_bnds", | ||
"grid_latitude_bnds", | ||
"grid_longitude_bnds", | ||
) | ||
|
||
PRED_PR_ATTRS = { | ||
"grid_mapping": "rotated_latitude_longitude", | ||
"standard_name": "pred_pr", | ||
"units": "kg m-2 s-1", | ||
} | ||
|
||
|
||
def fix(pred_ds): | ||
pred_ds[PRED_VAR] = pred_ds[PRED_VAR].assign_attrs(PRED_PR_ATTRS) | ||
for var in META_VARS: | ||
if "ensemble_member" in pred_ds[var].dims: | ||
pred_ds[var] = pred_ds[var].isel(ensemble_member=0) | ||
if "time" in pred_ds[var].dims: | ||
pred_ds[var] = pred_ds[var].isel(time=0) | ||
|
||
return pred_ds | ||
|
||
|
||
def check(pred_ds): | ||
errors = [] | ||
|
||
try: | ||
assert ( | ||
pred_ds[PRED_VAR].attrs == PRED_PR_ATTRS | ||
), f"Bad attrs on {PRED_VAR}: {pred_ds[PRED_VAR].attrs}" | ||
except AssertionError as e: | ||
errors.append(e) | ||
|
||
for var in META_VARS: | ||
try: | ||
assert ("ensemble_member" not in pred_ds[var].dims) and ( | ||
"time" not in pred_ds[var].dims | ||
), f"Bad dims on {var}: {pred_ds[var].dims}" | ||
except AssertionError as e: | ||
errors.append(e) | ||
|
||
return errors | ||
|
||
|
||
with logging_redirect_tqdm(): | ||
with tqdm( | ||
total=len(fpaths), | ||
desc=f"Checking prediction files", | ||
unit=" files", | ||
) as pbar: | ||
for fpath in fpaths: | ||
with xr.open_dataset(fpath) as pred_ds: | ||
errors = check(pred_ds) | ||
|
||
if len(errors) > 0: | ||
print(f"Errors in {fpath}:") | ||
for e in errors: | ||
print(e) | ||
|
||
# pred_ds = fix(pred_ds) | ||
# errors = check(pred_ds) | ||
|
||
# if len(errors) > 0: | ||
# print(f"Errors in {fpath}:") | ||
# for e in errors: | ||
# print(e) | ||
# else: | ||
# pred_ds.to_netcdf(fpath) | ||
pbar.update(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#!/usr/bin/env python | ||
|
||
import os | ||
from pathlib import Path | ||
|
||
from ml_collections import config_dict | ||
|
||
import typer | ||
import logging | ||
import yaml | ||
|
||
from ml_downscaling_emulator.score_sde_pytorch_hja22.models.location_params import ( | ||
LocationParams, | ||
) | ||
|
||
from ml_downscaling_emulator.score_sde_pytorch_hja22.models import utils as mutils | ||
|
||
# from score_sde_pytorch_hja22.models import ncsnv2 | ||
# from score_sde_pytorch_hja22.models import ncsnpp | ||
from ml_downscaling_emulator.score_sde_pytorch_hja22.models import cncsnpp # noqa: F401 | ||
from ml_downscaling_emulator.score_sde_pytorch_hja22.models import cunet # noqa: F401 | ||
|
||
# from score_sde_pytorch_hja22.models import ddpm as ddpm_model | ||
from ml_downscaling_emulator.score_sde_pytorch_hja22.models import ( # noqa: F401 | ||
layerspp, # noqa: F401 | ||
) # noqa: F401 | ||
from ml_downscaling_emulator.score_sde_pytorch_hja22.models import layers # noqa: F401 | ||
from ml_downscaling_emulator.score_sde_pytorch_hja22.models import ( # noqa: F401 | ||
normalization, # noqa: F401 | ||
) # noqa: F401 | ||
|
||
logger = logging.getLogger() | ||
logger.setLevel("INFO") | ||
|
||
app = typer.Typer() | ||
|
||
|
||
def load_model(config): | ||
logger.info(f"Loading model from config") | ||
score_model = mutils.get_model(config.model.name)(config) | ||
location_params = LocationParams( | ||
config.model.loc_spec_channels, config.data.image_size | ||
) | ||
|
||
return score_model, location_params | ||
|
||
|
||
def load_config(config_path): | ||
logger.info(f"Loading config from {config_path}") | ||
with open(config_path) as f: | ||
config = config_dict.ConfigDict(yaml.unsafe_load(f)) | ||
|
||
return config | ||
|
||
|
||
@app.command() | ||
def main( | ||
workdir: Path, | ||
): | ||
config_path = os.path.join(workdir, "config.yml") | ||
config = load_config(config_path) | ||
score_model, location_params = load_model(config) | ||
num_score_model_parameters = sum(p.numel() for p in score_model.parameters()) | ||
num_location_parameters = sum(p.numel() for p in location_params.parameters()) | ||
|
||
typer.echo(f"Score model has {num_score_model_parameters} parameters") | ||
typer.echo(f"Location parameters have {num_location_parameters} parameters") | ||
|
||
|
||
if __name__ == "__main__": | ||
app() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.