Skip to content

Commit

Permalink
add support for setting up era5 initial ensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
StevePny committed Oct 3, 2024
1 parent 93aff5c commit becfa0f
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 22 deletions.
2 changes: 1 addition & 1 deletion dabench/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import data, vector, model, observer, obsop, dacycler, _suppl_data
from . import data, vector, model, observer, obsop, dacycler, dasupport, _suppl_data, utils
5 changes: 5 additions & 0 deletions dabench/dasupport/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .generate_era5_ensemble import GenEra5Ens

__all__ = [
'GenEra5Ens',
]
Empty file added dabench/dasupport/__pycache__
Empty file.
47 changes: 26 additions & 21 deletions dabench/dasupport/generate_era5_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import xarray as xr
from dateutil.relativedelta import relativedelta

from helpers.timing import report_timing
from ..utils.timing import report_timing

# Selected vars for ERA5 ensemble
# This will reduce the number of model fields processed and stored in the ensemble
Expand Down Expand Up @@ -106,7 +106,7 @@ def parse_arguments():
#%% Define the initial ensemble


def define_init_ensemble(
def _define_init_ensemble(
ensemble_size, init_ensemble_start_date, init_ensemble_sample_strategy="multi_year"
):

Expand All @@ -130,14 +130,15 @@ def define_init_ensemble(
return init_ensemble_member_dates


def main(
def GenEra5Ens(
date_format:str="%Y%m%dZ%H",
atmosphere_ensemble_s3_key:str=None,
target_date:datetime=datetime.strptime("19990101Z00","%Y%m%dZ%H"),
sample_strategy:str="consecutive_day",
start_date:datetime=datetime.strptime("19981231Z00","%Y%m%dZ%H"),
ensemble_size:int=4,
era5_path:str="gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3",
verbose:bool=False,
):

#%% Set up the gcp access to era5
Expand All @@ -146,7 +147,7 @@ def main(
ds_era5 = xr.open_zarr(gcs.get_mapper(era5_path), chunks=None)
else:
raise Exception("Non-GCP source for ERA5 not yet supported. EXITING...")
report_timing(timing_label="build_test_ensemble_era5:: access remote zarr store")
report_timing(timing_label="GenEra5Ens:: access remote zarr store")

#%% Reorder the latitudes
# Following:
Expand All @@ -157,7 +158,7 @@ def main(
assert ds_era5.latitude[0] < ds_era5.latitude[-1]

#%% Determine dates for initial ensemble sampling
init_ensemble_member_dates = define_init_ensemble(
init_ensemble_member_dates = _define_init_ensemble(
ensemble_size=ensemble_size,
init_ensemble_start_date=start_date,
init_ensemble_sample_strategy=sample_strategy,
Expand All @@ -168,18 +169,20 @@ def main(
#%% Sample from era5
ds_init_ens = ds_era5[ERA5_CONTROL_VARIABLES].sel(time=init_ensemble_member_dates)
report_timing(
timing_label="build_test_ensemble_era5:: select time steps as ensemble members"
timing_label="GenEra5Ens:: select time steps as ensemble members"
)
print(ds_init_ens)
if verbose:
print(ds_init_ens)

#%% Update time to target and add ensemble dimension
ds_init_ens = ds_init_ens.rename_dims(dims_dict={"time": "member"})
ds_init_ens["member"] = range(ensemble_size)
ds_init_ens = ds_init_ens.drop_vars("time")
report_timing(
timing_label="build_test_ensemble_era5:: add member dimension to replace time"
timing_label="GenEra5Ens:: add member dimension to replace time"
)
print(ds_init_ens)
if verbose:
print(ds_init_ens)

#%% Select target date from era5 for recentering the ensemble
ds_target = ds_era5[ERA5_CONTROL_VARIABLES].sel(time=target_date)
Expand All @@ -190,43 +193,45 @@ def main(
ds_init_ens['ws10n'] = (ds_init_ens['10m_u_component_of_neutral_wind']**2 + ds_init_ens['10m_v_component_of_neutral_wind']**2)**(0.5)
ds_target['ws10n'] = (ds_target['10m_u_component_of_neutral_wind']**2 + ds_target['10m_v_component_of_neutral_wind']**2)**(0.5)
report_timing(
timing_label="build_test_ensemble_era5:: computing neutral wind speeds at 10m (ws10n)"
timing_label="GenEra5Ens:: computing neutral wind speeds at 10m (ws10n)"
)
if ('10m_u_component_of_wind' in ERA5_CONTROL_VARIABLES and
'10m_v_component_of_wind' in ERA5_CONTROL_VARIABLES):
ds_init_ens['ws10m'] = (ds_init_ens['10m_u_component_of_wind']**2 + ds_init_ens['10m_v_component_of_wind']**2)**(0.5)
ds_target['ws10m'] = (ds_target['10m_u_component_of_wind']**2 + ds_target['10m_v_component_of_wind']**2)**(0.5)
report_timing(
timing_label="build_test_ensemble_era5:: computing diagnostic wind speeds at 10m (ws10m)"
timing_label="GenEra5Ens:: computing diagnostic wind speeds at 10m (ws10m)"
)

#%% Recenter ensemble to target date
print(f'build_test_ensemble_era5:: re-centering ensemble with ensemble_size = {ensemble_size} to target_date = {target_date}...')
print(f'GenEra5Ens:: re-centering ensemble with ensemble_size = {ensemble_size} to target_date = {target_date}...')
ds_mean = ds_init_ens.mean(dim="member")
ds_diff = ds_target - ds_mean
ds_init_ens = ds_init_ens + ds_diff
report_timing(
timing_label="build_test_ensemble_era5:: recenter ensemble to target date"
timing_label="GenEra5Ens:: recenter ensemble to target date"
)
print(ds_init_ens)
if verbose:
print(ds_init_ens)

#%% Now add time back on as a singleton dimension
ds_init_ens = ds_init_ens.expand_dims(dim={"time": [target_date]}, axis=0)
report_timing(
timing_label="build_test_ensemble_era5:: add time dimension back on to dataset structure"
timing_label="GenEra5Ens:: add time dimension back on to dataset structure"
)
print(ds_init_ens)
if verbose:
print(ds_init_ens)

#%% Add some checks to make sure dimensions haven't changed
assert ds_era5.sizes['latitude'] == ds_init_ens.sizes['latitude']
assert ds_era5.sizes['longitude'] == ds_init_ens.sizes['longitude']
assert ds_era5.sizes['level'] == ds_init_ens.sizes['level']

#%% Upload to s3 as zarr
print('Uploading to s3 zarr...')
#%% Store to zarr (locally or on e.g. AWS s3)
print('Storing as zarr...')
ds_init_ens.to_zarr(atmosphere_ensemble_s3_key, mode="w")
report_timing(
timing_label="build_test_ensemble_era5:: upload to s3 as a new zarr store"
timing_label="GenEra5Ens:: upload to s3 as a new zarr store"
)


Expand All @@ -235,9 +240,9 @@ def main(
args = parse_arguments()

# %% Process input arguments
report_timing(timing_label="build_test_ensemble_era5:: initializing...")
report_timing(timing_label="GenEra5Ens:: initializing...")

main(
GenEra5Ens(
date_format=args.date_format,
atmosphere_ensemble_s3_key=args.atmosphere_ensemble_s3_key,
target_date=args.target_date,
Expand Down
191 changes: 191 additions & 0 deletions dabench/metrics/_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""Ensemble forecast metrics"""

import jax.numpy as jnp
from dabench.metrics import _utils


__all__ = [
'rank_histogram',
'crps_ensemble',
]


def rank_histogram(observations, forecasts, dim=None, member_dim="member"):
"""JAX array implementation of Rank Histogram
Description:
(from https://www.cawcr.gov.au/projects/verification/#Methods_for_EPS)
Answers the question: How well does the ensemble spread of the forecast represent the true variability (uncertainty) of the observations?
Also known as a "Talagrand diagram", this method checks where the verifying observation usually falls with respect to the ensemble forecast data, which is arranged in increasing order at each grid point. In an ensemble with perfect spread, each member represents an equally likely scenario, so the observation is equally likely to fall between any two members.
To construct a rank histogram, do the following:
1. At every observation (or analysis) point rank the N ensemble members from lowest to highest. This represents N+1 possible bins that the observation could fit into, including the two extremes
2. Identify which bin the observation falls into at each point
3. Tally over many observations to create a histogram of rank.
Interpretation:
Flat - ensemble spread about right to represent forecast uncertainty
U-shaped - ensemble spread too small, many observations falling outside the extremes of the ensemble
Dome-shaped - ensemble spread too large, most observations falling near the center of the ensemble
Asymmetric - ensemble contains bias
Note: A flat rank histogram does not necessarily indicate a good forecast, it only measures whether the observed probability distribution is well represented by the ensemble.
Args:
predictions (ndarray): Array of predictions
targets (ndarray): Array of targets to compare against. Shape must
be broadcastable to shape of predictions.
Returns:
[UPDATE] Float, Pearson's R correlation coefficient.
"""

# RMSD = sqrt( 1/(N+1) * sum(Sk - M/(N+1)^2) )

# See: https://github.com/xarray-contrib/xskillscore/blob/64f17fdd1816b64b9e13c3f2febb9800a7e6ed0c/xskillscore/core/probabilistic.py#L830C20-L830C76

def _rank_first(x, y):
"""Concatenates x and y and returns the rank of the
first element along the last axes"""
xy = jnp.concatenate((x[..., jnp.newaxis], y), axis=-1)
return bn.nanrankdata(xy, axis=-1)[..., 0]

if dim is not None:
if len(dim) == 0:
raise ValueError(
"At least one dimension must be supplied to compute rank histogram over"
)
if member_dim in dim:
raise ValueError(f'"{member_dim}" cannot be specified as an input to dim')

ranks = xr.apply_ufunc(
_rank_first,
observations,
forecasts,
input_core_dims=[[], [member_dim]],
dask="parallelized",
output_dtypes=[int],
)

bin_edges = jnp.arange(0.5, len(forecasts[member_dim]) + 2)
return histogram(ranks, bins=[bin_edges], bin_names=["rank"], dim=dim, bin_dim_suffix="")


def crps_ensemble(observations, forecasts, axis=-1):
"""JAX array implementation of Continuous Ranked Probability Score
(From: https://confluence.ecmwf.int/display/FUG/Section+12.B+Statistical+Concepts+-+Probabilistic+Data#:~:text=The%20Continuous%20Ranked%20Probability%20Score,the%20forecast%20is%20wholly%20inaccurate.)
A generalisation of Ranked Probability Score (RPS) is the Continuous Rank Probability Score (CRPSS) where the thresholds are continuous rather than discrete (see Nurmi, 2003; Jollife and Stephenson, 2003; Wilks, 2006). The Continuous Ranked Probability Score (CRPS) is a measure of how good forecasts are in matching observed outcomes. Where:
CRPS = 0 the forecast is wholly accurate;
CRPS = 1 the forecast is wholly inaccurate.
CRPS is calculated by comparing the Cumulative Distribution Functions (CDF) for the forecast against a reference dataset (observations, or analyses, or climatology) over a given period.
Args:
predictions (ndarray): Array of predictions
targets (ndarray): Array of targets to compare against. Shape must
be broadcastable to shape of predictions.
Returns:
[UPDATE] Float, Mean Squared Error
"""

# Integral from -inf to inf: (1/M) * sum[ S [P_j(x) - H(x - x_oj)]^2 dx ]
# where Pj, H, and x_oj are the predicted cumulative distribution for case j, the Heaviside step function,
# and the observed value, respectively.
# (see: https://www.ecmwf.int/sites/default/files/elibrary/2007/10729-ensemble-forecasting.pdf)
# with M independent cases (e.g. different dates)

# See: https://github.com/properscoring/properscoring/blob/a465b5578d4b661e662933e84fa7673a70e75e94/properscoring/_crps.py#L244

# Manage input quality
observations = jnp.asarray(observations)
forecasts = jnp.asarray(forecasts)

if axis != -1:
# Move the axis to the end
forecasts = jnp.rollaxis(forecasts, axis, start=forecasts.ndim)

if observations.shape not in [forecasts.shape, forecasts.shape[:-1]]:
raise ValueError('observations and forecasts must have matching '
'shapes or matching shapes except along `axis=%s`'
% axis)

if observations.shape == forecasts.shape:
if weights is not None:
raise ValueError('cannot supply weights unless you also supply '
'an ensemble forecast')
return abs(observations - forecasts)

# Sort forecast members by target quantity
idx = jnp.argsort(forecasts, axis=-1)
forecasts = forecasts[idx]
weights = jnp.ones_like(forecasts)

return _crps_ensemble_vectorized(observation, forecasts, weights, result)

# @guvectorize(["void(float64[:], float64[:], float64[:], float64[:])"],
# "(),(n),(n)->()", nopython=True)

@partial(jnp.vectorize, signature='(),(n),(n)->()')
def _crps_ensemble_vectorized(observation, forecasts, weights, result):
# beware: forecasts are assumed sorted in NumPy's sort order

# add asserts here:

# we index the 0th element to get the scalar value from this 0d array:
# http://numba.pydata.org/numba-doc/0.18.2/user/vectorize.html#the-guvectorize-decorator
obs = observation[0]

if jnp.isnan(obs):
result[0] = jnp.nan
return

total_weight = 0.0
for n, weight in enumerate(weights):
if jnp.isnan(forecasts[n]):
# NumPy sorts NaN to the end
break
if not weight >= 0:
# this catches NaN weights
result[0] = jnp.nan
return
total_weight += weight

obs_cdf = 0
forecast_cdf = 0
prev_forecast = 0
integral = 0

for n, forecast in enumerate(forecasts):
if jnp.isnan(forecast):
# NumPy sorts NaN to the end
if n == 0:
integral = jnp.nan
# reset for the sake of the conditional below
forecast = prev_forecast
break

if obs_cdf == 0 and obs < forecast:
integral += (obs - prev_forecast) * forecast_cdf ** 2
integral += (forecast - obs) * (forecast_cdf - 1) ** 2
obs_cdf = 1
else:
integral += ((forecast - prev_forecast)
* (forecast_cdf - obs_cdf) ** 2)

forecast_cdf += weights[n] / total_weight
prev_forecast = forecast

if obs_cdf == 0:
# forecast can be undefined here if the loop body is never executed
# (because forecasts have size 0), but don't worry about that because
# we want to raise an error in that case, anyways
integral += obs - forecast

result[0] = integral


5 changes: 5 additions & 0 deletions dabench/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .timing import report_timing

__all__ = [
'report_timing',
]
Loading

0 comments on commit becfa0f

Please sign in to comment.