From 6405bd84777993f4057302518f8ece5381e79f4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leonardo=20Zepeda-N=C3=BA=C3=B1ez?= Date: Fri, 27 Dec 2024 19:11:54 -0800 Subject: [PATCH] Code update PiperOrigin-RevId: 710185289 --- .../debiasing/rectified_flow/dataloaders.py | 200 +----------------- 1 file changed, 1 insertion(+), 199 deletions(-) diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/dataloaders.py b/swirl_dynamics/projects/debiasing/rectified_flow/dataloaders.py index d5193b2..895ff44 100644 --- a/swirl_dynamics/projects/debiasing/rectified_flow/dataloaders.py +++ b/swirl_dynamics/projects/debiasing/rectified_flow/dataloaders.py @@ -504,205 +504,7 @@ def _maybe_expands_dims(self, x: np.ndarray) -> np.ndarray: ) -# TODO: This class is redundant remove. -class DataSourceEnsembleWithClimatologyInference: - """An inferece data source that loads ensemble LENS2 data.""" - - def __init__( - self, - date_range: tuple[str, str], - input_dataset: epath.PathLike, - input_variable_names: Sequence[str], - input_member_indexer: tuple[Mapping[str, Any], ...], - input_climatology: epath.PathLike, - output_variables: Mapping[str, Any], - output_climatology: epath.PathLike, - dims_order_input: Sequence[str] | None = None, - dims_order_input_stats: Sequence[str] | None = None, - dims_order_output_stats: Sequence[str] | None = None, - resample_at_nan: bool = False, - resample_seed: int = 9999, - time_stamps: bool = False, - ): - """Data source constructor. - - Args: - date_range: The date range (in days) applied. Data not falling in this - range is ignored. - input_dataset: The path of a zarr dataset containing the input data. - input_variable_names: The variables to yield from the input dataset. - input_member_indexer: The name of the ensemble member to sample from, it - should be tuple of dictionaries with the key "member" and the value the - name of the member, to adhere to xarray formating, For example: - [{"member": "cmip6_1001_001"}, {"member": "cmip6_1021_002"}, ...] - input_climatology: The path of a zarr dataset containing the input - statistics. - output_variables: The variables to yield from the output dataset. - output_climatology: The path of a zarr dataset containing the output - statistics. - dims_order_input: Order of the dimensions (time, member, lat, lon, fields) - of the input variables. If None, the dimensions are not changed from the - order in the xarray dataset - dims_order_input_stats: Order of the dimensions (member, lat, lon, fields) - of the statistics of the input variables. If None, the dimensions are - not changed from the order in the xarray dataset - dims_order_output_stats: Order of the dimensions (member, lat, lon, field) - of the statistics of the output variables. If None, the dimensions are - not changed from the order in the xarray dataset - resample_at_nan: Whether to resample when NaN is detected in the data. - resample_seed: The random seed for resampling. - time_stamps: Wheter to add the time stamps to the samples. - """ - - # Using lens as input, they need to be modified. - input_variables = {v: input_member_indexer for v in input_variable_names} - - # Computing the date_range - date_range = jax.tree.map(lambda x: np.datetime64(x, "D"), date_range) - - # Open the datasets - input_ds = xrts.open_zarr(input_dataset).sel(time=slice(*date_range)) - # These contain the climatologies - input_stats_ds = xrts.open_zarr(input_climatology) - output_stats_ds = xrts.open_zarr(output_climatology) - - # Transpose the datasets if necessary - if dims_order_input: - input_ds = input_ds.transpose(*dims_order_input) - if dims_order_output_stats: - output_stats_ds = output_stats_ds.transpose(*dims_order_output_stats) - if dims_order_input_stats: - input_stats_ds = input_stats_ds.transpose(*dims_order_input_stats) - - # selecting the input_arrays - self._input_arrays = {} - for v, indexers in input_variables.items(): - self._input_arrays[v] = {} - for index in indexers: - idx = tuple(index.values())[0] - self._input_arrays[v][idx] = input_ds[v].sel(index) - - # Building the dictionary of xarray datasets to be used for the climatology. - # Climatological mean of LENS2. - self._input_mean_arrays = {} - for v, indexers in input_variables.items(): - self._input_mean_arrays[v] = {} - for index in indexers: - idx = tuple(index.values())[0] - self._input_mean_arrays[v][idx] = input_stats_ds[v].sel(index) - - # Climatological std of LENS2. - self._input_std_arrays = {} - for v, indexers in input_variables.items(): - self._input_std_arrays[v] = {} - for index in indexers: - idx = tuple(index.values())[0] - self._input_std_arrays[v][idx] = input_stats_ds[v + "_std"].sel(index) - - # Build the output arrays for the different output variables. - # Climatological mean of ERA5. - self._output_mean_arrays = {} - for v, indexers in output_variables.items(): - self._output_mean_arrays[v] = output_stats_ds[v].sel( - indexers - ) # pytype : disable=wrong-arg-types - - # Climatological std of ERA5. - self._output_std_arrays = {} - for v, indexers in output_variables.items(): - self._output_std_arrays[v] = output_stats_ds[v + "_std"].sel( - indexers - ) # pytype : disable=wrong-arg-types - - # The times can be slightly off due to the leap years. - # Member index. - self._indexes = [ind["member"] for ind in input_member_indexer] - self._len_time = input_ds.dims["time"] - self._len = len(self._indexes) * self._len_time - self._input_time_array = xrts.read(input_ds["time"]).data - self._resample_at_nan = resample_at_nan - self._resample_seed = resample_seed - self._time_stamps = time_stamps - - def __len__(self): - return self._len - - def __getitem__(self, record_key: SupportsIndex) -> Mapping[str, Any]: - """Retrieves record and retry if NaN found.""" - idx = record_key.__index__() - if not idx < self._len: - raise ValueError(f"Index out of range: {idx} / {self._len - 1}") - - item = self.get_item(idx) - - if self._resample_at_nan: - while np.isnan(item["input"]).any() or np.isnan(item["output"]).any(): - - rng = np.random.default_rng(self._resample_seed + idx) - resample_idx = rng.integers(0, len(self)) - item = self.get_item(resample_idx) - - return item - - def get_item(self, idx: int) -> Mapping[str, Any]: - """Returns the data record for a given index.""" - item = {} - - # Checking the index for the member and time (of the year) - idx_member = idx // self._len_time - idx_time = idx % self._len_time - - member = self._indexes[idx_member] - date = self._input_time_array[idx_time] - # computing day of the year. - dayofyear = ( - date - np.datetime64(str(date.astype("datetime64[Y]"))) - ) / np.timedelta64(1, "D") + 1 - if dayofyear <= 0 or dayofyear > 366: - raise ValueError(f"Invalid day of the year: {dayofyear}") - - sample_input = {} - mean_input = {} - std_input = {} - mean_output = {} - std_output = {} - - for v, da in self._input_arrays.items(): - - array = xrts.read(da[member].sel(time=date)).data - # Array is either two-or three-dimensional. - sample_input[v] = data_utils.maybe_expand_dims(array, (2, 3), 2) - - for v, da in self._input_mean_arrays.items(): - mean_array = xrts.read(da[member].sel(dayofyear=int(dayofyear))).data - mean_input[v] = data_utils.maybe_expand_dims(mean_array, (2, 3), 2) - - for v, da in self._input_std_arrays.items(): - std_array = xrts.read(da[member].sel(dayofyear=int(dayofyear))).data - std_input[v] = data_utils.maybe_expand_dims(std_array, (2, 3), 2) - - item["input"] = sample_input - item["input_mean"] = mean_input - item["input_std"] = std_input - - for v, da in self._output_mean_arrays.items(): - mean_array = xrts.read(da.sel(dayofyear=int(dayofyear))).data - mean_output[v] = data_utils.maybe_expand_dims(mean_array, (2, 3), 2) - - for v, da in self._output_std_arrays.items(): - std_array = xrts.read(da.sel(dayofyear=int(dayofyear))).data - std_output[v] = data_utils.maybe_expand_dims(std_array, (2, 3), 2) - - item["output_mean"] = mean_output - item["output_std"] = std_output - - if self._time_stamps: - item["input_time_stamp"] = date - item["input_member"] = member - - return item - - +# TODO: Merge this loader with the one below and add a flag. def create_ensemble_lens2_era5_loader_with_climatology( date_range: tuple[str, str], input_dataset_path: epath.PathLike = _LENS2_DATASET_PATH, # pylint: disable=dangerous-default-value