diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/dataloaders.py b/swirl_dynamics/projects/debiasing/rectified_flow/dataloaders.py index 07ce200..cede207 100644 --- a/swirl_dynamics/projects/debiasing/rectified_flow/dataloaders.py +++ b/swirl_dynamics/projects/debiasing/rectified_flow/dataloaders.py @@ -15,6 +15,7 @@ """Class with the data loaders the climatology-based models.""" # TODO encapsulate the functionality and streamline the code. +import abc from collections.abc import Callable, Mapping, Sequence import types from typing import Any, Literal, SupportsIndex @@ -91,6 +92,210 @@ def read_stats_simple( return out +class CommonSourceEnsemble(abc.ABC): + """A data source that loads daily ERA5- with ensemble LENS2 data. + + Here we consider a loose alignment between the ERA5 and LENS2 data using the + time stamps. The pairs feed to the model are selected such that the time + stamps are roughly the same, so the climatological statistics are roughly + aligned between the two datasets. + """ + + @abc.abstractmethod + def __len__(self): + pass + + @abc.abstractmethod + def _compute_len(self, *args) -> int: + pass + + @abc.abstractmethod + def _compute_indices( + self, idx + ) -> tuple[str, np.datetime64 | Sequence[np.datetime64], int | Sequence[int]]: + pass + + @abc.abstractmethod + def _maybe_expands_dims(self, x: np.ndarray) -> np.ndarray: + pass + + def __init__( + self, + date_range: tuple[str, str], + input_dataset: epath.PathLike, + input_variable_names: Sequence[str], + input_member_indexer: tuple[Mapping[str, Any], ...], + input_climatology: epath.PathLike, + output_dataset: epath.PathLike, + output_variables: Mapping[str, Any], + output_climatology: epath.PathLike, + dims_order_input: Sequence[str] | None = None, + dims_order_output: Sequence[str] | None = None, + dims_order_input_stats: Sequence[str] | None = None, + dims_order_output_stats: Sequence[str] | None = None, + resample_at_nan: bool = False, + resample_seed: int = 9999, + time_stamps: bool = False, + ): + """Data source constructor. + + Args: + date_range: The date range (in days) applied. Data not falling in this + range is ignored. + input_dataset: The path of a zarr dataset containing the input data. + input_variable_names: The variables to yield from the input dataset. + input_member_indexer: The name of the ensemble member to sample from, it + should be tuple of dictionaries with the key "member" and the value the + name of the member, to adhere to xarray formating, For example: + [{"member": "cmip6_1001_001"}, {"member": "cmip6_1021_002"}, ...] + input_climatology: The path of a zarr dataset containing the input + statistics. + output_dataset: The path of a zarr dataset containing the output data. + output_variables: The variables to yield from the output dataset. + output_climatology: The path of a zarr dataset containing the output + statistics. + dims_order_input: Order of the variables for the input. + dims_order_output: Order of the variables for the output. + dims_order_input_stats: Order of the variables for the input statistics. + dims_order_output_stats: Order of the variables for the output statistics. + resample_at_nan: Whether to resample when NaN is detected in the data. + resample_seed: The random seed for resampling. + time_stamps: Wheter to add the time stamps to the samples. + """ + + # Using LENS2 as input, they need to be modified. + input_variables = {v: input_member_indexer for v in input_variable_names} + + # Computing the date_range. + date_range = jax.tree.map(lambda x: np.datetime64(x, "D"), date_range) + + # Open the datasets. + input_ds = xrts.open_zarr(input_dataset).sel(time=slice(*date_range)) + output_ds = xrts.open_zarr(output_dataset).sel(time=slice(*date_range)) + # These contain the climatologies + input_stats_ds = xrts.open_zarr(input_climatology) + output_stats_ds = xrts.open_zarr(output_climatology) + + # Transpose the datasets if necessary. + if dims_order_input: + input_ds = input_ds.transpose(*dims_order_input) + if dims_order_output: + output_ds = output_ds.transpose(*dims_order_output) + if dims_order_output_stats: + output_stats_ds = output_stats_ds.transpose(*dims_order_output_stats) + if dims_order_input_stats: + input_stats_ds = input_stats_ds.transpose(*dims_order_input_stats) + + # Selects the input arrays and builds the dictionary of xarray datasets + # to be used for the climatology. + self._input_arrays = {} + self._input_mean_arrays = {} + self._input_std_arrays = {} + + for v, indexers in input_variables.items(): + self._input_arrays[v] = {} + self._input_mean_arrays[v] = {} + self._input_std_arrays[v] = {} + for index in indexers: + idx = tuple(index.values())[0] + self._input_arrays[v][idx] = input_ds[v].sel(index) + self._input_mean_arrays[v][idx] = input_stats_ds[v].sel(index) + self._input_std_arrays[v][idx] = input_stats_ds[v + "_std"].sel(index) + + # Build the output arrays for the different output variables and statistics. + self._output_arrays = {} + self._output_mean_arrays = {} + self._output_std_arrays = {} + for v, indexers in output_variables.items(): + self._output_arrays[v] = output_ds[v].sel(indexers) + self._output_mean_arrays[v] = output_stats_ds[v].sel(indexers) + self._output_std_arrays[v] = output_stats_ds[v + "_std"].sel(indexers) + + self._output_coords = output_ds.coords + + # The times can be slightly off due to the leap years. + # Member index. + self._indexes = [ind["member"] for ind in input_member_indexer] + self._len_time = np.min([input_ds.dims["time"], output_ds.dims["time"]]) + self._input_time_array = xrts.read(input_ds["time"]).data + self._output_time_array = xrts.read(output_ds["time"]).data + self._resample_at_nan = resample_at_nan + self._resample_seed = resample_seed + self._time_stamps = time_stamps + + def __getitem__(self, record_key: SupportsIndex) -> Mapping[str, Any]: + """Retrieves record and retry if NaN found.""" + idx = record_key.__index__() + if not idx < self.__len__(): + raise ValueError(f"Index out of range: {idx} / {self.__len__() - 1}") + + item = self.get_item(idx) + + if self._resample_at_nan: + while np.isnan(item["input"]).any() or np.isnan(item["output"]).any(): + + rng = np.random.default_rng(self._resample_seed + idx) + resample_idx = rng.integers(0, len(self)) + item = self.get_item(resample_idx) + + return item + + def get_item(self, idx: int) -> Mapping[str, Any]: + """Returns the data record for a given index.""" + item = {} + + # Computes the indices. + member, date, dayofyear = self._compute_indices(idx) + + sample_input, mean_input, std_input = {}, {}, {} + sample_output, mean_output, std_output = {}, {}, {} + + # Loop over the variable key and each dataset, and statistics. + for var_key, da in self._input_arrays.items(): + array = xrts.read(da[member].sel(time=date)).data + sample_input[var_key] = self._maybe_expands_dims(array) + + for var_key, da in self._input_mean_arrays.items(): + mean_array = xrts.read(da[member].sel(dayofyear=dayofyear)).data + mean_input[var_key] = self._maybe_expands_dims(mean_array) + + for var_key, da in self._input_std_arrays.items(): + std_array = xrts.read(da[member].sel(dayofyear=dayofyear)).data + std_input[var_key] = self._maybe_expands_dims(std_array) + + item["input"] = sample_input + item["input_mean"] = mean_input + item["input_std"] = std_input + + # Loop over the variable key and each dataset, and statistics. + for var_key, da in self._output_arrays.items(): + array = xrts.read(da.sel(time=date)).data + sample_output[var_key] = self._maybe_expands_dims(array) + + for var_key, da in self._output_mean_arrays.items(): + mean_array = xrts.read(da.sel(dayofyear=dayofyear)).data + mean_output[var_key] = self._maybe_expands_dims(mean_array) + + for var_key, da in self._output_std_arrays.items(): + std_array = xrts.read(da.sel(dayofyear=dayofyear)).data + std_output[var_key] = self._maybe_expands_dims(std_array) + + item["output"] = sample_output + item["output_mean"] = mean_output + item["output_std"] = std_output + + if self._time_stamps: + # Adds the time and the ensemble member corresponding to the data. + item["input_time_stamp"] = date + item["input_member"] = member + + return item + + def get_output_coords(self): + """Returns the coordinates of the output dataset.""" + return self._output_coords + + class DataSourceEnsembleWithClimatology: """A data source that loads paired daily ERA5- with ensemble LENS2 data."""