diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/data_utils.py b/swirl_dynamics/projects/debiasing/rectified_flow/data_utils.py index d97eb12..03839d7 100644 --- a/swirl_dynamics/projects/debiasing/rectified_flow/data_utils.py +++ b/swirl_dynamics/projects/debiasing/rectified_flow/data_utils.py @@ -707,215 +707,6 @@ def get_output_coords(self): return self._output_coords -class DataSourceContiguousNonOverlappingEnsembleWithStats: - """A data source that loads paired daily ERA5- with ensemble LENS2 data.""" - - def __init__( - self, - date_range: tuple[str, str], - batch_size: int, - input_dataset: epath.PathLike, - input_variable_names: Sequence[str], - input_member_indexer: tuple[Mapping[str, Any], ...], - input_stats_dataset: epath.PathLike, - output_dataset: epath.PathLike, - output_variables: Mapping[str, Any], - dims_order_input: Sequence[str] | None = None, - dims_order_output: Sequence[str] | None = None, - dims_order_input_stats: Sequence[str] | None = ( - "member", - "longitude", - "latitude", - "stats", - ), - resample_at_nan: bool = False, - resample_seed: int = 9999, - time_stamps: bool = False, - ): - """Data source constructor. - - Args: - date_range: The date range (in days) applied. Data not falling in this - range is ignored. - batch_size: Size of the batch. - input_dataset: The path of a zarr dataset containing the input data. - input_variable_names: The names of variables (in a tuple) to yield from - the input dataset. - input_member_indexer: The name of the ensemble member to sample from. - input_stats_dataset: The path of a zarr dataset containing the input - statistics. - output_dataset: The path of a zarr dataset containing the output data. - output_variables: The variables to yield from the output dataset. - dims_order_input: Order of the variables for the input. - dims_order_output: Order of the variables for the output. - dims_order_input_stats: Order of the variables for the input statistics. - resample_at_nan: Whether to resample when NaN is detected in the data. - resample_seed: The random seed for resampling. - time_stamps: Wheter to add the time stamps to the samples. - """ - - # Using lens as input, they need to be modified. - input_variables = {v: input_member_indexer for v in input_variable_names} - - date_range = jax.tree.map(lambda x: np.datetime64(x, "D"), date_range) - - input_ds = xrts.open_zarr(input_dataset).sel(time=slice(*date_range)) - output_ds = xrts.open_zarr(output_dataset).sel(time=slice(*date_range)) - input_stats_ds = xrts.open_zarr(input_stats_dataset) - - if dims_order_input: - input_ds = input_ds.transpose(*dims_order_input) - if dims_order_output: - output_ds = output_ds.transpose(*dims_order_output) - if dims_order_input_stats: - input_stats_ds = input_stats_ds.transpose(*dims_order_input_stats) - - self._input_arrays = {} - for v, indexers in input_variables.items(): - self._input_arrays[v] = {} - for index in indexers: - idx = tuple(index.values())[0] - self._input_arrays[v][idx] = input_ds[v].sel(index) - - self._input_stats_arrays = {} - for v, indexers in input_variables.items(): - self._input_stats_arrays[v] = {} - for index in indexers: - # print(f"index = {index}") - idx = tuple(index.values())[0] - self._input_stats_arrays[v][idx] = input_stats_ds[v].sel(index) - - self._output_arrays = {} - for v, indexers in output_variables.items(): - self._output_arrays[v] = output_ds[v].sel( - indexers - ) # pytype : disable=wrong-arg-types - - self._output_coords = output_ds.coords - - # self._dates = get_common_times(input_ds, date_range) - # The times can be slightly off due to the leap years. - self._indexes = [ind["member"] for ind in input_member_indexer] - self._len_time = ( - np.min([input_ds.dims["time"], output_ds.dims["time"]]) // batch_size - ) - self._len = len(self._indexes) * self._len_time - self._input_time_array = xrts.read(input_ds["time"]).data - self._output_time_array = xrts.read(output_ds["time"]).data - self._resample_at_nan = resample_at_nan - self._resample_seed = resample_seed - self._time_stamps = time_stamps - self._batch_size = batch_size - - def __len__(self): - return self._len - - def __getitem__(self, record_key: SupportsIndex) -> Mapping[str, Any]: - """Retrieves record and retry if NaN found.""" - idx = record_key.__index__() - if not idx < self._len: - raise ValueError(f"Index out of range: {idx} / {self._len - 1}") - - item = self.get_item(idx) - - if self._resample_at_nan: - while np.isnan(item["input"]).any() or np.isnan(item["output"]).any(): - - rng = np.random.default_rng(self._resample_seed + idx) - resample_idx = rng.integers(0, len(self)) - item = self.get_item(resample_idx) - - return item - - def get_item(self, idx: int) -> Mapping[str, Any]: - """Returns the data record for a given index.""" - item = {} - - idx_member = idx // self._len_time - idx_time = idx % self._len_time - - member = self._indexes[idx_member] - sample_input = {} - mean_input = {} - std_input = {} - - for v, da in self._input_arrays.items(): - array = xrts.read( - da[member].isel( - time=slice( - idx_time * self._batch_size, (idx_time + 1) * self._batch_size - ) - ) - ).data - - assert array.ndim == 3 or array.ndim == 4 - sample_input[v] = ( - np.expand_dims(array, axis=-1) if array.ndim == 3 else array - ) - - for v, da in self._input_stats_arrays.items(): - mean_array = xrts.read(da[member].sel(stats="mean")).data - - # there is no time dimension yet, - # so the mean and std are either 2- or 3-tensors - assert mean_array.ndim == 2 or mean_array.ndim == 3 - mean_array = ( - np.expand_dims(mean_array, axis=-1) - if mean_array.ndim == 2 - else mean_array - ) - mean_input[v] = np.tile( - mean_array, (self._batch_size,) + (1,) * mean_array.ndim - ) - - std_array = xrts.read(da[member].sel(stats="std")).data - - assert std_array.ndim == 2 or std_array.ndim == 3 - std_array = ( - np.expand_dims(std_array, axis=-1) - if std_array.ndim == 2 - else std_array - ) - std_input[v] = np.tile( - std_array, (self._batch_size,) + (1,) * std_array.ndim - ) - - item["input"] = sample_input - item["input_mean"] = mean_input - item["input_std"] = std_input - - sample_output = {} - # TODO encapsulate these functions. - for v, da in self._output_arrays.items(): - # We change the slicing to be contiguous and non-overlapping. - array = xrts.read( - da.isel( - time=slice( - idx_time * self._batch_size, (idx_time + 1) * self._batch_size - ) - ) - ).data - assert array.ndim == 3 or array.ndim == 4 - sample_output[v] = ( - np.expand_dims(array, axis=-1) if array.ndim == 3 else array - ) - - item["output"] = sample_output - - if self._time_stamps: - item["input_time_stamp"] = self._input_time_array[ - idx_time * self._batch_size : (idx_time + 1) * self._batch_size - ] - item["output_time_stamp"] = self._output_time_array[ - idx_time * self._batch_size : (idx_time + 1) * self._batch_size - ] - - return item - - def get_output_coords(self): - return self._output_coords - - class DataSourceContiguousEnsembleNonOverlappingWithStatsLENS2: """A data source that loads ensemble LENS2 data with stats.""" @@ -1433,40 +1224,23 @@ def create_ensemble_lens2_era5_loader_chunked_with_stats( Returns: A pygrain loader with the LENS2 and ERA5 data set. """ - if overlapping_chunks: - source = DataSourceContiguousNonOverlappingEnsembleWithStats( - date_range=date_range, - batch_size=batch_size, - input_dataset=input_dataset_path, - input_variable_names=input_variable_names, - input_member_indexer=input_member_indexer, - input_stats_dataset=input_stats_path, - output_dataset=output_dataset_path, - output_variables=output_variables, - resample_at_nan=False, - dims_order_input=["member", "time", "longitude", "latitude"], - dims_order_output=["time", "longitude", "latitude", "level"], - dims_order_input_stats=["member", "longitude", "latitude", "stats"], - resample_seed=9999, - time_stamps=time_stamps, - ) - else: - source = DataSourceContiguousEnsembleWithStats( - date_range=date_range, - batch_size=batch_size, - input_dataset=input_dataset_path, - input_variable_names=input_variable_names, - input_member_indexer=input_member_indexer, - input_stats_dataset=input_stats_path, - output_dataset=output_dataset_path, - output_variables=output_variables, - resample_at_nan=False, - dims_order_input=["member", "time", "longitude", "latitude"], - dims_order_output=["time", "longitude", "latitude", "level"], - dims_order_input_stats=["member", "longitude", "latitude", "stats"], - resample_seed=9999, - time_stamps=time_stamps, - ) + del overlapping_chunks # Unused. + source = DataSourceContiguousEnsembleWithStats( + date_range=date_range, + batch_size=batch_size, + input_dataset=input_dataset_path, + input_variable_names=input_variable_names, + input_member_indexer=input_member_indexer, + input_stats_dataset=input_stats_path, + output_dataset=output_dataset_path, + output_variables=output_variables, + resample_at_nan=False, + dims_order_input=["member", "time", "longitude", "latitude"], + dims_order_output=["time", "longitude", "latitude", "level"], + dims_order_input_stats=["member", "longitude", "latitude", "stats"], + resample_seed=9999, + time_stamps=time_stamps, + ) member_indexer = input_member_indexer[0] # Just the first one to extract the statistics. @@ -1622,42 +1396,25 @@ def create_ensemble_lens2_era5_loader_chunked_with_normalized_stats( Returns: """ + del overlapping_chunks # Unused. chunk_size = batch_size // num_chunks - if overlapping_chunks: - source = DataSourceContiguousNonOverlappingEnsembleWithStats( - date_range=date_range, - batch_size=chunk_size, - input_dataset=input_dataset_path, - input_variable_names=input_variable_names, - input_member_indexer=input_member_indexer, - input_stats_dataset=input_stats_path, - output_dataset=output_dataset_path, - output_variables=output_variables, - resample_at_nan=False, - dims_order_input=["member", "time", "longitude", "latitude"], - dims_order_output=["time", "longitude", "latitude", "level"], - dims_order_input_stats=["member", "longitude", "latitude", "stats"], - resample_seed=9999, - time_stamps=time_stamps, - ) - else: - source = DataSourceContiguousEnsembleWithStats( - date_range=date_range, - batch_size=chunk_size, - input_dataset=input_dataset_path, - input_variable_names=input_variable_names, - input_member_indexer=input_member_indexer, - input_stats_dataset=input_stats_path, - output_dataset=output_dataset_path, - output_variables=output_variables, - resample_at_nan=False, - dims_order_input=["member", "time", "longitude", "latitude"], - dims_order_output=["time", "longitude", "latitude", "level"], - dims_order_input_stats=["member", "longitude", "latitude", "stats"], - resample_seed=9999, - time_stamps=time_stamps, - ) + source = DataSourceContiguousEnsembleWithStats( + date_range=date_range, + batch_size=chunk_size, + input_dataset=input_dataset_path, + input_variable_names=input_variable_names, + input_member_indexer=input_member_indexer, + input_stats_dataset=input_stats_path, + output_dataset=output_dataset_path, + output_variables=output_variables, + resample_at_nan=False, + dims_order_input=["member", "time", "longitude", "latitude"], + dims_order_output=["time", "longitude", "latitude", "level"], + dims_order_input_stats=["member", "longitude", "latitude", "stats"], + resample_seed=9999, + time_stamps=time_stamps, + ) member_indexer = input_member_indexer[0] # Just the first one to extract the statistics.