Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708957573
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Dec 23, 2024
1 parent d8e3c6e commit 0ffca00
Showing 1 changed file with 34 additions and 277 deletions.
311 changes: 34 additions & 277 deletions swirl_dynamics/projects/debiasing/rectified_flow/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,215 +707,6 @@ def get_output_coords(self):
return self._output_coords


class DataSourceContiguousNonOverlappingEnsembleWithStats:
"""A data source that loads paired daily ERA5- with ensemble LENS2 data."""

def __init__(
self,
date_range: tuple[str, str],
batch_size: int,
input_dataset: epath.PathLike,
input_variable_names: Sequence[str],
input_member_indexer: tuple[Mapping[str, Any], ...],
input_stats_dataset: epath.PathLike,
output_dataset: epath.PathLike,
output_variables: Mapping[str, Any],
dims_order_input: Sequence[str] | None = None,
dims_order_output: Sequence[str] | None = None,
dims_order_input_stats: Sequence[str] | None = (
"member",
"longitude",
"latitude",
"stats",
),
resample_at_nan: bool = False,
resample_seed: int = 9999,
time_stamps: bool = False,
):
"""Data source constructor.
Args:
date_range: The date range (in days) applied. Data not falling in this
range is ignored.
batch_size: Size of the batch.
input_dataset: The path of a zarr dataset containing the input data.
input_variable_names: The names of variables (in a tuple) to yield from
the input dataset.
input_member_indexer: The name of the ensemble member to sample from.
input_stats_dataset: The path of a zarr dataset containing the input
statistics.
output_dataset: The path of a zarr dataset containing the output data.
output_variables: The variables to yield from the output dataset.
dims_order_input: Order of the variables for the input.
dims_order_output: Order of the variables for the output.
dims_order_input_stats: Order of the variables for the input statistics.
resample_at_nan: Whether to resample when NaN is detected in the data.
resample_seed: The random seed for resampling.
time_stamps: Wheter to add the time stamps to the samples.
"""

# Using lens as input, they need to be modified.
input_variables = {v: input_member_indexer for v in input_variable_names}

date_range = jax.tree.map(lambda x: np.datetime64(x, "D"), date_range)

input_ds = xrts.open_zarr(input_dataset).sel(time=slice(*date_range))
output_ds = xrts.open_zarr(output_dataset).sel(time=slice(*date_range))
input_stats_ds = xrts.open_zarr(input_stats_dataset)

if dims_order_input:
input_ds = input_ds.transpose(*dims_order_input)
if dims_order_output:
output_ds = output_ds.transpose(*dims_order_output)
if dims_order_input_stats:
input_stats_ds = input_stats_ds.transpose(*dims_order_input_stats)

self._input_arrays = {}
for v, indexers in input_variables.items():
self._input_arrays[v] = {}
for index in indexers:
idx = tuple(index.values())[0]
self._input_arrays[v][idx] = input_ds[v].sel(index)

self._input_stats_arrays = {}
for v, indexers in input_variables.items():
self._input_stats_arrays[v] = {}
for index in indexers:
# print(f"index = {index}")
idx = tuple(index.values())[0]
self._input_stats_arrays[v][idx] = input_stats_ds[v].sel(index)

self._output_arrays = {}
for v, indexers in output_variables.items():
self._output_arrays[v] = output_ds[v].sel(
indexers
) # pytype : disable=wrong-arg-types

self._output_coords = output_ds.coords

# self._dates = get_common_times(input_ds, date_range)
# The times can be slightly off due to the leap years.
self._indexes = [ind["member"] for ind in input_member_indexer]
self._len_time = (
np.min([input_ds.dims["time"], output_ds.dims["time"]]) // batch_size
)
self._len = len(self._indexes) * self._len_time
self._input_time_array = xrts.read(input_ds["time"]).data
self._output_time_array = xrts.read(output_ds["time"]).data
self._resample_at_nan = resample_at_nan
self._resample_seed = resample_seed
self._time_stamps = time_stamps
self._batch_size = batch_size

