Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 710185289
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Dec 28, 2024
1 parent 7ea2644 commit 6405bd8
Showing 1 changed file with 1 addition and 199 deletions.
200 changes: 1 addition & 199 deletions swirl_dynamics/projects/debiasing/rectified_flow/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,205 +504,7 @@ def _maybe_expands_dims(self, x: np.ndarray) -> np.ndarray:
)


# TODO: This class is redundant remove.
class DataSourceEnsembleWithClimatologyInference:
"""An inferece data source that loads ensemble LENS2 data."""

def __init__(
self,
date_range: tuple[str, str],
input_dataset: epath.PathLike,
input_variable_names: Sequence[str],
input_member_indexer: tuple[Mapping[str, Any], ...],
input_climatology: epath.PathLike,
output_variables: Mapping[str, Any],
output_climatology: epath.PathLike,
dims_order_input: Sequence[str] | None = None,
dims_order_input_stats: Sequence[str] | None = None,
dims_order_output_stats: Sequence[str] | None = None,
resample_at_nan: bool = False,
resample_seed: int = 9999,
time_stamps: bool = False,
):
"""Data source constructor.
Args:
date_range: The date range (in days) applied. Data not falling in this
range is ignored.
input_dataset: The path of a zarr dataset containing the input data.
input_variable_names: The variables to yield from the input dataset.
input_member_indexer: The name of the ensemble member to sample from, it
should be tuple of dictionaries with the key "member" and the value the
name of the member, to adhere to xarray formating, For example:
[{"member": "cmip6_1001_001"}, {"member": "cmip6_1021_002"}, ...]
input_climatology: The path of a zarr dataset containing the input
statistics.
output_variables: The variables to yield from the output dataset.
output_climatology: The path of a zarr dataset containing the output
statistics.
dims_order_input: Order of the dimensions (time, member, lat, lon, fields)
of the input variables. If None, the dimensions are not changed from the
order in the xarray dataset
dims_order_input_stats: Order of the dimensions (member, lat, lon, fields)
of the statistics of the input variables. If None, the dimensions are
not changed from the order in the xarray dataset
dims_order_output_stats: Order of the dimensions (member, lat, lon, field)
of the statistics of the output variables. If None, the dimensions are
not changed from the order in the xarray dataset
resample_at_nan: Whether to resample when NaN is detected in the data.
resample_seed: The random seed for resampling.
time_stamps: Wheter to add the time stamps to the samples.
"""

# Using lens as input, they need to be modified.
input_variables = {v: input_member_indexer for v in input_variable_names}

# Computing the date_range
date_range = jax.tree.map(lambda x: np.datetime64(x, "D"), date_range)

# Open the datasets
input_ds = xrts.open_zarr(input_dataset).sel(time=slice(*date_range))
# These contain the climatologies
input_stats_ds = xrts.open_zarr(input_climatology)
output_stats_ds = xrts.open_zarr(output_climatology)

# Transpose the datasets if necessary
if dims_order_input:
input_ds = input_ds.transpose(*dims_order_input)
if dims_order_output_stats:
output_stats_ds = output_stats_ds.transpose(*dims_order_output_stats)
if dims_order_input_stats:
input_stats_ds = input_stats_ds.transpose(*dims_order_input_stats)

# selecting the input_arrays
self._input_arrays = {}
for v, indexers in input_variables.items():
self._input_arrays[v] = {}
for index in indexers:
idx = tuple(index.values())[0]
self._input_arrays[v][idx] = input_ds[v].sel(index)

# Building the dictionary of xarray datasets to be used for the climatology.
# Climatological mean of LENS2.
self._input_mean_arrays = {}
for v, indexers in input_variables.items():
self._input_mean_arrays[v] = {}
for index in indexers:
idx = tuple(index.values())[0]
self._input_mean_arrays[v][idx] = input_stats_ds[v].sel(index)

# Climatological std of LENS2.
self._input_std_arrays = {}
for v, indexers in input_variables.items():
self._input_std_arrays[v] = {}
for index in indexers:
idx = tuple(index.values())[0]
self._input_std_arrays[v][idx] = input_stats_ds[v + "_std"].sel(index)

