Skip to content

Commit

Permalink
refactor checking preds
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Dec 14, 2023
1 parent 70f4599 commit 275c988
Showing 1 changed file with 42 additions and 24 deletions.
66 changes: 42 additions & 24 deletions bin/check_preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,52 +6,54 @@

fpaths = sys.argv[1:]

PRED_VAR = "pred_pr"
META_VARS = (
"rotated_latitude_longitude",
"time_bnds",
"grid_latitude_bnds",
"grid_longitude_bnds",
)

def check(pred_ds):
PRED_PR_ATTRS = {
"grid_mapping": "rotated_latitude_longitude",
"standard_name": "pred_pr",
"units": "kg m-2 s-1",
}

pred_var = "pred_pr"
meta_vars = (
"rotated_latitude_longitude",
"time_bnds",
"grid_latitude_bnds",
"grid_longitude_bnds",
)

errors = []
def fix(pred_ds):
pred_ds[PRED_VAR] = pred_ds[PRED_VAR].assign_attrs(PRED_PR_ATTRS)
for var in META_VARS:
pred_ds[var] = pred_ds[var].isel(ensemble_member=0, time=0)

return pred_ds


pred_pr_attrs = {
"grid_mapping": "rotated_latitude_longitude",
"standard_name": "pred_pr",
"units": "kg m-2 s-1",
}
def check(pred_ds):
errors = []

try:
assert (
pred_ds[pred_var].attrs == pred_pr_attrs
), f"Bad attrs on {pred_var}: {pred_ds[pred_var].attrs}"
pred_ds[PRED_VAR].attrs == PRED_PR_ATTRS
), f"Bad attrs on {PRED_VAR}: {pred_ds[PRED_VAR].attrs}"
except AssertionError as e:
errors.append(e)

# pred_ds[var] = pred_ds[var].assign_attrs(pred_pr_attrs)

for var in meta_vars:
# print(var, pred_ds[var].dims)
for var in META_VARS:
try:
assert ("ensemble_member" not in pred_ds[var].dims) and (
"time" not in pred_ds[var].dims
), f"Bad dims on {var}: {pred_ds[var].dims}"
except AssertionError as e:
errors.append(e)
# pred_ds[var] = pred_ds[var].isel(ensemble_member=0, time=0)
# print(var, pred_ds[var].dims)

return errors


with logging_redirect_tqdm():
with tqdm(
total=len(fpaths),
desc=f"CHecking prediction files",
desc=f"Checking prediction files",
unit=" files",
) as pbar:
for fpath in fpaths:
Expand All @@ -65,4 +67,20 @@ def check(pred_ds):
print(e)
pbar.update(1)

# print(pred_ds)
# with logging_redirect_tqdm():
# with tqdm(
# total=len(fpaths),
# desc=f"Fixing prediction files",
# unit=" files",
# ) as pbar:
# for fpath in fpaths:
# pred_ds = xr.open_dataset(fpath)
# # import pdb; pdb.set_trace()
# pred_ds = fix(pred_ds)
# errors = check(pred_ds)

# if len(errors) > 0:
# print(f"Errors in {fpath}:")
# for e in errors:
# print(e)
# pbar.update(1)

0 comments on commit 275c988

Please sign in to comment.