def __len__(self):
return self._len

def __getitem__(self, record_key: SupportsIndex) -> Mapping[str, Any]:
"""Retrieves record and retry if NaN found."""
idx = record_key.__index__()
if not idx < self._len:
raise ValueError(f"Index out of range: {idx} / {self._len - 1}")

item = self.get_item(idx)

if self._resample_at_nan:
while np.isnan(item["input"]).any() or np.isnan(item["output"]).any():

rng = np.random.default_rng(self._resample_seed + idx)
resample_idx = rng.integers(0, len(self))
item = self.get_item(resample_idx)

return item

def get_item(self, idx: int) -> Mapping[str, Any]:
"""Returns the data record for a given index."""
item = {}

idx_member = idx // self._len_time
idx_time = idx % self._len_time

member = self._indexes[idx_member]
sample_input = {}
mean_input = {}
std_input = {}

for v, da in self._input_arrays.items():
array = xrts.read(
da[member].isel(
time=slice(
idx_time * self._batch_size, (idx_time + 1) * self._batch_size
)
)
).data

assert array.ndim == 3 or array.ndim == 4
sample_input[v] = (
np.expand_dims(array, axis=-1) if array.ndim == 3 else array
)

for v, da in self._input_stats_arrays.items():
mean_array = xrts.read(da[member].sel(stats="mean")).data

# there is no time dimension yet,
# so the mean and std are either 2- or 3-tensors
assert mean_array.ndim == 2 or mean_array.ndim == 3
mean_array = (
np.expand_dims(mean_array, axis=-1)
if mean_array.ndim == 2
else mean_array
)
mean_input[v] = np.tile(
mean_array, (self._batch_size,) + (1,) * mean_array.ndim
)

std_array = xrts.read(da[member].sel(stats="std")).data

assert std_array.ndim == 2 or std_array.ndim == 3
std_array = (
np.expand_dims(std_array, axis=-1)
if std_array.ndim == 2
else std_array
)
std_input[v] = np.tile(
std_array, (self._batch_size,) + (1,) * std_array.ndim
)

item["input"] = sample_input
item["input_mean"] = mean_input
item["input_std"] = std_input

sample_output = {}
# TODO encapsulate these functions.
for v, da in self._output_arrays.items():
# We change the slicing to be contiguous and non-overlapping.
array = xrts.read(
da.isel(
time=slice(
idx_time * self._batch_size, (idx_time + 1) * self._batch_size
)
)
).data
assert array.ndim == 3 or array.ndim == 4
sample_output[v] = (
np.expand_dims(array, axis=-1) if array.ndim == 3 else array
)

item["output"] = sample_output

if self._time_stamps:
item["input_time_stamp"] = self._input_time_array[
idx_time * self._batch_size : (idx_time + 1) * self._batch_size
]
item["output_time_stamp"] = self._output_time_array[
idx_time * self._batch_size : (idx_time + 1) * self._batch_size
]

return item

def get_output_coords(self):
return self._output_coords


class DataSourceContiguousEnsembleNonOverlappingWithStatsLENS2:
"""A data source that loads ensemble LENS2 data with stats."""

