diff --git a/bin/check_preds.py b/bin/check_preds.py index f68a9ea3b..0d8c35203 100644 --- a/bin/check_preds.py +++ b/bin/check_preds.py @@ -6,44 +6,46 @@ fpaths = sys.argv[1:] +PRED_VAR = "pred_pr" +META_VARS = ( + "rotated_latitude_longitude", + "time_bnds", + "grid_latitude_bnds", + "grid_longitude_bnds", +) -def check(pred_ds): +PRED_PR_ATTRS = { + "grid_mapping": "rotated_latitude_longitude", + "standard_name": "pred_pr", + "units": "kg m-2 s-1", +} - pred_var = "pred_pr" - meta_vars = ( - "rotated_latitude_longitude", - "time_bnds", - "grid_latitude_bnds", - "grid_longitude_bnds", - ) - errors = [] +def fix(pred_ds): + pred_ds[PRED_VAR] = pred_ds[PRED_VAR].assign_attrs(PRED_PR_ATTRS) + for var in META_VARS: + pred_ds[var] = pred_ds[var].isel(ensemble_member=0, time=0) + + return pred_ds + - pred_pr_attrs = { - "grid_mapping": "rotated_latitude_longitude", - "standard_name": "pred_pr", - "units": "kg m-2 s-1", - } +def check(pred_ds): + errors = [] try: assert ( - pred_ds[pred_var].attrs == pred_pr_attrs - ), f"Bad attrs on {pred_var}: {pred_ds[pred_var].attrs}" + pred_ds[PRED_VAR].attrs == PRED_PR_ATTRS + ), f"Bad attrs on {PRED_VAR}: {pred_ds[PRED_VAR].attrs}" except AssertionError as e: errors.append(e) - # pred_ds[var] = pred_ds[var].assign_attrs(pred_pr_attrs) - - for var in meta_vars: - # print(var, pred_ds[var].dims) + for var in META_VARS: try: assert ("ensemble_member" not in pred_ds[var].dims) and ( "time" not in pred_ds[var].dims ), f"Bad dims on {var}: {pred_ds[var].dims}" except AssertionError as e: errors.append(e) - # pred_ds[var] = pred_ds[var].isel(ensemble_member=0, time=0) - # print(var, pred_ds[var].dims) return errors @@ -51,7 +53,7 @@ def check(pred_ds): with logging_redirect_tqdm(): with tqdm( total=len(fpaths), - desc=f"CHecking prediction files", + desc=f"Checking prediction files", unit=" files", ) as pbar: for fpath in fpaths: @@ -65,4 +67,20 @@ def check(pred_ds): print(e) pbar.update(1) - # print(pred_ds) +# with logging_redirect_tqdm(): +# with tqdm( +# total=len(fpaths), +# desc=f"Fixing prediction files", +# unit=" files", +# ) as pbar: +# for fpath in fpaths: +# pred_ds = xr.open_dataset(fpath) +# # import pdb; pdb.set_trace() +# pred_ds = fix(pred_ds) +# errors = check(pred_ds) + +# if len(errors) > 0: +# print(f"Errors in {fpath}:") +# for e in errors: +# print(e) +# pbar.update(1)