Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707982474
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Dec 19, 2024
1 parent 0c27bf1 commit 38a996d
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions swirl_dynamics/projects/debiasing/rectified_flow/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,35 @@
_LENS2_STD_STATS_PATH = "/lzepedanunez/data/lens2/stats/lens2_std_stats_all_variables_240x121_lonlat_1961-2000.zarr"


def maybe_expand_dims(
array: np.ndarray,
allowed_dims: tuple[int, ...],
trigger_expand_dims: int,
axis: int = -1,
) -> np.ndarray:
"""Expands the dimensions of a numpy array if necessary.
Args:
array: The numpy array to be possibly expanded.
allowed_dims: The dimensions that the array can have, raise an error
otherwise.
trigger_expand_dims: The dimension that triggers the expansion.
axis: The axis in which the extra dimension is added.
Returns:
The array possibly expanded if its dimension is trigger_expand_dims.
"""
ndim = array.ndim
if ndim not in allowed_dims:
raise ValueError(
f"The array has {ndim} dimensions, but it should have one of the"
f" dimensions {allowed_dims}"
)
if ndim == trigger_expand_dims:
array = np.expand_dims(array, axis=axis)
return array


class UnpairedDataLoader:
"""Unpaired dataloader for loading samples from two distributions."""

Expand Down

0 comments on commit 38a996d

Please sign in to comment.