Expand Down Expand Up @@ -1433,40 +1224,23 @@ def create_ensemble_lens2_era5_loader_chunked_with_stats(
Returns:
A pygrain loader with the LENS2 and ERA5 data set.
"""
if overlapping_chunks:
source = DataSourceContiguousNonOverlappingEnsembleWithStats(
date_range=date_range,
batch_size=batch_size,
input_dataset=input_dataset_path,
input_variable_names=input_variable_names,
input_member_indexer=input_member_indexer,
input_stats_dataset=input_stats_path,
output_dataset=output_dataset_path,
output_variables=output_variables,
resample_at_nan=False,
dims_order_input=["member", "time", "longitude", "latitude"],
dims_order_output=["time", "longitude", "latitude", "level"],
dims_order_input_stats=["member", "longitude", "latitude", "stats"],
resample_seed=9999,
time_stamps=time_stamps,
)
else:
source = DataSourceContiguousEnsembleWithStats(
date_range=date_range,
batch_size=batch_size,
input_dataset=input_dataset_path,
input_variable_names=input_variable_names,
input_member_indexer=input_member_indexer,
input_stats_dataset=input_stats_path,
output_dataset=output_dataset_path,
output_variables=output_variables,
resample_at_nan=False,
dims_order_input=["member", "time", "longitude", "latitude"],
dims_order_output=["time", "longitude", "latitude", "level"],
dims_order_input_stats=["member", "longitude", "latitude", "stats"],
resample_seed=9999,
time_stamps=time_stamps,
)
del overlapping_chunks # Unused.
source = DataSourceContiguousEnsembleWithStats(
date_range=date_range,
batch_size=batch_size,
input_dataset=input_dataset_path,
input_variable_names=input_variable_names,
input_member_indexer=input_member_indexer,
input_stats_dataset=input_stats_path,
output_dataset=output_dataset_path,
output_variables=output_variables,
resample_at_nan=False,
dims_order_input=["member", "time", "longitude", "latitude"],
dims_order_output=["time", "longitude", "latitude", "level"],
dims_order_input_stats=["member", "longitude", "latitude", "stats"],
resample_seed=9999,
time_stamps=time_stamps,
)

member_indexer = input_member_indexer[0]
# Just the first one to extract the statistics.
Expand Down Expand Up @@ -1622,42 +1396,25 @@ def create_ensemble_lens2_era5_loader_chunked_with_normalized_stats(
Returns:
"""
del overlapping_chunks # Unused.
chunk_size = batch_size // num_chunks

if overlapping_chunks:
source = DataSourceContiguousNonOverlappingEnsembleWithStats(
date_range=date_range,
batch_size=chunk_size,
input_dataset=input_dataset_path,
input_variable_names=input_variable_names,
input_member_indexer=input_member_indexer,
input_stats_dataset=input_stats_path,
output_dataset=output_dataset_path,
output_variables=output_variables,
resample_at_nan=False,
dims_order_input=["member", "time", "longitude", "latitude"],
dims_order_output=["time", "longitude", "latitude", "level"],
dims_order_input_stats=["member", "longitude", "latitude", "stats"],
resample_seed=9999,
time_stamps=time_stamps,
)
else:
source = DataSourceContiguousEnsembleWithStats(
date_range=date_range,
batch_size=chunk_size,
input_dataset=input_dataset_path,
input_variable_names=input_variable_names,
input_member_indexer=input_member_indexer,
input_stats_dataset=input_stats_path,
output_dataset=output_dataset_path,
output_variables=output_variables,
resample_at_nan=False,
dims_order_input=["member", "time", "longitude", "latitude"],
dims_order_output=["time", "longitude", "latitude", "level"],
dims_order_input_stats=["member", "longitude", "latitude", "stats"],
resample_seed=9999,
time_stamps=time_stamps,
)
source = DataSourceContiguousEnsembleWithStats(
date_range=date_range,
batch_size=chunk_size,
input_dataset=input_dataset_path,
input_variable_names=input_variable_names,
input_member_indexer=input_member_indexer,
input_stats_dataset=input_stats_path,
output_dataset=output_dataset_path,
output_variables=output_variables,
resample_at_nan=False,
dims_order_input=["member", "time", "longitude", "latitude"],
dims_order_output=["time", "longitude", "latitude", "level"],
dims_order_input_stats=["member", "longitude", "latitude", "stats"],
resample_seed=9999,
time_stamps=time_stamps,
)

member_indexer = input_member_indexer[0]
# Just the first one to extract the statistics.
Expand Down

0 comments on commit 0ffca00

Please sign in to comment.