Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 710178802
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Dec 28, 2024
1 parent ec6400c commit 2a83c36
Showing 1 changed file with 205 additions and 0 deletions.
205 changes: 205 additions & 0 deletions swirl_dynamics/projects/debiasing/rectified_flow/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Class with the data loaders the climatology-based models."""
# TODO encapsulate the functionality and streamline the code.

import abc
from collections.abc import Callable, Mapping, Sequence
import types
from typing import Any, Literal, SupportsIndex
Expand Down Expand Up @@ -91,6 +92,210 @@ def read_stats_simple(
return out


class CommonSourceEnsemble(abc.ABC):
"""A data source that loads daily ERA5- with ensemble LENS2 data.
Here we consider a loose alignment between the ERA5 and LENS2 data using the
time stamps. The pairs feed to the model are selected such that the time
stamps are roughly the same, so the climatological statistics are roughly
aligned between the two datasets.
"""

@abc.abstractmethod
def __len__(self):
pass

@abc.abstractmethod
def _compute_len(self, *args) -> int:
pass

@abc.abstractmethod
def _compute_indices(
self, idx
) -> tuple[str, np.datetime64 | Sequence[np.datetime64], int | Sequence[int]]:
pass

@abc.abstractmethod
def _maybe_expands_dims(self, x: np.ndarray) -> np.ndarray:
pass

def __init__(
self,
date_range: tuple[str, str],
input_dataset: epath.PathLike,
input_variable_names: Sequence[str],
input_member_indexer: tuple[Mapping[str, Any], ...],
input_climatology: epath.PathLike,
output_dataset: epath.PathLike,
output_variables: Mapping[str, Any],
output_climatology: epath.PathLike,
dims_order_input: Sequence[str] | None = None,
dims_order_output: Sequence[str] | None = None,
dims_order_input_stats: Sequence[str] | None = None,
dims_order_output_stats: Sequence[str] | None = None,
resample_at_nan: bool = False,
resample_seed: int = 9999,
time_stamps: bool = False,
):
"""Data source constructor.
Args:
date_range: The date range (in days) applied. Data not falling in this
range is ignored.
input_dataset: The path of a zarr dataset containing the input data.
input_variable_names: The variables to yield from the input dataset.
input_member_indexer: The name of the ensemble member to sample from, it
should be tuple of dictionaries with the key "member" and the value the
name of the member, to adhere to xarray formating, For example:
[{"member": "cmip6_1001_001"}, {"member": "cmip6_1021_002"}, ...]
input_climatology: The path of a zarr dataset containing the input
statistics.
output_dataset: The path of a zarr dataset containing the output data.
output_variables: The variables to yield from the output dataset.
output_climatology: The path of a zarr dataset containing the output
statistics.
dims_order_input: Order of the variables for the input.
dims_order_output: Order of the variables for the output.
dims_order_input_stats: Order of the variables for the input statistics.
dims_order_output_stats: Order of the variables for the output statistics.
resample_at_nan: Whether to resample when NaN is detected in the data.
resample_seed: The random seed for resampling.
time_stamps: Wheter to add the time stamps to the samples.
"""

# Using LENS2 as input, they need to be modified.
input_variables = {v: input_member_indexer for v in input_variable_names}

# Computing the date_range.
date_range = jax.tree.map(lambda x: np.datetime64(x, "D"), date_range)

# Open the datasets.
input_ds = xrts.open_zarr(input_dataset).sel(time=slice(*date_range))
output_ds = xrts.open_zarr(output_dataset).sel(time=slice(*date_range))
# These contain the climatologies
input_stats_ds = xrts.open_zarr(input_climatology)
output_stats_ds = xrts.open_zarr(output_climatology)

# Transpose the datasets if necessary.
if dims_order_input:
input_ds = input_ds.transpose(*dims_order_input)
if dims_order_output:
output_ds = output_ds.transpose(*dims_order_output)
if dims_order_output_stats:
output_stats_ds = output_stats_ds.transpose(*dims_order_output_stats)
if dims_order_input_stats:
input_stats_ds = input_stats_ds.transpose(*dims_order_input_stats)

# Selects the input arrays and builds the dictionary of xarray datasets
# to be used for the climatology.
self._input_arrays = {}
self._input_mean_arrays = {}
self._input_std_arrays = {}

for v, indexers in input_variables.items():
self._input_arrays[v] = {}
self._input_mean_arrays[v] = {}
self._input_std_arrays[v] = {}
for index in indexers:
idx = tuple(index.values())[0]
self._input_arrays[v][idx] = input_ds[v].sel(index)
self._input_mean_arrays[v][idx] = input_stats_ds[v].sel(index)
self._input_std_arrays[v][idx] = input_stats_ds[v + "_std"].sel(index)

# Build the output arrays for the different output variables and statistics.
self._output_arrays = {}
self._output_mean_arrays = {}
self._output_std_arrays = {}
for v, indexers in output_variables.items():
self._output_arrays[v] = output_ds[v].sel(indexers)
self._output_mean_arrays[v] = output_stats_ds[v].sel(indexers)
self._output_std_arrays[v] = output_stats_ds[v + "_std"].sel(indexers)

self._output_coords = output_ds.coords

# The times can be slightly off due to the leap years.
# Member index.
self._indexes = [ind["member"] for ind in input_member_indexer]
self._len_time = np.min([input_ds.dims["time"], output_ds.dims["time"]])
self._input_time_array = xrts.read(input_ds["time"]).data
self._output_time_array = xrts.read(output_ds["time"]).data
self._resample_at_nan = resample_at_nan
self._resample_seed = resample_seed
self._time_stamps = time_stamps

def __getitem__(self, record_key: SupportsIndex) -> Mapping[str, Any]:
"""Retrieves record and retry if NaN found."""
idx = record_key.__index__()
if not idx < self.__len__():
raise ValueError(f"Index out of range: {idx} / {self.__len__() - 1}")

item = self.get_item(idx)

if self._resample_at_nan:
while np.isnan(item["input"]).any() or np.isnan(item["output"]).any():

rng = np.random.default_rng(self._resample_seed + idx)
resample_idx = rng.integers(0, len(self))
item = self.get_item(resample_idx)

return item

def get_item(self, idx: int) -> Mapping[str, Any]:
"""Returns the data record for a given index."""
item = {}

# Computes the indices.
member, date, dayofyear = self._compute_indices(idx)

sample_input, mean_input, std_input = {}, {}, {}
sample_output, mean_output, std_output = {}, {}, {}

# Loop over the variable key and each dataset, and statistics.
for var_key, da in self._input_arrays.items():
array = xrts.read(da[member].sel(time=date)).data
sample_input[var_key] = self._maybe_expands_dims(array)

for var_key, da in self._input_mean_arrays.items():
mean_array = xrts.read(da[member].sel(dayofyear=dayofyear)).data
mean_input[var_key] = self._maybe_expands_dims(mean_array)

for var_key, da in self._input_std_arrays.items():
std_array = xrts.read(da[member].sel(dayofyear=dayofyear)).data
std_input[var_key] = self._maybe_expands_dims(std_array)

item["input"] = sample_input
item["input_mean"] = mean_input
item["input_std"] = std_input

# Loop over the variable key and each dataset, and statistics.
for var_key, da in self._output_arrays.items():
array = xrts.read(da.sel(time=date)).data
sample_output[var_key] = self._maybe_expands_dims(array)

for var_key, da in self._output_mean_arrays.items():
mean_array = xrts.read(da.sel(dayofyear=dayofyear)).data
mean_output[var_key] = self._maybe_expands_dims(mean_array)

for var_key, da in self._output_std_arrays.items():
std_array = xrts.read(da.sel(dayofyear=dayofyear)).data
std_output[var_key] = self._maybe_expands_dims(std_array)

item["output"] = sample_output
item["output_mean"] = mean_output
item["output_std"] = std_output

if self._time_stamps:
# Adds the time and the ensemble member corresponding to the data.
item["input_time_stamp"] = date
item["input_member"] = member

return item

def get_output_coords(self):
"""Returns the coordinates of the output dataset."""
return self._output_coords


class DataSourceEnsembleWithClimatology:
"""A data source that loads paired daily ERA5- with ensemble LENS2 data."""

Expand Down

0 comments on commit 2a83c36

Please sign in to comment.