# Build the output arrays for the different output variables.
# Climatological mean of ERA5.
self._output_mean_arrays = {}
for v, indexers in output_variables.items():
self._output_mean_arrays[v] = output_stats_ds[v].sel(
indexers
) # pytype : disable=wrong-arg-types

# Climatological std of ERA5.
self._output_std_arrays = {}
for v, indexers in output_variables.items():
self._output_std_arrays[v] = output_stats_ds[v + "_std"].sel(
indexers
) # pytype : disable=wrong-arg-types

# The times can be slightly off due to the leap years.
# Member index.
self._indexes = [ind["member"] for ind in input_member_indexer]
self._len_time = input_ds.dims["time"]
self._len = len(self._indexes) * self._len_time
self._input_time_array = xrts.read(input_ds["time"]).data
self._resample_at_nan = resample_at_nan
self._resample_seed = resample_seed
self._time_stamps = time_stamps

def __len__(self):
return self._len

def __getitem__(self, record_key: SupportsIndex) -> Mapping[str, Any]:
"""Retrieves record and retry if NaN found."""
idx = record_key.__index__()
if not idx < self._len:
raise ValueError(f"Index out of range: {idx} / {self._len - 1}")

item = self.get_item(idx)

if self._resample_at_nan:
while np.isnan(item["input"]).any() or np.isnan(item["output"]).any():

rng = np.random.default_rng(self._resample_seed + idx)
resample_idx = rng.integers(0, len(self))
item = self.get_item(resample_idx)

return item

def get_item(self, idx: int) -> Mapping[str, Any]:
"""Returns the data record for a given index."""
item = {}

# Checking the index for the member and time (of the year)
idx_member = idx // self._len_time
idx_time = idx % self._len_time

member = self._indexes[idx_member]
date = self._input_time_array[idx_time]
# computing day of the year.
dayofyear = (
date - np.datetime64(str(date.astype("datetime64[Y]")))
) / np.timedelta64(1, "D") + 1
if dayofyear <= 0 or dayofyear > 366:
raise ValueError(f"Invalid day of the year: {dayofyear}")

sample_input = {}
mean_input = {}
std_input = {}
mean_output = {}
std_output = {}

for v, da in self._input_arrays.items():

array = xrts.read(da[member].sel(time=date)).data
# Array is either two-or three-dimensional.
sample_input[v] = data_utils.maybe_expand_dims(array, (2, 3), 2)

for v, da in self._input_mean_arrays.items():
mean_array = xrts.read(da[member].sel(dayofyear=int(dayofyear))).data
mean_input[v] = data_utils.maybe_expand_dims(mean_array, (2, 3), 2)

for v, da in self._input_std_arrays.items():
std_array = xrts.read(da[member].sel(dayofyear=int(dayofyear))).data
std_input[v] = data_utils.maybe_expand_dims(std_array, (2, 3), 2)

item["input"] = sample_input
item["input_mean"] = mean_input
item["input_std"] = std_input

for v, da in self._output_mean_arrays.items():
mean_array = xrts.read(da.sel(dayofyear=int(dayofyear))).data
mean_output[v] = data_utils.maybe_expand_dims(mean_array, (2, 3), 2)

for v, da in self._output_std_arrays.items():
std_array = xrts.read(da.sel(dayofyear=int(dayofyear))).data
std_output[v] = data_utils.maybe_expand_dims(std_array, (2, 3), 2)

item["output_mean"] = mean_output
item["output_std"] = std_output

if self._time_stamps:
item["input_time_stamp"] = date
item["input_member"] = member

return item


# TODO: Merge this loader with the one below and add a flag.
def create_ensemble_lens2_era5_loader_with_climatology(
date_range: tuple[str, str],
input_dataset_path: epath.PathLike = _LENS2_DATASET_PATH, # pylint: disable=dangerous-default-value
Expand Down

0 comments on commit 6405bd8

Please sign in to comment.